|
1 | 1 | from sqlalchemy.dialects.postgresql.base import ischema_names |
2 | 2 | from sqlalchemy.types import UserDefinedType, Float, String |
3 | | -from sqlalchemy import Dialect |
| 3 | +from sqlalchemy import Dialect, Operators |
4 | 4 | from typing import Any |
5 | 5 | from .. import SparseVector |
6 | 6 |
|
@@ -36,16 +36,16 @@ def process(value): |
36 | 36 | return process |
37 | 37 |
|
38 | 38 | class comparator_factory(UserDefinedType.Comparator): |
39 | | - def l2_distance(self, other: Any) -> Any: |
| 39 | + def l2_distance(self, other: Any) -> Operators: |
40 | 40 | return self.op('<->', return_type=Float)(other) |
41 | 41 |
|
42 | | - def max_inner_product(self, other: Any) -> Any: |
| 42 | + def max_inner_product(self, other: Any) -> Operators: |
43 | 43 | return self.op('<#>', return_type=Float)(other) |
44 | 44 |
|
45 | | - def cosine_distance(self, other: Any) -> Any: |
| 45 | + def cosine_distance(self, other: Any) -> Operators: |
46 | 46 | return self.op('<=>', return_type=Float)(other) |
47 | 47 |
|
48 | | - def l1_distance(self, other: Any) -> Any: |
| 48 | + def l1_distance(self, other: Any) -> Operators: |
49 | 49 | return self.op('<+>', return_type=Float)(other) |
50 | 50 |
|
51 | 51 |
|
|
0 commit comments