diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index bfdd5115..f9f2b64a 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -154,15 +154,6 @@ namespace dxvk { Rc DxbcCompiler::finalize() { - // Define the actual 'main' function of the shader - m_module.functionBegin( - m_module.defVoidType(), - m_entryPointId, - m_module.defFunctionType( - m_module.defVoidType(), 0, nullptr), - spv::FunctionControlMaskNone); - m_module.opLabel(m_module.allocateId()); - // Depending on the shader type, this will prepare // input registers, call various shader functions // and write back the output registers. @@ -175,10 +166,6 @@ namespace dxvk { case DxbcProgramType::ComputeShader: this->emitCsFinalize(); break; } - // End main function - m_module.opReturn(); - m_module.functionEnd(); - // Declare the entry point, we now have all the // information we need, including the interfaces m_module.addEntryPoint(m_entryPointId, @@ -5049,6 +5036,23 @@ namespace dxvk { } + void DxbcCompiler::emitMainFunctionBegin() { + m_module.functionBegin( + m_module.defVoidType(), + m_entryPointId, + m_module.defFunctionType( + m_module.defVoidType(), 0, nullptr), + spv::FunctionControlMaskNone); + m_module.opLabel(m_module.allocateId()); + } + + + void DxbcCompiler::emitMainFunctionEnd() { + m_module.opReturn(); + m_module.functionEnd(); + } + + void DxbcCompiler::emitVsInit() { m_module.enableCapability(spv::CapabilityClipDistance); m_module.enableCapability(spv::CapabilityCullDistance); @@ -5206,17 +5210,22 @@ namespace dxvk { void DxbcCompiler::emitVsFinalize() { + this->emitMainFunctionBegin(); this->emitInputSetup(); m_module.opFunctionCall( m_module.defVoidType(), m_vs.functionId, 0, nullptr); this->emitOutputSetup(); + this->emitMainFunctionEnd(); } void DxbcCompiler::emitHsFinalize() { - emitInputSetup(m_hs.vertexCountIn); + if (m_hs.cpPhase.functionId == 0) + m_hs.cpPhase = this->emitNewHullShaderPassthroughPhase(); + this->emitMainFunctionBegin(); + this->emitInputSetup(m_hs.vertexCountIn); this->emitHsControlPointPhase(m_hs.cpPhase); if (m_hs.forkPhases.size() != 0 @@ -5233,15 +5242,19 @@ namespace dxvk { this->emitHsInvocationBlockBegin(1); this->emitOutputSetup(); this->emitHsInvocationBlockEnd(); + this->emitMainFunctionEnd(); } void DxbcCompiler::emitDsFinalize() { + this->emitMainFunctionBegin(); // TODO implement + this->emitMainFunctionEnd(); } void DxbcCompiler::emitGsFinalize() { + this->emitMainFunctionBegin(); this->emitInputSetup( primitiveVertexCount(m_gs.inputPrimitive)); m_module.opFunctionCall( @@ -5249,32 +5262,35 @@ namespace dxvk { m_gs.functionId, 0, nullptr); // No output setup at this point as that was // already done during the EmitVertex step + this->emitMainFunctionEnd(); } void DxbcCompiler::emitPsFinalize() { + this->emitMainFunctionBegin(); this->emitInputSetup(); m_module.opFunctionCall( m_module.defVoidType(), m_ps.functionId, 0, nullptr); this->emitOutputSetup(); + this->emitMainFunctionEnd(); } void DxbcCompiler::emitCsFinalize() { + this->emitMainFunctionBegin(); m_module.opFunctionCall( m_module.defVoidType(), m_cs.functionId, 0, nullptr); + this->emitMainFunctionEnd(); } void DxbcCompiler::emitHsControlPointPhase( const DxbcCompilerHsControlPointPhase& phase) { - if (phase.functionId != 0) { - m_module.opFunctionCall( - m_module.defVoidType(), - phase.functionId, 0, nullptr); - } + m_module.opFunctionCall( + m_module.defVoidType(), + phase.functionId, 0, nullptr); } @@ -5353,6 +5369,62 @@ namespace dxvk { } + DxbcCompilerHsControlPointPhase DxbcCompiler::emitNewHullShaderPassthroughPhase() { + uint32_t funTypeId = m_module.defFunctionType( + m_module.defVoidType(), 0, nullptr); + + // Begin passthrough function + uint32_t funId = m_module.allocateId(); + m_module.setDebugName(funId, "hs_passthrough"); + + m_module.functionBegin(m_module.defVoidType(), + funId, funTypeId, spv::FunctionControlMaskNone); + m_module.opLabel(m_module.allocateId()); + + // We'll basically copy each input variable to the corresponding + // output, using the shader's invocation ID as the array index. + const uint32_t invocationId = m_module.opLoad( + getScalarTypeId(DxbcScalarType::Uint32), + m_hs.builtinInvocationId); + + for (auto i = m_isgn->begin(); i != m_isgn->end(); i++) { + this->emitDclInput( + i->registerId, m_hs.vertexCountIn, + i->componentMask, + DxbcSystemValue::None, + DxbcInterpolationMode::Undefined); + + // Vector type index + uint32_t vecTypeId = getVectorTypeId({ DxbcScalarType::Float32, 4 }); + + uint32_t dstPtrTypeId = m_module.defPointerType(vecTypeId, spv::StorageClassOutput); + uint32_t srcPtrTypeId = m_module.defPointerType(vecTypeId, spv::StorageClassInput); + + const std::array dstIndices + = {{ invocationId, m_module.constu32(i->registerId) }}; + + uint32_t dstPtr = m_module.opAccessChain( + dstPtrTypeId, m_hs.outputPerVertex, + dstIndices.size(), dstIndices.data()); + + uint32_t srcPtr = m_module.opAccessChain( + srcPtrTypeId, m_vRegs.at(i->registerId), + 1, &invocationId); + + m_module.opStore(dstPtr, + m_module.opLoad(vecTypeId, srcPtr)); + } + + // End function + m_module.opReturn(); + m_module.functionEnd(); + + DxbcCompilerHsControlPointPhase result; + result.functionId = funId; + return result; + } + + DxbcCompilerHsForkJoinPhase DxbcCompiler::emitNewHullShaderForkJoinPhase() { uint32_t funTypeId = m_module.defFunctionType( m_module.defVoidType(), 0, nullptr); diff --git a/src/dxbc/dxbc_compiler.h b/src/dxbc/dxbc_compiler.h index 987fc9af..9cda3231 100644 --- a/src/dxbc/dxbc_compiler.h +++ b/src/dxbc/dxbc_compiler.h @@ -849,10 +849,16 @@ namespace dxvk { DxbcRegMask mask, const DxbcRegisterValue& value); - ///////////////////////////////// - // Shader initialization methods + ////////////////////////////////////// + // Common function definition methods void emitInit(); + void emitMainFunctionBegin(); + + void emitMainFunctionEnd(); + + ///////////////////////////////// + // Shader initialization methods void emitVsInit(); void emitHsInit(); void emitDsInit(); @@ -902,6 +908,8 @@ namespace dxvk { DxbcCompilerHsControlPointPhase emitNewHullShaderControlPointPhase(); + DxbcCompilerHsControlPointPhase emitNewHullShaderPassthroughPhase(); + DxbcCompilerHsForkJoinPhase emitNewHullShaderForkJoinPhase(); ///////////////////////////////