Skip to content

Commit 09a9837

Browse files
committed
Added overloads [skip ci]
1 parent b1f9c8d commit 09a9837

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

pgvector/sparsevec.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
from __future__ import annotations
22
import numpy as np
33
from struct import pack, unpack_from
4-
from typing import Any
4+
from typing import Any, overload
55

66
NO_DEFAULT = object()
77

88

99
class SparseVector:
10+
@overload
11+
def __init__(self, value: dict[int, float], dimensions: int, /) -> None:
12+
...
13+
14+
@overload
15+
def __init__(self, value: list[float], /) -> None:
16+
...
17+
18+
@overload
19+
def __init__(self, value: Any, /) -> None:
20+
...
21+
1022
def __init__(self, value: dict[int, float] | list[float] | Any, dimensions: int | Any = NO_DEFAULT, /) -> None:
1123
if value.__class__.__module__.startswith('scipy.sparse.'):
1224
if dimensions is not NO_DEFAULT:

tests/test_sparse_vector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_list(self):
1414

1515
def test_list_dimensions(self):
1616
with pytest.raises(ValueError) as error:
17-
SparseVector([1, 0, 2, 0, 3, 0], 6)
17+
SparseVector([1, 0, 2, 0, 3, 0], 6) # ty: ignore[invalid-argument-type]
1818
assert str(error.value) == 'extra argument'
1919

2020
def test_ndarray(self):
@@ -40,7 +40,7 @@ def test_coo_array(self):
4040

4141
def test_coo_array_dimensions(self):
4242
with pytest.raises(ValueError) as error:
43-
SparseVector(coo_array(np.array([1, 0, 2, 0, 3, 0])), 6)
43+
SparseVector(coo_array(np.array([1, 0, 2, 0, 3, 0])), 6) # ty: ignore[invalid-argument-type]
4444
assert str(error.value) == 'extra argument'
4545

4646
def test_coo_matrix(self):

0 commit comments

Comments
 (0)