|
1 | 1 | from peewee import Expression, Field |
| 2 | +from typing import Any |
2 | 3 | from .. import SparseVector |
3 | 4 |
|
4 | 5 |
|
5 | 6 | class SparseVectorField(Field): |
6 | 7 | field_type = 'sparsevec' |
7 | 8 |
|
8 | | - def __init__(self, dimensions=None, *args, **kwargs): |
| 9 | + def __init__(self, dimensions: int | None = None, *args, **kwargs) -> None: |
9 | 10 | self.dimensions = dimensions |
10 | 11 | super(SparseVectorField, self).__init__(*args, **kwargs) |
11 | 12 |
|
12 | | - def get_modifiers(self): |
13 | | - return self.dimensions and [self.dimensions] or None |
| 13 | + def get_modifiers(self) -> list[int] | None: |
| 14 | + return [self.dimensions] if self.dimensions else None |
14 | 15 |
|
15 | | - def db_value(self, value): |
| 16 | + def db_value(self, value: Any) -> str | None: |
16 | 17 | return SparseVector._to_db(value) |
17 | 18 |
|
18 | | - def python_value(self, value): |
| 19 | + def python_value(self, value: Any) -> SparseVector | None: |
19 | 20 | return SparseVector._from_db(value) |
20 | 21 |
|
21 | | - def _distance(self, op, vector): |
| 22 | + def _distance(self, op: str, vector: Any) -> Expression: |
22 | 23 | return Expression(lhs=self, op=op, rhs=self.to_value(vector)) |
23 | 24 |
|
24 | | - def l2_distance(self, vector): |
| 25 | + def l2_distance(self, vector: Any) -> Expression: |
25 | 26 | return self._distance('<->', vector) |
26 | 27 |
|
27 | | - def max_inner_product(self, vector): |
| 28 | + def max_inner_product(self, vector: Any) -> Expression: |
28 | 29 | return self._distance('<#>', vector) |
29 | 30 |
|
30 | | - def cosine_distance(self, vector): |
| 31 | + def cosine_distance(self, vector: Any) -> Expression: |
31 | 32 | return self._distance('<=>', vector) |
32 | 33 |
|
33 | | - def l1_distance(self, vector): |
| 34 | + def l1_distance(self, vector: Any) -> Expression: |
34 | 35 | return self._distance('<+>', vector) |
0 commit comments