@@ -86,14 +86,14 @@ def patched_matmul_lora(
8686 W_full = W .dequantize ()
8787 else :
8888 W_full = W .contiguous ()
89- out = torch_matmul (X , W_full .t (), out = out )
89+ out = torch_matmul (X , W_full .t (), out = out ) # type: ignore[call-arg]
9090 elif getattr (W , "dtype" , None ) == getattr (torch , "float8_e4m3fn" , None ):
9191 if fp8_linear is None :
9292 raise RuntimeError ("FP8 weights detected but fp8_linear unavailable." )
9393 out = fp8_linear (X , W , W_quant )
9494 else :
95- W_full = fast_dequantize (W , W_quant , use_global_buffer = True )
96- out = torch_matmul (X , W_full .t (), out = out )
95+ W_full = fast_dequantize (W , W_quant , use_global_buffer = True ) # type: ignore[call-arg]
96+ out = torch_matmul (X , W_full .t (), out = out ) # type: ignore[call-arg]
9797
9898 if A is not None :
9999 td = _target_dtype (out , dtype )
@@ -113,16 +113,16 @@ def patched_fast_linear_forward(
113113 return patched_matmul_lora (X , W , W_quant , lora_A , lora_B , lora_S )
114114
115115 if W_quant is None :
116- out = torch_matmul (X , W .t (), out = out )
116+ out = torch_matmul (X , W .t (), out = out ) # type: ignore[call-arg]
117117 elif getattr (W , "dtype" , None ) == getattr (torch , "float8_e4m3fn" , None ):
118118 if fp8_linear is None :
119119 raise RuntimeError ("FP8 weights detected but fp8_linear unavailable." )
120120 out = fp8_linear (X , W , W_quant , bias )
121121 elif fast_gemv is not None and bsz == 1 and q_len == 1 :
122122 out = fast_gemv (X , W , W_quant , out = out )
123123 else :
124- W_full = fast_dequantize (W .t (), W_quant , use_global_buffer = True )
125- out = torch_matmul (X , W_full , out = out )
124+ W_full = fast_dequantize (W .t (), W_quant , use_global_buffer = True ) # type: ignore[call-arg]
125+ out = torch_matmul (X , W_full , out = out ) # type: ignore[call-arg]
126126
127127 if lora_A is not None :
128128 td = _target_dtype (out , X .dtype )
0 commit comments