Skip to content

Commit b1f9c8d

Browse files
committed
Improved type hints [skip ci]
1 parent bc02e89 commit b1f9c8d

3 files changed

Lines changed: 11 additions & 7 deletions

File tree

pgvector/halfvec.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
from __future__ import annotations
22
import numpy as np
33
from struct import pack, unpack_from
4-
from typing import Any
54

65

76
class HalfVector:
8-
def __init__(self, value: Any) -> None:
7+
def __init__(self, value: object) -> None:
98
# asarray still copies if same dtype
109
if not isinstance(value, np.ndarray) or value.dtype != '>f2':
1110
value = np.asarray(value, dtype='>f2')
1211

12+
# for mypy
13+
assert isinstance(value, np.ndarray)
14+
1315
if value.ndim != 1:
1416
raise ValueError('expected ndim to be 1')
1517

16-
self._value = value
18+
self._value = np.atleast_1d(value)
1719

1820
def __repr__(self) -> str:
1921
return f'HalfVector({self.to_list()})'

pgvector/sparsevec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class SparseVector:
10-
def __init__(self, value: Any, dimensions: int | Any = NO_DEFAULT, /) -> None:
10+
def __init__(self, value: dict[int, float] | list[float] | Any, dimensions: int | Any = NO_DEFAULT, /) -> None:
1111
if value.__class__.__module__.startswith('scipy.sparse.'):
1212
if dimensions is not NO_DEFAULT:
1313
raise ValueError('extra argument')

pgvector/vector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
from __future__ import annotations
22
import numpy as np
33
from struct import pack, unpack_from
4-
from typing import Any
54

65

76
class Vector:
8-
def __init__(self, value: Any) -> None:
7+
def __init__(self, value: object) -> None:
98
# asarray still copies if same dtype
109
if not isinstance(value, np.ndarray) or value.dtype != '>f4':
1110
value = np.asarray(value, dtype='>f4')
1211

12+
# for mypy
13+
assert isinstance(value, np.ndarray)
14+
1315
if value.ndim != 1:
1416
raise ValueError('expected ndim to be 1')
1517

16-
self._value = value
18+
self._value = np.atleast_1d(value)
1719

1820
def __repr__(self) -> str:
1921
return f'Vector({self.to_list()})'

0 commit comments

Comments
 (0)