Skip to content

Commit 9346894

Browse files
committed
Fix type errors: add type ignores for unsloth runtime function signatures
1 parent 7871978 commit 9346894

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ allowed-unresolved-imports = [
198198
"nbclient.**",
199199
"nbmake.**",
200200
"peft.**",
201+
"safetensors.**",
201202
"pyarrow.**",
202203
"torch.**",
203204
"torchao.**",

src/art/unsloth/dtype_patch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ def patched_matmul_lora(
8686
W_full = W.dequantize()
8787
else:
8888
W_full = W.contiguous()
89-
out = torch_matmul(X, W_full.t(), out=out)
89+
out = torch_matmul(X, W_full.t(), out=out) # type: ignore[call-arg]
9090
elif getattr(W, "dtype", None) == getattr(torch, "float8_e4m3fn", None):
9191
if fp8_linear is None:
9292
raise RuntimeError("FP8 weights detected but fp8_linear unavailable.")
9393
out = fp8_linear(X, W, W_quant)
9494
else:
95-
W_full = fast_dequantize(W, W_quant, use_global_buffer=True)
96-
out = torch_matmul(X, W_full.t(), out=out)
95+
W_full = fast_dequantize(W, W_quant, use_global_buffer=True) # type: ignore[call-arg]
96+
out = torch_matmul(X, W_full.t(), out=out) # type: ignore[call-arg]
9797

9898
if A is not None:
9999
td = _target_dtype(out, dtype)
@@ -113,16 +113,16 @@ def patched_fast_linear_forward(
113113
return patched_matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
114114

115115
if W_quant is None:
116-
out = torch_matmul(X, W.t(), out=out)
116+
out = torch_matmul(X, W.t(), out=out) # type: ignore[call-arg]
117117
elif getattr(W, "dtype", None) == getattr(torch, "float8_e4m3fn", None):
118118
if fp8_linear is None:
119119
raise RuntimeError("FP8 weights detected but fp8_linear unavailable.")
120120
out = fp8_linear(X, W, W_quant, bias)
121121
elif fast_gemv is not None and bsz == 1 and q_len == 1:
122122
out = fast_gemv(X, W, W_quant, out=out)
123123
else:
124-
W_full = fast_dequantize(W.t(), W_quant, use_global_buffer=True)
125-
out = torch_matmul(X, W_full, out=out)
124+
W_full = fast_dequantize(W.t(), W_quant, use_global_buffer=True) # type: ignore[call-arg]
125+
out = torch_matmul(X, W_full, out=out) # type: ignore[call-arg]
126126

127127
if lora_A is not None:
128128
td = _target_dtype(out, X.dtype)

0 commit comments

Comments
 (0)