Skip to content

Commit ba4c406

Browse files
committed
Fix type errors: add type ignores for unsloth runtime function signatures
1 parent 7871978 commit ba4c406

3 files changed

Lines changed: 13 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ allowed-unresolved-imports = [
198198
"nbclient.**",
199199
"nbmake.**",
200200
"peft.**",
201+
"safetensors.**",
201202
"pyarrow.**",
202203
"torch.**",
203204
"torchao.**",

src/art/unsloth/dtype_patch.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ def patched_matmul_lora(
8686
W_full = W.dequantize()
8787
else:
8888
W_full = W.contiguous()
89-
out = torch_matmul(X, W_full.t(), out=out)
89+
out = torch_matmul(X, W_full.t(), out=out) # type: ignore[call-arg]
9090
elif getattr(W, "dtype", None) == getattr(torch, "float8_e4m3fn", None):
9191
if fp8_linear is None:
9292
raise RuntimeError("FP8 weights detected but fp8_linear unavailable.")
9393
out = fp8_linear(X, W, W_quant)
9494
else:
95-
W_full = fast_dequantize(W, W_quant, use_global_buffer=True)
96-
out = torch_matmul(X, W_full.t(), out=out)
95+
W_full = fast_dequantize(W, W_quant, use_global_buffer=True) # type: ignore[call-arg]
96+
out = torch_matmul(X, W_full.t(), out=out) # type: ignore[call-arg]
9797

9898
if A is not None:
9999
td = _target_dtype(out, dtype)
@@ -113,16 +113,16 @@ def patched_fast_linear_forward(
113113
return patched_matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
114114

115115
if W_quant is None:
116-
out = torch_matmul(X, W.t(), out=out)
116+
out = torch_matmul(X, W.t(), out=out) # type: ignore[call-arg]
117117
elif getattr(W, "dtype", None) == getattr(torch, "float8_e4m3fn", None):
118118
if fp8_linear is None:
119119
raise RuntimeError("FP8 weights detected but fp8_linear unavailable.")
120120
out = fp8_linear(X, W, W_quant, bias)
121121
elif fast_gemv is not None and bsz == 1 and q_len == 1:
122122
out = fast_gemv(X, W, W_quant, out=out)
123123
else:
124-
W_full = fast_dequantize(W.t(), W_quant, use_global_buffer=True)
125-
out = torch_matmul(X, W_full, out=out)
124+
W_full = fast_dequantize(W.t(), W_quant, use_global_buffer=True) # type: ignore[call-arg]
125+
out = torch_matmul(X, W_full, out=out) # type: ignore[call-arg]
126126

127127
if lora_A is not None:
128128
td = _target_dtype(out, X.dtype)
@@ -166,3 +166,7 @@ def patched_fast_linear_forward(
166166
if log:
167167
log.debug("Applied Unsloth LoRA dtype harmonisation patch.")
168168
return True
169+
170+
171+
# Apply eagerly so import side-effects protect downstream callers.
172+
ensure_dtype_patch(logging.getLogger(__name__))

src/art/unsloth/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ async def _train_dedicated(
592592
# Load forked adapter weights on first training call if needed.
593593
forked_dir = getattr(self, "_forked_checkpoint_dir", None)
594594
if forked_dir is not None:
595-
del self._forked_checkpoint_dir
595+
self._forked_checkpoint_dir = None
596596
await self._state.load_lora_adapter(forked_dir)
597597
async for result in run_unsloth_rl_training(
598598
self._state,
@@ -638,7 +638,7 @@ async def _train_shared(
638638
# Load forked adapter weights on first training call if needed.
639639
forked_dir = getattr(self, "_forked_checkpoint_dir", None)
640640
if forked_dir is not None:
641-
del self._forked_checkpoint_dir
641+
self._forked_checkpoint_dir = None
642642
await self._state.load_lora_adapter(forked_dir)
643643
llm = await self.llm
644644

0 commit comments

Comments
 (0)