Skip to content

Commit d95fa2e

Browse files
committed
Added more type hints [skip ci]
1 parent afa127f commit d95fa2e

4 files changed

Lines changed: 40 additions & 32 deletions

File tree

pgvector/sqlalchemy/bit.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float
3+
from sqlalchemy import Dialect
4+
from typing import Any
35

46

57
class BIT(UserDefinedType):
68
cache_ok = True
79

8-
def __init__(self, length=None):
10+
def __init__(self, length: int | None = None) -> None:
911
super(UserDefinedType, self).__init__()
1012
self.length = length
1113

12-
def get_col_spec(self, **kw):
14+
def get_col_spec(self, **kw) -> str:
1315
if self.length is None:
1416
return 'BIT'
1517
return 'BIT(%d)' % self.length
1618

17-
def bind_processor(self, dialect):
19+
def bind_processor(self, dialect: Dialect) -> Any:
1820
if dialect.__class__.__name__ == 'PGDialect_asyncpg':
1921
import asyncpg
2022

@@ -27,10 +29,10 @@ def process(value):
2729
return super().bind_processor(dialect)
2830

2931
class comparator_factory(UserDefinedType.Comparator):
30-
def hamming_distance(self, other):
32+
def hamming_distance(self, other: Any) -> Any:
3133
return self.op('<~>', return_type=Float)(other)
3234

33-
def jaccard_distance(self, other):
35+
def jaccard_distance(self, other: Any) -> Any:
3436
return self.op('<%>', return_type=Float)(other)
3537

3638

pgvector/sqlalchemy/halfvec.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,51 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float, String
3+
from sqlalchemy import Dialect
4+
from typing import Any
35
from .. import HalfVector
46

57

68
class HALFVEC(UserDefinedType):
79
cache_ok = True
810
_string = String()
911

10-
def __init__(self, dim=None):
12+
def __init__(self, dim: int | None = None) -> None:
1113
super(UserDefinedType, self).__init__()
1214
self.dim = dim
1315

14-
def get_col_spec(self, **kw):
16+
def get_col_spec(self, **kw) -> str:
1517
if self.dim is None:
1618
return 'HALFVEC'
1719
return 'HALFVEC(%d)' % self.dim
1820

19-
def bind_processor(self, dialect):
21+
def bind_processor(self, dialect: Dialect):
2022
def process(value):
2123
return HalfVector._to_db(value, self.dim)
2224
return process
2325

24-
def literal_processor(self, dialect):
26+
def literal_processor(self, dialect: Dialect) -> Any:
2527
string_literal_processor = self._string._cached_literal_processor(dialect)
2628

2729
def process(value):
2830
return string_literal_processor(HalfVector._to_db(value, self.dim)) # type: ignore
2931
return process
3032

31-
def result_processor(self, dialect, coltype):
33+
def result_processor(self, dialect: Dialect, coltype: Any) -> Any:
3234
def process(value):
3335
return HalfVector._from_db(value)
3436
return process
3537

3638
class comparator_factory(UserDefinedType.Comparator):
37-
def l2_distance(self, other):
39+
def l2_distance(self, other: Any) -> Any:
3840
return self.op('<->', return_type=Float)(other)
3941

40-
def max_inner_product(self, other):
42+
def max_inner_product(self, other: Any) -> Any:
4143
return self.op('<#>', return_type=Float)(other)
4244

43-
def cosine_distance(self, other):
45+
def cosine_distance(self, other: Any) -> Any:
4446
return self.op('<=>', return_type=Float)(other)
4547

46-
def l1_distance(self, other):
48+
def l1_distance(self, other: Any) -> Any:
4749
return self.op('<+>', return_type=Float)(other)
4850

4951

pgvector/sqlalchemy/sparsevec.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,51 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float, String
3+
from sqlalchemy import Dialect
4+
from typing import Any
35
from .. import SparseVector
46

57

68
class SPARSEVEC(UserDefinedType):
79
cache_ok = True
810
_string = String()
911

