[dxbc] Use raw access chains for buffer loads and stores

Maps more or less perfectly to D3D raw and structured buffers.
This commit is contained in:
Philip Rebohle 2024-02-16 21:50:43 +01:00
parent 69d74a46a0
commit c9cea93b7b
2 changed files with 161 additions and 32 deletions

View File

@ -2873,28 +2873,90 @@ namespace dxvk {
uint32_t sparseFeedbackId = 0;
const DxbcRegisterValue elementIndex = isStructured
? emitCalcBufferIndexStructured(
emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)),
emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false)),
bufferInfo.stride)
: emitCalcBufferIndexRaw(
emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)));
bool useRawAccessChains = m_hasRawAccessChains && isSsbo && !imageOperands.sparse;
DxbcRegisterValue index = emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false));
DxbcRegisterValue offset = index;
if (isStructured)
offset = emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false));
DxbcRegisterValue elementIndex = { };
uint32_t baseAlignment = sizeof(uint32_t);
if (useRawAccessChains) {
memoryOperands.flags |= spv::MemoryAccessAlignedMask;
if (isStructured && ins.src[1].type == DxbcOperandType::Imm32) {
baseAlignment = bufferInfo.stride | ins.src[1].imm.u32_1;
baseAlignment = baseAlignment & -baseAlignment;
baseAlignment = std::min(baseAlignment, uint32_t(m_moduleInfo.options.minSsboAlignment));
}
} else {
elementIndex = isStructured
? emitCalcBufferIndexStructured(index, offset, bufferInfo.stride)
: emitCalcBufferIndexRaw(offset);
}
uint32_t readMask = 0u;
for (uint32_t i = 0; i < 4; i++) {
uint32_t sindex = srcReg.swizzle[i];
if (dstReg.mask[i])
readMask |= 1u << srcReg.swizzle[i];
}
if (!dstReg.mask[i])
continue;
while (readMask) {
uint32_t sindex = bit::tzcnt(readMask);
uint32_t scount = bit::tzcnt(~(readMask >> sindex));
uint32_t zero = 0;
if (ccomps[sindex] == 0) {
if (useRawAccessChains) {
uint32_t alignment = baseAlignment;
uint32_t offsetId = offset.id;
if (sindex) {
offsetId = m_module.opIAdd(scalarTypeId,
offsetId, m_module.constu32(sizeof(uint32_t) * sindex));
alignment |= sizeof(uint32_t) * sindex;
}
DxbcRegisterInfo storeInfo;
storeInfo.type.ctype = DxbcScalarType::Uint32;
storeInfo.type.ccount = scount;
storeInfo.type.alength = 0;
storeInfo.sclass = spv::StorageClassStorageBuffer;
uint32_t loadTypeId = getArrayTypeId(storeInfo.type);
uint32_t ptrTypeId = getPointerTypeId(storeInfo);
uint32_t accessChain = isStructured
? m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId,
m_module.constu32(bufferInfo.stride), index.id, offsetId,
spv::RawAccessChainOperandsRobustnessPerElementNVMask)
: m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId,
m_module.constu32(0), m_module.constu32(0), offsetId,
spv::RawAccessChainOperandsRobustnessPerComponentNVMask);
memoryOperands.alignment = alignment & -alignment;
uint32_t vectorId = m_module.opLoad(loadTypeId, accessChain, memoryOperands);
for (uint32_t i = 0; i < scount; i++) {
ccomps[sindex + i] = vectorId;
if (scount > 1) {
ccomps[sindex + i] = m_module.opCompositeExtract(
scalarTypeId, vectorId, 1, &i);
}
}
readMask &= ~(((1u << scount) - 1u) << sindex);
} else {
uint32_t elementIndexAdjusted = m_module.opIAdd(
getVectorTypeId(elementIndex.type), elementIndex.id,
m_module.consti32(sindex));
// Load requested component from the buffer
uint32_t zero = 0;
if (isTgsm) {
ccomps[sindex] = m_module.opLoad(scalarTypeId,
m_module.opAccessChain(bufferInfo.typeId,
@ -2934,6 +2996,8 @@ namespace dxvk {
ccomps[sindex] = m_module.opCompositeExtract(scalarTypeId, resultId, 1, &zero);
}
readMask &= readMask - 1;
}
}
@ -2994,8 +3058,6 @@ namespace dxvk {
uint32_t scalarTypeId = getVectorTypeId({ DxbcScalarType::Uint32, 1 });
uint32_t vectorTypeId = getVectorTypeId({ DxbcScalarType::Uint32, 4 });
uint32_t srcComponentIndex = 0;
// Set memory operands according to resource properties
SpirvMemoryOperands memoryOperands;
SpirvImageOperands imageOperands;
@ -3020,26 +3082,90 @@ namespace dxvk {
}
}
// Compute flat element index
const DxbcRegisterValue elementIndex = isStructured
? emitCalcBufferIndexStructured(
emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)),
emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false)),
bufferInfo.stride)
: emitCalcBufferIndexRaw(
emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false)));
// Compute flat element index as necessary
bool useRawAccessChains = isSsbo && m_hasRawAccessChains;
for (uint32_t i = 0; i < 4; i++) {
if (dstReg.mask[i]) {
DxbcRegisterValue index = emitRegisterLoad(ins.src[0], DxbcRegMask(true, false, false, false));
DxbcRegisterValue offset = index;
if (isStructured)
offset = emitRegisterLoad(ins.src[1], DxbcRegMask(true, false, false, false));
DxbcRegisterValue elementIndex = { };
uint32_t baseAlignment = sizeof(uint32_t);
if (useRawAccessChains) {
memoryOperands.flags |= spv::MemoryAccessAlignedMask;
if (isStructured && ins.src[1].type == DxbcOperandType::Imm32) {
baseAlignment = bufferInfo.stride | ins.src[1].imm.u32_1;
baseAlignment = baseAlignment & -baseAlignment;
baseAlignment = std::min(baseAlignment, uint32_t(m_moduleInfo.options.minSsboAlignment));
}
} else {
elementIndex = isStructured
? emitCalcBufferIndexStructured(index, offset, bufferInfo.stride)
: emitCalcBufferIndexRaw(offset);
}
uint32_t writeMask = dstReg.mask.raw();
while (writeMask) {
uint32_t sindex = bit::tzcnt(writeMask);
uint32_t scount = bit::tzcnt(~(writeMask >> sindex));
if (useRawAccessChains) {
uint32_t alignment = baseAlignment;
uint32_t offsetId = offset.id;
if (sindex) {
offsetId = m_module.opIAdd(scalarTypeId,
offsetId, m_module.constu32(sizeof(uint32_t) * sindex));
alignment = alignment | (sizeof(uint32_t) * sindex);
}
DxbcRegisterInfo storeInfo;
storeInfo.type.ctype = DxbcScalarType::Uint32;
storeInfo.type.ccount = scount;
storeInfo.type.alength = 0;
storeInfo.sclass = spv::StorageClassStorageBuffer;
uint32_t storeTypeId = getArrayTypeId(storeInfo.type);
uint32_t ptrTypeId = getPointerTypeId(storeInfo);
uint32_t accessChain = isStructured
? m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId,
m_module.constu32(bufferInfo.stride), index.id, offsetId,
spv::RawAccessChainOperandsRobustnessPerElementNVMask)
: m_module.opRawAccessChain(ptrTypeId, bufferInfo.varId,
m_module.constu32(0), m_module.constu32(0), offsetId,
spv::RawAccessChainOperandsRobustnessPerComponentNVMask);
uint32_t valueId = value.id;
if (scount < value.type.ccount) {
if (scount == 1) {
valueId = m_module.opCompositeExtract(storeTypeId, value.id, 1, &sindex);
} else {
std::array<uint32_t, 4> indices = { sindex, sindex + 1u, sindex + 2u, sindex + 3u };
valueId = m_module.opVectorShuffle(storeTypeId, value.id, value.id, scount, indices.data());
}
}
memoryOperands.alignment = alignment & -alignment;
m_module.opStore(accessChain, valueId, memoryOperands);
writeMask &= ~(((1u << scount) - 1u) << sindex);
} else {
uint32_t srcComponentId = value.type.ccount > 1
? m_module.opCompositeExtract(scalarTypeId,
value.id, 1, &srcComponentIndex)
value.id, 1, &sindex)
: value.id;
// Add the component offset to the element index
uint32_t elementIndexAdjusted = i != 0
uint32_t elementIndexAdjusted = sindex != 0
? m_module.opIAdd(getVectorTypeId(elementIndex.type),
elementIndex.id, m_module.consti32(i))
elementIndex.id, m_module.consti32(sindex))
: elementIndex.id;
if (isTgsm) {
@ -3068,8 +3194,7 @@ namespace dxvk {
throw DxvkError("DxbcCompiler: Invalid operand type for strucured/raw store");
}
// Write next component
srcComponentIndex += 1;
writeMask &= writeMask - 1u;
}
}
}

View File

@ -149,6 +149,10 @@ namespace dxvk {
: m_mask((x ? 0x1 : 0) | (y ? 0x2 : 0)
| (z ? 0x4 : 0) | (w ? 0x8 : 0)) { }
uint32_t raw() const {
return m_mask;
}
bool operator [] (uint32_t id) const {
return (m_mask >> id) & 1;
}