Skip to content

Commit 0673ca7

Browse files
committed
Improved type hints [skip ci]
1 parent d95fa2e commit 0673ca7

4 files changed

Lines changed: 18 additions & 18 deletions

File tree

pgvector/sqlalchemy/bit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float
3-
from sqlalchemy import Dialect
3+
from sqlalchemy import Dialect, Operators
44
from typing import Any
55

66

@@ -29,10 +29,10 @@ def process(value):
2929
return super().bind_processor(dialect)
3030

3131
class comparator_factory(UserDefinedType.Comparator):
32-
def hamming_distance(self, other: Any) -> Any:
32+
def hamming_distance(self, other: Any) -> Operators:
3333
return self.op('<~>', return_type=Float)(other)
3434

35-
def jaccard_distance(self, other: Any) -> Any:
35+
def jaccard_distance(self, other: Any) -> Operators:
3636
return self.op('<%>', return_type=Float)(other)
3737

3838

pgvector/sqlalchemy/halfvec.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float, String
3-
from sqlalchemy import Dialect
3+
from sqlalchemy import Dialect, Operators
44
from typing import Any
55
from .. import HalfVector
66

@@ -36,16 +36,16 @@ def process(value):
3636
return process
3737

3838
class comparator_factory(UserDefinedType.Comparator):
39-
def l2_distance(self, other: Any) -> Any:
39+
def l2_distance(self, other: Any) -> Operators:
4040
return self.op('<->', return_type=Float)(other)
4141

42-
def max_inner_product(self, other: Any) -> Any:
42+
def max_inner_product(self, other: Any) -> Operators:
4343
return self.op('<#>', return_type=Float)(other)
4444

45-
def cosine_distance(self, other: Any) -> Any:
45+
def cosine_distance(self, other: Any) -> Operators:
4646
return self.op('<=>', return_type=Float)(other)
4747

48-
def l1_distance(self, other: Any) -> Any:
48+
def l1_distance(self, other: Any) -> Operators:
4949
return self.op('<+>', return_type=Float)(other)
5050

5151

pgvector/sqlalchemy/sparsevec.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float, String
3-
from sqlalchemy import Dialect
3+
from sqlalchemy import Dialect, Operators
44
from typing import Any
55
from .. import SparseVector
66

@@ -36,16 +36,16 @@ def process(value):
3636
return process
3737

3838
class comparator_factory(UserDefinedType.Comparator):
39-
def l2_distance(self, other: Any) -> Any:
39+
def l2_distance(self, other: Any) -> Operators:
4040
return self.op('<->', return_type=Float)(other)
4141

42-
def max_inner_product(self, other: Any) -> Any:
42+
def max_inner_product(self, other: Any) -> Operators:
4343
return self.op('<#>', return_type=Float)(other)
4444

45-
def cosine_distance(self, other: Any) -> Any:
45+
def cosine_distance(self, other: Any) -> Operators:
4646
return self.op('<=>', return_type=Float)(other)
4747

48-
def l1_distance(self, other: Any) -> Any:
48+
def l1_distance(self, other: Any) -> Operators:
4949
return self.op('<+>', return_type=Float)(other)
5050

5151

pgvector/sqlalchemy/vector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
22
from sqlalchemy.types import UserDefinedType, Float, String
3-
from sqlalchemy import Dialect
3+
from sqlalchemy import Dialect, Operators
44
from typing import Any
55
from .. import Vector
66

@@ -36,16 +36,16 @@ def process(value):
3636
return process
3737

3838
class comparator_factory(UserDefinedType.Comparator):
39-
def l2_distance(self, other: Any) -> Any:
39+
def l2_distance(self, other: Any) -> Operators:
4040
return self.op('<->', return_type=Float)(other)
4141

42-
def max_inner_product(self, other: Any) -> Any:
42+
def max_inner_product(self, other: Any) -> Operators:
4343
return self.op('<#>', return_type=Float)(other)
4444

45-
def cosine_distance(self, other: Any) -> Any:
45+
def cosine_distance(self, other: Any) -> Operators:
4646
return self.op('<=>', return_type=Float)(other)
4747

48-
def l1_distance(self, other: Any) -> Any:
48+
def l1_distance(self, other: Any) -> Operators:
4949
return self.op('<+>', return_type=Float)(other)
5050

5151

0 commit comments

Comments
 (0)