diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 0c38ecd..5f9047f 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -53,7 +53,7 @@ def __init__( def __call__( self, x: mx.array, - offsets: list[int], + offsets: int | list[int] | mx.array, cache: TinyKvCache, mask: mx.array | str | None = None, ) -> mx.array: @@ -172,7 +172,7 @@ def __init__( def __call__( self, x: mx.array, - offset: int, + offset: int | list[int] | mx.array, cache: TinyKvCache, mask: mx.array | str | None = None, ) -> mx.array: @@ -266,7 +266,7 @@ def __init__( def __call__( self, inputs: mx.array, - offset: int, + offset: int | list[int] | mx.array, cache: list[TinyKvCache], ) -> mx.array: h = self.embedding(inputs)