Skip to content

Commit afa127f

Browse files
committed
Added type hints [skip ci]
1 parent c1233e2 commit afa127f

17 files changed

Lines changed: 152 additions & 110 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.5.0 (unreleased)
22

3+
- Added type hints
34
- Dropped support for Python < 3.10
45
- Dropped support for SQLAlchemy < 2
56

pgvector/asyncpg/register.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from asyncpg import Connection
12
from .. import Vector, HalfVector, SparseVector
23

34

4-
async def register_vector(conn, schema='public'):
5+
async def register_vector(conn: Connection, schema: str = 'public') -> None:
56
await conn.set_type_codec(
67
'vector',
78
schema=schema,

pgvector/bit.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from __future__ import annotations
12
import numpy as np
23
from struct import pack, unpack_from
4+
from typing import Any
35
from warnings import warn
46

57

68
class Bit:
7-
def __init__(self, value):
9+
def __init__(self, value: Any) -> None:
810
if isinstance(value, bytes):
911
self._len = 8 * len(value)
1012
self._data = value
@@ -26,32 +28,32 @@ def __init__(self, value):
2628
self._len = len(value)
2729
self._data = np.packbits(value).tobytes()
2830

29-
def __repr__(self):
31+
def __repr__(self) -> str:
3032
return f'Bit({self.to_text()})'
3133

32-
def __eq__(self, other):
34+
def __eq__(self, other: Any) -> bool:
3335
if isinstance(other, self.__class__):
3436
return self._len == other._len and self._data == other._data
3537
return False
3638

37-
def to_list(self):
39+
def to_list(self) -> list[bool]:
3840
return self.to_numpy().tolist()
3941

40-
def to_numpy(self):
42+
def to_numpy(self) -> np.ndarray:
4143
return np.unpackbits(np.frombuffer(self._data, dtype=np.uint8), count=self._len).astype(bool)
4244

43-
def to_text(self):
45+
def to_text(self) -> str:
4446
return ''.join(format(v, '08b') for v in self._data)[:self._len]
4547

46-
def to_binary(self):
48+
def to_binary(self) -> bytes:
4749
return pack('>i', self._len) + self._data
4850

4951
@classmethod
50-
def from_text(cls, value):
52+
def from_text(cls, value: str) -> Bit:
5153
return cls(str(value))
5254

5355
@classmethod
54-
def from_binary(cls, value):
56+
def from_binary(cls, value: bytes) -> Bit:
5557
if not isinstance(value, bytes):
5658
raise ValueError('expected bytes')
5759

@@ -61,15 +63,15 @@ def from_binary(cls, value):
6163
return bit
6264

6365
@classmethod
64-
def _to_db(cls, value):
65-
if not isinstance(value, cls):
66+
def _to_db(cls, value: Bit) -> str:
67+
if not isinstance(value, Bit):
6668
raise ValueError('expected bit')
6769

6870
return value.to_text()
6971

7072
@classmethod
71-
def _to_db_binary(cls, value):
72-
if not isinstance(value, cls):
73+
def _to_db_binary(cls, value: Bit) -> bytes:
74+
if not isinstance(value, Bit):
7375
raise ValueError('expected bit')
7476

7577
return value.to_binary()

pgvector/halfvec.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from __future__ import annotations
12
import numpy as np
23
from struct import pack, unpack_from
4+
from typing import Any
35

46

57
class HalfVector:
6-
def __init__(self, value):
8+
def __init__(self, value: Any) -> None:
79
# asarray still copies if same dtype
810
if not isinstance(value, np.ndarray) or value.dtype != '>f2':
911
value = np.asarray(value, dtype='>f2')
@@ -13,40 +15,40 @@ def __init__(self, value):
1315

1416
self._value = value
1517

16-
def __repr__(self):
18+
def __repr__(self) -> str:
1719
return f'HalfVector({self.to_list()})'
1820

19-
def __eq__(self, other):
21+
def __eq__(self, other: Any) -> bool:
2022
if isinstance(other, self.__class__):
2123
return np.array_equal(self.to_numpy(), other.to_numpy())
2224
return False
2325

24-
def dimensions(self):
26+
def dimensions(self) -> int:
2527
return len(self._value)
2628

27-
def to_list(self):
29+
def to_list(self) -> list[float]:
2830
return self._value.tolist()
2931

30-
def to_numpy(self):
32+
def to_numpy(self) -> np.ndarray:
3133
return self._value
3234

33-
def to_text(self):
35+
def to_text(self) -> str:
3436
return '[' + ','.join([str(float(v)) for v in self._value]) + ']'
3537

36-
def to_binary(self):
38+
def to_binary(self) -> bytes:
3739
return pack('>HH', self.dimensions(), 0) + self._value.tobytes()
3840

3941
@classmethod
40-
def from_text(cls, value):
42+
def from_text(cls, value: str) -> HalfVector:
4143
return cls([float(v) for v in value[1:-1].split(',')])
4244

4345
@classmethod
44-
def from_binary(cls, value):
46+
def from_binary(cls, value: bytes) -> HalfVector:
4547
dim, unused = unpack_from('>HH', value)
4648
return cls(np.frombuffer(value, dtype='>f2', count=dim, offset=4))
4749

4850
@classmethod
49-
def _to_db(cls, value, dim=None):
51+
def _to_db(cls, value: Any, dim: int | None = None) -> str | None:
5052
if value is None:
5153
return value
5254

@@ -59,7 +61,7 @@ def _to_db(cls, value, dim=None):
5961
return value.to_text()
6062

6163
@classmethod
62-
def _to_db_binary(cls, value):
64+
def _to_db_binary(cls, value: Any) -> bytes | None:
6365
if value is None:
6466
return value
6567

@@ -69,15 +71,15 @@ def _to_db_binary(cls, value):
6971
return value.to_binary()
7072

7173
@classmethod
72-
def _from_db(cls, value):
73-
if value is None or isinstance(value, cls):
74+
def _from_db(cls, value: str | HalfVector | None) -> HalfVector | None:
75+
if value is None or isinstance(value, HalfVector):
7476
return value
7577

7678
return cls.from_text(value)
7779

7880
@classmethod
79-
def _from_db_binary(cls, value):
80-
if value is None or isinstance(value, cls):
81+
def _from_db_binary(cls, value: bytes | HalfVector | None) -> HalfVector | None:
82+
if value is None or isinstance(value, HalfVector):
8183
return value
8284

8385
return cls.from_binary(value)

pgvector/pg8000/register.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
2+
from pg8000.native import Connection
23
from .. import Vector, HalfVector, SparseVector
34

45

5-
def register_vector(conn):
6+
def register_vector(conn: Connection) -> None:
67
# use to_regtype to get first matching type in search path
78
res = conn.run("SELECT typname, oid FROM pg_type WHERE oid IN (to_regtype('vector'), to_regtype('halfvec'), to_regtype('sparsevec'))")
89
type_info = dict(res)

pgvector/psycopg/bit.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1+
from psycopg import BaseConnection
2+
from psycopg.types import TypeInfo
13
from psycopg.adapt import Dumper
24
from psycopg.pq import Format
5+
from typing import Any, TypeAlias
36
from .. import Bit
47

8+
Buffer: TypeAlias = bytes | bytearray | memoryview
9+
510

611
class BitDumper(Dumper):
712

813
format = Format.TEXT
914

10-
def dump(self, obj):
15+
def dump(self, obj: Bit) -> Buffer | None:
1116
return Bit._to_db(obj).encode('utf8')
1217

1318

1419
class BitBinaryDumper(BitDumper):
1520

1621
format = Format.BINARY
1722

18-
def dump(self, obj):
23+
def dump(self, obj: Bit) -> Buffer | None:
1924
return Bit._to_db_binary(obj)
2025

2126

22-
def register_bit_info(context, info):
27+
def register_bit_info(context: BaseConnection[Any], info: TypeInfo) -> None:
2328
info.register(context)
2429

2530
# add oid to anonymous class for set_types

pgvector/psycopg/halfvec.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
1+
from psycopg import BaseConnection
12
from psycopg.adapt import Loader, Dumper
23
from psycopg.pq import Format
4+
from psycopg.types import TypeInfo
5+
from typing import Any, TypeAlias
36
from .. import HalfVector
47

8+
Buffer: TypeAlias = bytes | bytearray | memoryview
9+
510

611
class HalfVectorDumper(Dumper):
712

813
format = Format.TEXT
914

10-
def dump(self, obj):
11-
return HalfVector._to_db(obj).encode('utf8')
15+
def dump(self, obj: HalfVector) -> Buffer | None:
16+
value = HalfVector._to_db(obj)
17+
return value if value is None else value.encode('utf8')
1218

1319

1420
class HalfVectorBinaryDumper(HalfVectorDumper):
1521

1622
format = Format.BINARY
1723

18-
def dump(self, obj):
24+
def dump(self, obj: HalfVector) -> Buffer | None:
1925
return HalfVector._to_db_binary(obj)
2026

2127

2228
class HalfVectorLoader(Loader):
2329

2430
format = Format.TEXT
2531

26-
def load(self, data):
32+
def load(self, data: Buffer) -> HalfVector | None:
2733
if isinstance(data, memoryview):
2834
data = bytes(data)
2935
return HalfVector._from_db(data.decode('utf8'))
@@ -33,13 +39,13 @@ class HalfVectorBinaryLoader(HalfVectorLoader):
3339

3440
format = Format.BINARY
3541

36-
def load(self, data):
37-
if isinstance(data, memoryview):
42+
def load(self, data: Buffer) -> HalfVector | None:
43+
if isinstance(data, (bytearray, memoryview)):
3844
data = bytes(data)
3945
return HalfVector._from_db_binary(data)
4046

4147

42-
def register_halfvec_info(context, info):
48+
def register_halfvec_info(context: BaseConnection[Any], info: TypeInfo) -> None:
4349
info.register(context)
4450

4551
# add oid to anonymous class for set_types

pgvector/psycopg/register.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
from psycopg import AsyncConnection, Connection
12
from psycopg.types import TypeInfo
3+
from typing import Any
24
from .bit import register_bit_info
35
from .halfvec import register_halfvec_info
46
from .sparsevec import register_sparsevec_info
57
from .vector import register_vector_info
68

79

8-
def register_vector(context):
10+
def register_vector(context: Connection[Any]) -> None:
911
info = TypeInfo.fetch(context, 'vector')
1012
register_vector_info(context, info)
1113

1214
info = TypeInfo.fetch(context, 'bit')
15+
assert info is not None
1316
register_bit_info(context, info)
1417

1518
info = TypeInfo.fetch(context, 'halfvec')
@@ -21,11 +24,12 @@ def register_vector(context):
2124
register_sparsevec_info(context, info)
2225

2326

24-
async def register_vector_async(context):
27+
async def register_vector_async(context: AsyncConnection[Any]) -> None:
2528
info = await TypeInfo.fetch(context, 'vector')
2629
register_vector_info(context, info)
2730

2831
info = await TypeInfo.fetch(context, 'bit')
32+
assert info is not None
2933
register_bit_info(context, info)
3034

3135
info = await TypeInfo.fetch(context, 'halfvec')

pgvector/psycopg/sparsevec.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
1+
from psycopg import BaseConnection
12
from psycopg.adapt import Loader, Dumper
23
from psycopg.pq import Format
4+
from psycopg.types import TypeInfo
5+
from typing import Any, TypeAlias
36
from .. import SparseVector
47

8+
Buffer: TypeAlias = bytes | bytearray | memoryview
9+
510

611
class SparseVectorDumper(Dumper):
712

813
format = Format.TEXT
914

10-
def dump(self, obj):
11-
return SparseVector._to_db(obj).encode('utf8')
15+
def dump(self, obj: SparseVector) -> Buffer | None:
16+
value = SparseVector._to_db(obj)
17+
return value if value is None else value.encode('utf8')
1218

1319

1420
class SparseVectorBinaryDumper(SparseVectorDumper):
1521

1622
format = Format.BINARY
1723

18-
def dump(self, obj):
24+
def dump(self, obj: SparseVector) -> Buffer | None:
1925
return SparseVector._to_db_binary(obj)
2026

2127

2228
class SparseVectorLoader(Loader):
2329

2430
format = Format.TEXT
2531

26-
def load(self, data):
32+
def load(self, data: Buffer) -> SparseVector | None:
2733
if isinstance(data, memoryview):
2834
data = bytes(data)
2935
return SparseVector._from_db(data.decode('utf8'))
@@ -33,13 +39,13 @@ class SparseVectorBinaryLoader(SparseVectorLoader):
3339

3440
format = Format.BINARY
3541

36-
def load(self, data):
37-
if isinstance(data, memoryview):
42+
def load(self, data: Buffer) -> SparseVector | None:
43+
if isinstance(data, (bytearray, memoryview)):
3844
data = bytes(data)
3945
return SparseVector._from_db_binary(data)
4046

4147

42-
def register_sparsevec_info(context, info):
48+
def register_sparsevec_info(context: BaseConnection[Any], info: TypeInfo) -> None:
4349
info.register(context)
4450

4551
# add oid to anonymous class for set_types

0 commit comments

Comments
 (0)