Skip to content

Commit f4498f2

Browse files
committed
Improved type hints [skip ci]
1 parent 0673ca7 commit f4498f2

5 files changed

Lines changed: 42 additions & 37 deletions

File tree

pgvector/peewee/bit.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
from peewee import Expression, Field
2+
from typing import Any
23

34

45
class FixedBitField(Field):
56
field_type = 'bit'
67

7-
def __init__(self, max_length=None, *args, **kwargs):
8+
def __init__(self, max_length: int | None = None, *args, **kwargs) -> None:
89
self.max_length = max_length
910
super(FixedBitField, self).__init__(*args, **kwargs)
1011

11-
def get_modifiers(self):
12-
return self.max_length and [self.max_length] or None
12+
def get_modifiers(self) -> list[int] | None:
13+
return [self.max_length] if self.max_length else None
1314

14-
def _distance(self, op, vector):
15+
def _distance(self, op: str, vector: Any) -> Expression:
1516
return Expression(lhs=self, op=op, rhs=self.to_value(vector))
1617

17-
def hamming_distance(self, vector):
18+
def hamming_distance(self, vector: Any) -> Expression:
1819
return self._distance('<~>', vector)
1920

20-
def jaccard_distance(self, vector):
21+
def jaccard_distance(self, vector: Any) -> Expression:
2122
return self._distance('<%%>', vector)

pgvector/peewee/halfvec.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
from peewee import Expression, Field
2+
from typing import Any
23
from .. import HalfVector
34

45

56
class HalfVectorField(Field):
67
field_type = 'halfvec'
78

8-
def __init__(self, dimensions=None, *args, **kwargs):
9+
def __init__(self, dimensions: int | None = None, *args, **kwargs) -> None:
910
self.dimensions = dimensions
1011
super(HalfVectorField, self).__init__(*args, **kwargs)
1112

12-
def get_modifiers(self):
13-
return self.dimensions and [self.dimensions] or None
13+
def get_modifiers(self) -> list[int] | None:
14+
return [self.dimensions] if self.dimensions else None
1415

15-
def db_value(self, value):
16+
def db_value(self, value: Any) -> str | None:
1617
return HalfVector._to_db(value)
1718

18-
def python_value(self, value):
19+
def python_value(self, value: Any) -> HalfVector | None:
1920
return HalfVector._from_db(value)
2021

21-
def _distance(self, op, vector):
22+
def _distance(self, op: str, vector: Any) -> Expression:
2223
return Expression(lhs=self, op=op, rhs=self.to_value(vector))
2324

24-
def l2_distance(self, vector):
25+
def l2_distance(self, vector: Any) -> Expression:
2526
return self._distance('<->', vector)
2627

27-
def max_inner_product(self, vector):
28+
def max_inner_product(self, vector: Any) -> Expression:
2829
return self._distance('<#>', vector)
2930

30-
def cosine_distance(self, vector):
31+
def cosine_distance(self, vector: Any) -> Expression:
3132
return self._distance('<=>', vector)
3233

33-
def l1_distance(self, vector):
34+
def l1_distance(self, vector: Any) -> Expression:
3435
return self._distance('<+>', vector)

pgvector/peewee/sparsevec.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
from peewee import Expression, Field
2+
from typing import Any
23
from .. import SparseVector
34

45

56
class SparseVectorField(Field):
67
field_type = 'sparsevec'
78

8-
def __init__(self, dimensions=None, *args, **kwargs):
9+
def __init__(self, dimensions: int | None = None, *args, **kwargs) -> None:
910
self.dimensions = dimensions
1011
super(SparseVectorField, self).__init__(*args, **kwargs)
1112

12-
def get_modifiers(self):
13-
return self.dimensions and [self.dimensions] or None
13+
def get_modifiers(self) -> list[int] | None:
14+
return [self.dimensions] if self.dimensions else None
1415

15-
def db_value(self, value):
16+
def db_value(self, value: Any) -> str | None:
1617
return SparseVector._to_db(value)
1718

18-
def python_value(self, value):
19+
def python_value(self, value: Any) -> SparseVector | None:
1920
return SparseVector._from_db(value)
2021

21-
def _distance(self, op, vector):
22+
def _distance(self, op: str, vector: Any) -> Expression:
2223
return Expression(lhs=self, op=op, rhs=self.to_value(vector))
2324

24-
def l2_distance(self, vector):
25+
def l2_distance(self, vector: Any) -> Expression:
2526
return self._distance('<->', vector)
2627

27-
def max_inner_product(self, vector):
28+
def max_inner_product(self, vector: Any) -> Expression:
2829
return self._distance('<#>', vector)
2930

30-
def cosine_distance(self, vector):
31+
def cosine_distance(self, vector: Any) -> Expression:
3132
return self._distance('<=>', vector)
3233

33-
def l1_distance(self, vector):
34+
def l1_distance(self, vector: Any) -> Expression:
3435
return self._distance('<+>', vector)

pgvector/peewee/vector.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
1+
import numpy as np
12
from peewee import Expression, Field
3+
from typing import Any
24
from .. import Vector
35

46

57
class VectorField(Field):
68
field_type = 'vector'
79

8-
def __init__(self, dimensions=None, *args, **kwargs):
10+
def __init__(self, dimensions: int | None = None, *args, **kwargs) -> None:
911
self.dimensions = dimensions
1012
super(VectorField, self).__init__(*args, **kwargs)
1113

12-
def get_modifiers(self):
13-
return self.dimensions and [self.dimensions] or None
14+
def get_modifiers(self) -> list[int] | None:
15+
return [self.dimensions] if self.dimensions else None
1416

15-
def db_value(self, value):
17+
def db_value(self, value: Any) -> str | None:
1618
return Vector._to_db(value)
1719

18-
def python_value(self, value):
20+
def python_value(self, value: Any) -> np.ndarray | None:
1921
return Vector._from_db(value)
2022

21-
def _distance(self, op, vector):
23+
def _distance(self, op: str, vector: Any) -> Expression:
2224
return Expression(lhs=self, op=op, rhs=self.to_value(vector))
2325

24-
def l2_distance(self, vector):
26+
def l2_distance(self, vector: Any) -> Expression:
2527
return self._distance('<->', vector)
2628

27-
def max_inner_product(self, vector):
29+
def max_inner_product(self, vector: Any) -> Expression:
2830
return self._distance('<#>', vector)
2931

30-
def cosine_distance(self, vector):
32+
def cosine_distance(self, vector: Any) -> Expression:
3133
return self._distance('<=>', vector)
3234

33-
def l1_distance(self, vector):
35+
def l1_distance(self, vector: Any) -> Expression:
3436
return self._distance('<+>', vector)

pgvector/sqlalchemy/halfvec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def get_col_spec(self, **kw) -> str:
1818
return 'HALFVEC'
1919
return 'HALFVEC(%d)' % self.dim
2020

21-
def bind_processor(self, dialect: Dialect):
21+
def bind_processor(self, dialect: Dialect) -> Any:
2222
def process(value):
2323
return HalfVector._to_db(value, self.dim)
2424
return process

0 commit comments

Comments
 (0)