@@ -8368,96 +8368,59 @@ void TranslateStructBufSubscriptUser(Instruction *user, Value *handle,
83688368 baseOffset, status, OP, DL);
83698369 }
83708370 } else if (isa<LoadInst>(user) || isa<StoreInst>(user)) {
8371- LoadInst *ldInst = dyn_cast<LoadInst>(user);
8372- StoreInst *stInst = dyn_cast<StoreInst>(user);
8371+ LoadInst *LdInst = dyn_cast<LoadInst>(user);
8372+ StoreInst *StInst = dyn_cast<StoreInst>(user);
83738373
8374- Type *Ty = isa<LoadInst>(user) ? ldInst ->getType ()
8375- : stInst ->getValueOperand ()->getType ();
8374+ Type *Ty = isa<LoadInst>(user) ? LdInst ->getType ()
8375+ : StInst ->getValueOperand ()->getType ();
83768376 Type *pOverloadTy = Ty->getScalarType ();
8377- Value *offset = baseOffset;
8378- unsigned arraySize = 1 ;
8379- Value *eltSize = nullptr ;
8377+ Value *Offset = baseOffset;
83808378
8381- if (pOverloadTy->isArrayTy ()) {
8382- arraySize = pOverloadTy->getArrayNumElements ();
8383- eltSize = OP->GetU32Const (
8384- DL.getTypeAllocSize (pOverloadTy->getArrayElementType ()));
8379+ if (LdInst) {
8380+ unsigned NumComponents = 0 ;
8381+ Value *NewLd = nullptr ;
8382+ if (VectorType *VTy = dyn_cast<VectorType>(Ty))
8383+ NumComponents = VTy->getNumElements ();
8384+ else
8385+ NumComponents = 1 ;
83858386
8386- pOverloadTy = pOverloadTy->getArrayElementType ()->getScalarType ();
8387- }
8387+ if (ResKind == HLResource::Kind::TypedBuffer) {
8388+ // Typed buffer cannot have offsets, they must be loaded all at once
8389+ ResRetValueArray ResRet = GenerateTypedBufferLoad (
8390+ handle, pOverloadTy, bufIdx, status, OP, Builder);
83888391
8389- if (ldInst) {
8390- auto LdElement = [=](Value *offset, IRBuilder<> &Builder) -> Value * {
8391- unsigned numComponents = 0 ;
8392- if (VectorType *VTy = dyn_cast<VectorType>(Ty)) {
8393- numComponents = VTy->getNumElements ();
8394- } else {
8395- numComponents = 1 ;
8396- }
8397- Constant *alignment =
8392+ NewLd = ExtractFromTypedBufferLoad (ResRet, Ty, Offset, Builder);
8393+ } else {
8394+ Value *ResultElts[4 ];
8395+ Constant *Alignment =
83988396 OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8399- if (ResKind == HLResource::Kind::TypedBuffer) {
8400- // Typed buffer cannot have offsets, they must be loaded all at once
8401- ResRetValueArray ResRet = GenerateTypedBufferLoad (
8402- handle, pOverloadTy, bufIdx, status, OP, Builder);
8403-
8404- return ExtractFromTypedBufferLoad (ResRet, Ty, offset, Builder);
8405- } else {
8406- Value *ResultElts[4 ];
8407- GenerateRawBufLd (handle, bufIdx, offset, status, pOverloadTy,
8408- ResultElts, OP, Builder, numComponents, alignment);
8409- return ScalarizeElements (Ty, ResultElts, Builder);
8410- }
8411- };
8412-
8413- Value *newLd = LdElement (offset, Builder);
8414- if (arraySize > 1 ) {
8415- newLd =
8416- Builder.CreateInsertValue (UndefValue::get (Ty), newLd, (uint64_t )0 );
8417-
8418- for (unsigned i = 1 ; i < arraySize; i++) {
8419- offset = Builder.CreateAdd (offset, eltSize);
8420- Value *eltLd = LdElement (offset, Builder);
8421- newLd = Builder.CreateInsertValue (newLd, eltLd, i);
8422- }
8397+ GenerateRawBufLd (handle, bufIdx, Offset, status, pOverloadTy,
8398+ ResultElts, OP, Builder, NumComponents, Alignment);
8399+ NewLd = ScalarizeElements (Ty, ResultElts, Builder);
84238400 }
8424- ldInst->replaceAllUsesWith (newLd);
8401+
8402+ LdInst->replaceAllUsesWith (NewLd);
84258403 } else {
8426- Value *val = stInst->getValueOperand ();
8427- auto StElement = [&](Value *offset, Value *val, IRBuilder<> &Builder) {
8428- Value *undefVal = llvm::UndefValue::get (pOverloadTy);
8429- Value *vals[] = {undefVal, undefVal, undefVal, undefVal};
8430- uint8_t mask = 0 ;
8431- if (Ty->isVectorTy ()) {
8432- unsigned vectorNumElements = Ty->getVectorNumElements ();
8433- DXASSERT (vectorNumElements <= 4 , " up to 4 elements in vector" );
8434- assert (vectorNumElements <= 4 );
8435- for (unsigned i = 0 ; i < vectorNumElements; i++) {
8436- vals[i] = Builder.CreateExtractElement (val, i);
8437- mask |= (1 << i);
8438- }
8439- } else {
8440- vals[0 ] = val;
8441- mask = DXIL::kCompMask_X ;
8442- }
8443- Constant *alignment =
8444- OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8445- GenerateStructBufSt (handle, bufIdx, offset, pOverloadTy, OP, Builder,
8446- vals, mask, alignment);
8447- };
8448- if (arraySize > 1 )
8449- val = Builder.CreateExtractValue (val, 0 );
8450-
8451- StElement (offset, val, Builder);
8452- if (arraySize > 1 ) {
8453- val = stInst->getValueOperand ();
8454-
8455- for (unsigned i = 1 ; i < arraySize; i++) {
8456- offset = Builder.CreateAdd (offset, eltSize);
8457- Value *eltVal = Builder.CreateExtractValue (val, i);
8458- StElement (offset, eltVal, Builder);
8404+ Value *val = StInst->getValueOperand ();
8405+ Value *undefVal = llvm::UndefValue::get (pOverloadTy);
8406+ Value *vals[] = {undefVal, undefVal, undefVal, undefVal};
8407+ uint8_t mask = 0 ;
8408+ if (Ty->isVectorTy ()) {
8409+ unsigned vectorNumElements = Ty->getVectorNumElements ();
8410+ DXASSERT (vectorNumElements <= 4 , " up to 4 elements in vector" );
8411+ assert (vectorNumElements <= 4 );
8412+ for (unsigned i = 0 ; i < vectorNumElements; i++) {
8413+ vals[i] = Builder.CreateExtractElement (val, i);
8414+ mask |= (1 << i);
84598415 }
8416+ } else {
8417+ vals[0 ] = val;
8418+ mask = DXIL::kCompMask_X ;
84608419 }
8420+ Constant *alignment =
8421+ OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8422+ GenerateStructBufSt (handle, bufIdx, Offset, pOverloadTy, OP, Builder,
8423+ vals, mask, alignment);
84618424 }
84628425 user->eraseFromParent ();
84638426 } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(user)) {
0 commit comments