|
1 | 1 | from sqlalchemy.dialects.postgresql.base import ischema_names |
2 | 2 | from sqlalchemy.types import UserDefinedType, Float, String |
| 3 | +from sqlalchemy import Dialect |
| 4 | +from typing import Any |
3 | 5 | from .. import SparseVector |
4 | 6 |
|
5 | 7 |
|
6 | 8 | class SPARSEVEC(UserDefinedType): |
7 | 9 | cache_ok = True |
8 | 10 | _string = String() |
9 | 11 |
|
10 | | - def __init__(self, dim=None): |
| 12 | + def __init__(self, dim: int | None = None) -> None: |
11 | 13 | super(UserDefinedType, self).__init__() |
12 | 14 | self.dim = dim |
13 | 15 |
|
14 | | - def get_col_spec(self, **kw): |
| 16 | + def get_col_spec(self, **kw) -> str: |
15 | 17 | if self.dim is None: |
16 | 18 | return 'SPARSEVEC' |
17 | 19 | return 'SPARSEVEC(%d)' % self.dim |
18 | 20 |
|
19 | | - def bind_processor(self, dialect): |
| 21 | + def bind_processor(self, dialect: Dialect) -> Any: |
20 | 22 | def process(value): |
21 | 23 | return SparseVector._to_db(value, self.dim) |
22 | 24 | return process |
23 | 25 |
|
24 | | - def literal_processor(self, dialect): |
| 26 | + def literal_processor(self, dialect: Dialect) -> Any: |
25 | 27 | string_literal_processor = self._string._cached_literal_processor(dialect) |
26 | 28 |
|
27 | 29 | def process(value): |
28 | 30 | return string_literal_processor(SparseVector._to_db(value, self.dim)) # type: ignore |
29 | 31 | return process |
30 | 32 |
|
31 | | - def result_processor(self, dialect, coltype): |
| 33 | + def result_processor(self, dialect: Dialect, coltype: Any) -> Any: |
32 | 34 | def process(value): |
33 | 35 | return SparseVector._from_db(value) |
34 | 36 | return process |
35 | 37 |
|
36 | 38 | class comparator_factory(UserDefinedType.Comparator): |
37 | | - def l2_distance(self, other): |
| 39 | + def l2_distance(self, other: Any) -> Any: |
38 | 40 | return self.op('<->', return_type=Float)(other) |
39 | 41 |
|
40 | | - def max_inner_product(self, other): |
| 42 | + def max_inner_product(self, other: Any) -> Any: |
41 | 43 | return self.op('<#>', return_type=Float)(other) |
42 | 44 |
|
43 | | - def cosine_distance(self, other): |
| 45 | + def cosine_distance(self, other: Any) -> Any: |
44 | 46 | return self.op('<=>', return_type=Float)(other) |
45 | 47 |
|
46 | | - def l1_distance(self, other): |
| 48 | + def l1_distance(self, other: Any) -> Any: |
47 | 49 | return self.op('<+>', return_type=Float)(other) |
48 | 50 |
|
49 | 51 |
|
|
0 commit comments