@@ -221,7 +221,7 @@ def backward(ctx, grad_output):
221221
222222def supports_igemmlt (device : torch .device ) -> bool :
223223 """check if this device supports the optimized int8 kernel"""
224- if device == torch .device ("cpu" ):
224+ if device == torch .device ("cpu" ) or torch . device ( "xpu" ) :
225225 return True
226226 if torch .version .hip :
227227 return False if BNB_HIP_VERSION < 601 else True
@@ -463,7 +463,9 @@ def backward(ctx, grad_output):
463463 if len (grad_output .shape ) == 3 :
464464 grad_output = grad_output .reshape (- 1 , grad_output .shape [- 1 ]).contiguous ()
465465
466- Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output .to (torch .float16 ))
466+ Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = None , None , None , None , None
467+ if req_gradB or (req_gradA and state .CBt ):
468+ Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output .to (torch .float16 ))
467469 if req_gradB :
468470 CxAt , SAt = F .transform (CAt , formatB , transpose = True )
469471 C32grad , Sgrad = F .transform (Cgradt , "col32" , transpose = True )
@@ -517,7 +519,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
517519
518520 # 1. Dequantize
519521 # 2. MatmulnN
520- output = torch .nn .functional .linear (A , F .dequantize_4bit (B , quant_state ).to (A .dtype ).t (), bias )
522+ if A .device .type == "npu" :
523+ output = torch .matmul (A , F .dequantize_4bit (B , quant_state ).to (A .dtype ).t ())
524+ if bias is not None :
525+ output += bias
526+ else :
527+ output = torch .nn .functional .linear (A , F .dequantize_4bit (B , quant_state ).to (A .dtype ).t (), bias )
521528
522529 # 3. Save state
523530 ctx .state = quant_state
@@ -548,11 +555,37 @@ def backward(ctx, grad_output):
548555 # not supported by PyTorch. TODO: create work-around
549556 # if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
550557 if req_gradA :
551- grad_A = torch .matmul (grad_output , F .dequantize_4bit (B , ctx .state ).to (grad_output .dtype ).t ())
558+ if grad_output .device .type == "npu" :
559+ grad_A = torch .matmul (grad_output , F .dequantize_4bit (B , ctx .state ).to (grad_output .dtype ))
560+ else :
561+ grad_A = torch .matmul (grad_output , F .dequantize_4bit (B , ctx .state ).to (grad_output .dtype ).t ())
552562
553563 return grad_A , grad_B , None , grad_bias , None
554564
555565
566+ class MatMul8bitFp (torch .autograd .Function ):
567+ # For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune.
568+ # We'd like to use dequant + matmul to run finetune currently.
569+
570+ @staticmethod
571+ def forward (ctx , A , B , out = None , bias = None , state = MatmulLtState ):
572+ CB = B .data .to (A .dtype ).mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 )).t ()
573+ output = torch .matmul (A , CB ).to (A .dtype )
574+ ctx .state = state
575+ ctx .dtype_A = A .dtype
576+ ctx .grad_shape = A .shape
577+ return output
578+
579+ @staticmethod
580+ def backward (ctx , grad_output ):
581+ state = ctx .state
582+ B = state .CxB if state .CxB is not None else state .CB
583+ CB = B .to (ctx .dtype_A ).mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 ))
584+ grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
585+
586+ return grad_A , None , None , None , None
587+
588+
556589def matmul (
557590 A : torch .Tensor ,
558591 B : torch .Tensor ,
@@ -564,6 +597,8 @@ def matmul(
564597 state = state or MatmulLtState ()
565598 if threshold > 0.0 :
566599 state .threshold = threshold
600+ if A .device .type in ("cpu" , "xpu" ) and state .is_training :
601+ return MatMul8bitFp .apply (A , B , out , bias , state )
567602 return MatMul8bitLt .apply (A , B , out , bias , state )
568603
569604
@@ -575,8 +610,16 @@ def matmul_4bit(
575610 bias = None ,
576611):
577612 assert quant_state is not None
578- if (A .numel () == A .shape [- 1 ] or A .device .type == "cpu" ) and A .requires_grad == False :
579- # CPU backend does not require A to be a vector
613+ if A .device .type in ("cpu" , "xpu" ) and A .requires_grad == False :
614+ if getattr (quant_state , "ipex" , False ):
615+ B = B .t () if len (B .shape ) == 2 else B
616+ out = F .gemv_4bit (A , B , out , state = quant_state )
617+ if bias is not None :
618+ out += bias
619+ return out
620+ else :
621+ return MatMul4Bit .apply (A , B , out , bias , quant_state )
622+ elif A .numel () == A .shape [- 1 ] and A .requires_grad == False and A .device .type != "npu" :
580623 if A .shape [- 1 ] % quant_state .blocksize != 0 :
581624 warn (
582625 f"Some matrices hidden dimension is not a multiple of { quant_state .blocksize } and efficient inference kernels are not supported for these (slow). Matrix input size found: { A .shape } " ,
0 commit comments