10-
def __init__(self, dim=None):
12+
def __init__(self, dim: int | None = None) -> None:
1113
super(UserDefinedType, self).__init__()
1214
self.dim = dim
1315

14-
def get_col_spec(self, **kw):
16+
def get_col_spec(self, **kw) -> str:
1517
if self.dim is None:
1618
return 'SPARSEVEC'
1719
return 'SPARSEVEC(%d)' % self.dim
1820

19-
def bind_processor(self, dialect):
21+
def bind_processor(self, dialect: Dialect) -> Any:
2022
def process(value):
2123
return SparseVector._to_db(value, self.dim)
2224
return process
2325

24-
def literal_processor(self, dialect):
26+
def literal_processor(self, dialect: Dialect) -> Any:
2527
string_literal_processor = self._string._cached_literal_processor(dialect)
2628

2729
def process(value):
2830
return string_literal_processor(SparseVector._to_db(value, self.dim)) # type: ignore
2931
return process
3032

31-
def result_processor(self, dialect, coltype):
33+
def result_processor(self, dialect: Dialect, coltype: Any) -> Any:
3234
def process(value):
3335
return SparseVector._from_db(value)
3436
return process
3537

3638
class comparator_factory(UserDefinedType.Comparator):
37-
def l2_distance(self, other):
39+
def l2_distance(self, other: Any) -> Any:
3840
return self.op('<->', return_type=Float)(other)
3941

40-
def max_inner_product(self, other):
42+
def max_inner_product(self, other: Any) -> Any:
4143
return self.op('<#>', return_type=Float)(other)
4244

43-
def cosine_distance(self, other):
45+
def cosine_distance(self, other: Any) -> Any:
4446
return self.op('<=>', return_type=Float)(other)
4547

46-
def l1_distance(self, other):
48+
def l1_distance(self, other: Any) -> Any:
4749
return self.op('<+>', return_type=Float)(other)
4850

4951

pgvector/sqlalchemy/vector.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,51 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float, String
3+
from sqlalchemy import Dialect
4+
from typing import Any
35
from .. import Vector
46

57

68
class VECTOR(UserDefinedType):
79
cache_ok = True
810
_string = String()
911

10-
def __init__(self, dim=None):
12+
def __init__(self, dim: int | None = None) -> None:
1113
super(UserDefinedType, self).__init__()
1214
self.dim = dim
1315

14-
def get_col_spec(self, **kw):
16+
def get_col_spec(self, **kw) -> str:
1517
if self.dim is None:
1618
return 'VECTOR'
1719
return 'VECTOR(%d)' % self.dim
1820

19-
def bind_processor(self, dialect):
21+
def bind_processor(self, dialect: Dialect) -> Any:
2022
def process(value):
2123
return Vector._to_db(value, self.dim)
2224
return process
2325

24-
def literal_processor(self, dialect):
26+
def literal_processor(self, dialect: Dialect) -> Any:
2527
string_literal_processor = self._string._cached_literal_processor(dialect)
2628

2729
def process(value):
2830
return string_literal_processor(Vector._to_db(value, self.dim)) # type: ignore
2931
return process
3032

31-
def result_processor(self, dialect, coltype):
33+
def result_processor(self, dialect: Dialect, coltype: Any) -> Any:
3234
def process(value):
3335
return Vector._from_db(value)
3436
return process
3537

3638
class comparator_factory(UserDefinedType.Comparator):
37-
def l2_distance(self, other):
39+
def l2_distance(self, other: Any) -> Any:
3840
return self.op('<->', return_type=Float)(other)
3941

40-
def max_inner_product(self, other):
42+
def max_inner_product(self, other: Any) -> Any:
4143
return self.op('<#>', return_type=Float)(other)
4244

43-
def cosine_distance(self, other):
45+
def cosine_distance(self, other: Any) -> Any:
4446
return self.op('<=>', return_type=Float)(other)
4547

46-
def l1_distance(self, other):
48+
def l1_distance(self, other: Any) -> Any:
4749
return self.op('<+>', return_type=Float)(other)
4850

4951

0 commit comments

Comments
 (0)