diff --git a/src/dxbc/dxbc_compiler.cpp b/src/dxbc/dxbc_compiler.cpp index 26d0d66e..59323b08 100644 --- a/src/dxbc/dxbc_compiler.cpp +++ b/src/dxbc/dxbc_compiler.cpp @@ -2873,28 +2873,90 @@ namespace dxvk { uint32_t sparseFeedbackId = 0; - const DxbcRegisterValue elementIndex = isStructured - ? emitCalcBufferIndexStructured( - emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)), - emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false)), - bufferInfo.stride) - : emitCalcBufferIndexRaw( - emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false))); + bool useRawAccessChains = m_hasRawAccessChains && isSsbo && !imageOperands.sparse; + + DxbcRegisterValue index = emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)); + DxbcRegisterValue offset = index; + + if (isStructured) + offset = emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false)); + + DxbcRegisterValue elementIndex = { }; + + uint32_t baseAlignment = sizeof(uint32_t); + + if (useRawAccessChains) { + memoryOperands.flags |= spv::MemoryAccessAlignedMask; + + if (isStructured && ins.src[1].type == DxbcOperandType::Imm32) { + baseAlignment = bufferInfo.stride | ins.src[1].imm.u32_1; + baseAlignment = baseAlignment & -baseAlignment; + baseAlignment = std::min(baseAlignment, uint32_t(m_moduleInfo.options.minSsboAlignment)); + } + } else { + elementIndex = isStructured + ? emitCalcBufferIndexStructured(index, offset, bufferInfo.stride) + : emitCalcBufferIndexRaw(offset); + } + + uint32_t readMask = 0u; for (uint32_t i = 0; i < 4; i++) { - uint32_t sindex = srcReg.swizzle[i]; + if (dstReg.mask[i]) + readMask |= 1u << srcReg.swizzle[i]; + } - if (!dstReg.mask[i]) - continue; + while (readMask) { + uint32_t sindex = bit::tzcnt(readMask); + uint32_t scount = bit::tzcnt(~(readMask >> sindex)); + uint32_t zero = 0; - if (ccomps[sindex] == 0) { + if (useRawAccessChains) { + uint32_t alignment = baseAlignment; + uint32_t offsetId = offset.id; + + if (sindex) { + offsetId = m_module.opIAdd(scalarTypeId, + offsetId, m_module.constu32(sizeof(uint32_t) * sindex)); + alignment |= sizeof(uint32_t) * sindex; + } + + DxbcRegisterInfo storeInfo; + storeInfo.type.ctype = DxbcScalarType::Uint32; + storeInfo.type.ccount = scount; + storeInfo.type.alength = 0; + storeInfo.sclass = spv::StorageClassStorageBuffer; + + uint32_t loadTypeId = getArrayTypeId(storeInfo.type); + uint32_t ptrTypeId = getPointerTypeId(storeInfo); + + uint32_t accessChain = isStructured + ? m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId, + m_module.constu32(bufferInfo.stride), index.id, offsetId, + spv::RawAccessChainOperandsRobustnessPerElementNVMask) + : m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId, + m_module.constu32(0), m_module.constu32(0), offsetId, + spv::RawAccessChainOperandsRobustnessPerComponentNVMask); + + memoryOperands.alignment = alignment & -alignment; + + uint32_t vectorId = m_module.opLoad(loadTypeId, accessChain, memoryOperands); + + for (uint32_t i = 0; i < scount; i++) { + ccomps[sindex + i] = vectorId; + + if (scount > 1) { + ccomps[sindex + i] = m_module.opCompositeExtract( + scalarTypeId, vectorId, 1, &i); + } + } + + readMask &= ~(((1u << scount) - 1u) << sindex); + } else { uint32_t elementIndexAdjusted = m_module.opIAdd( getVectorTypeId(elementIndex.type), elementIndex.id, m_module.consti32(sindex)); - // Load requested component from the buffer - uint32_t zero = 0; - if (isTgsm) { ccomps[sindex] = m_module.opLoad(scalarTypeId, m_module.opAccessChain(bufferInfo.typeId, @@ -2934,6 +2996,8 @@ namespace dxvk { ccomps[sindex] = m_module.opCompositeExtract(scalarTypeId, resultId, 1, &zero); } + + readMask &= readMask - 1; } } @@ -2994,8 +3058,6 @@ namespace dxvk { uint32_t scalarTypeId = getVectorTypeId({ DxbcScalarType::Uint32, 1 }); uint32_t vectorTypeId = getVectorTypeId({ DxbcScalarType::Uint32, 4 }); - uint32_t srcComponentIndex = 0; - // Set memory operands according to resource properties SpirvMemoryOperands memoryOperands; SpirvImageOperands imageOperands; @@ -3020,26 +3082,90 @@ namespace dxvk { } } - // Compute flat element index - const DxbcRegisterValue elementIndex = isStructured - ? emitCalcBufferIndexStructured( - emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)), - emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false)), - bufferInfo.stride) - : emitCalcBufferIndexRaw( - emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false))); + // Compute flat element index as necessary + bool useRawAccessChains = isSsbo && m_hasRawAccessChains; - for (uint32_t i = 0; i < 4; i++) { - if (dstReg.mask[i]) { + DxbcRegisterValue index = emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)); + DxbcRegisterValue offset = index; + + if (isStructured) + offset = emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false)); + + DxbcRegisterValue elementIndex = { }; + + uint32_t baseAlignment = sizeof(uint32_t); + + if (useRawAccessChains) { + memoryOperands.flags |= spv::MemoryAccessAlignedMask; + + if (isStructured && ins.src[1].type == DxbcOperandType::Imm32) { + baseAlignment = bufferInfo.stride | ins.src[1].imm.u32_1; + baseAlignment = baseAlignment & -baseAlignment; + baseAlignment = std::min(baseAlignment, uint32_t(m_moduleInfo.options.minSsboAlignment)); + } + } else { + elementIndex = isStructured + ? emitCalcBufferIndexStructured(index, offset, bufferInfo.stride) + : emitCalcBufferIndexRaw(offset); + } + + uint32_t writeMask = dstReg.mask.raw(); + + while (writeMask) { + uint32_t sindex = bit::tzcnt(writeMask); + uint32_t scount = bit::tzcnt(~(writeMask >> sindex)); + + if (useRawAccessChains) { + uint32_t alignment = baseAlignment; + uint32_t offsetId = offset.id; + + if (sindex) { + offsetId = m_module.opIAdd(scalarTypeId, + offsetId, m_module.constu32(sizeof(uint32_t) * sindex)); + alignment = alignment | (sizeof(uint32_t) * sindex); + } + + DxbcRegisterInfo storeInfo; + storeInfo.type.ctype = DxbcScalarType::Uint32; + storeInfo.type.ccount = scount; + storeInfo.type.alength = 0; + storeInfo.sclass = spv::StorageClassStorageBuffer; + + uint32_t storeTypeId = getArrayTypeId(storeInfo.type); + uint32_t ptrTypeId = getPointerTypeId(storeInfo); + + uint32_t accessChain = isStructured + ? m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId, + m_module.constu32(bufferInfo.stride), index.id, offsetId, + spv::RawAccessChainOperandsRobustnessPerElementNVMask) + : m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId, + m_module.constu32(0), m_module.constu32(0), offsetId, + spv::RawAccessChainOperandsRobustnessPerComponentNVMask); + + uint32_t valueId = value.id; + + if (scount < value.type.ccount) { + if (scount == 1) { + valueId = m_module.opCompositeExtract(storeTypeId, value.id, 1, &sindex); + } else { + std::array indices = { sindex, sindex + 1u, sindex + 2u, sindex + 3u }; + valueId = m_module.opVectorShuffle(storeTypeId, value.id, value.id, scount, indices.data()); + } + } + + memoryOperands.alignment = alignment & -alignment; + m_module.opStore(accessChain, valueId, memoryOperands); + + writeMask &= ~(((1u << scount) - 1u) << sindex); + } else { uint32_t srcComponentId = value.type.ccount > 1 ? m_module.opCompositeExtract(scalarTypeId, - value.id, 1, &srcComponentIndex) + value.id, 1, &sindex) : value.id; - // Add the component offset to the element index - uint32_t elementIndexAdjusted = i != 0 + uint32_t elementIndexAdjusted = sindex != 0 ? m_module.opIAdd(getVectorTypeId(elementIndex.type), - elementIndex.id, m_module.consti32(i)) + elementIndex.id, m_module.consti32(sindex)) : elementIndex.id; if (isTgsm) { @@ -3068,8 +3194,7 @@ namespace dxvk { throw DxvkError("DxbcCompiler: Invalid operand type for strucured/raw store"); } - // Write next component - srcComponentIndex += 1; + writeMask &= writeMask - 1u; } } } diff --git a/src/dxbc/dxbc_decoder.h b/src/dxbc/dxbc_decoder.h index 3fa4c377..77183715 100644 --- a/src/dxbc/dxbc_decoder.h +++ b/src/dxbc/dxbc_decoder.h @@ -149,6 +149,10 @@ namespace dxvk { : m_mask((x ? 0x1 : 0) | (y ? 0x2 : 0) | (z ? 0x4 : 0) | (w ? 0x8 : 0)) { } + uint32_t raw() const { + return m_mask; + } + bool operator [] (uint32_t id) const { return (m_mask >> id) & 1; }