diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index 0a6fa7a0..2f76b182 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -5620,15 +5620,15 @@ namespace dxvk { this->emitHsControlPointPhase(m_hs.cpPhase); this->emitHsPhaseBarrier(); - // Fork-join phases (will run in parallel) + // Fork-join phases and output setup + this->emitHsInvocationBlockBegin(1); + for (const auto& phase : m_hs.forkPhases) this->emitHsForkJoinPhase(phase); for (const auto& phase : m_hs.joinPhases) this->emitHsForkJoinPhase(phase); - // Output setup phase - this->emitHsInvocationBlockBegin(1); this->emitOutputSetup(); this->emitHsInvocationBlockEnd(); this->emitMainFunctionEnd(); @@ -5692,19 +5692,14 @@ namespace dxvk { void DxbcCompiler::emitHsForkJoinPhase( const DxbcCompilerHsForkJoinPhase& phase) { - this->emitHsInvocationBlockBegin(phase.instanceCount); - - uint32_t invocationId = m_module.opLoad( - getScalarTypeId(DxbcScalarType::Uint32), - m_hs.builtinInvocationId); - - m_module.opFunctionCall( - m_module.defVoidType(), - phase.functionId, 1, - &invocationId); - - this->emitHsInvocationBlockEnd(); - this->emitHsPhaseBarrier(); + for (uint32_t i = 0; i < phase.instanceCount; i++) { + uint32_t invocationId = m_module.constu32(i); + + m_module.opFunctionCall( + m_module.defVoidType(), + phase.functionId, 1, + &invocationId); + } }