Skip to content

Commit 3cb3fcc

Browse files
committed
fix: Match patched LoRA slice signature
1 parent d210742 commit 3cb3fcc

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

src/art/vllm/patches.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Monkey patches and modifications for vLLM."""
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
from torch import Tensor
47

58

69
def patch_transformers_v5_compat() -> None:
@@ -69,7 +72,10 @@ def can_replace_layer(
6972

7073
MergedColumnParallelLinearWithLoRA.can_replace_layer = can_replace_layer
7174

72-
def slice_lora_a(self: Any, lora_a: list[Any]) -> list[Any]:
75+
def slice_lora_a(
76+
self: Any,
77+
lora_a: "list[Tensor | None]",
78+
) -> "list[Tensor | None]":
7379
output_shard_size = self.lora_a_stacked[0].shape[2]
7480
output_start_idx = self.tp_rank * output_shard_size
7581
return [
@@ -79,7 +85,7 @@ def slice_lora_a(self: Any, lora_a: list[Any]) -> list[Any]:
7985
for a in lora_a
8086
]
8187

82-
MergedColumnParallelLinearWithShardedLoRA.slice_lora_a = slice_lora_a
88+
MergedColumnParallelLinearWithShardedLoRA.slice_lora_a = slice_lora_a # ty:ignore[invalid-assignment]
8389

8490

8591
def subclass_chat_completion_request() -> None:

0 commit comments

Comments
 (0)