Skip to content

Commit 9c2d44a

Browse files
committed
Added more type hints [skip ci]
1 parent 90fd6da commit 9c2d44a

7 files changed

Lines changed: 53 additions & 46 deletions

File tree

pgvector/django/bit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
11
from django import forms
22
from django.db.models import Field
3+
from typing import Any
34

45

56
# https://docs.djangoproject.com/en/5.0/howto/custom-model-fields/
67
class BitField(Field):
78
description = 'Bit string'
89

9-
def __init__(self, *args, length=None, **kwargs):
10+
def __init__(self, *args: Any, length: int | None = None, **kwargs: Any) -> None:
1011
self.length = length
1112
super().__init__(*args, **kwargs)
1213

13-
def deconstruct(self):
14+
def deconstruct(self) -> tuple:
1415
name, path, args, kwargs = super().deconstruct()
1516
if self.length is not None:
1617
kwargs['length'] = self.length
1718
return name, path, args, kwargs
1819

19-
def db_type(self, connection):
20+
def db_type(self, connection: Any) -> str:
2021
if self.length is None:
2122
return 'bit'
2223
return 'bit(%d)' % self.length
2324

24-
def formfield(self, **kwargs): # type: ignore
25+
def formfield(self, **kwargs: Any): # type: ignore
2526
return super().formfield(form_class=BitFormField, **kwargs)
2627

2728

2829
class BitFormField(forms.CharField):
29-
def to_python(self, value):
30+
def to_python(self, value: Any) -> Any:
3031
if isinstance(value, str) and value == '':
3132
return None
3233
return super().to_python(value)

pgvector/django/extensions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from django import VERSION
22
from django.contrib.postgres.operations import CreateExtension
3+
from typing import Any
34

45

56
class VectorExtension(CreateExtension):
6-
def __init__(self, hints=None):
7+
def __init__(self, hints: Any = None) -> None:
78
if VERSION[0] >= 6:
89
super().__init__('vector', hints=hints)
910
else:

pgvector/django/functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from django.db.models import FloatField, Func, Value
22
from .. import Vector, HalfVector, SparseVector
3+
from typing import Any
34

45

56
class DistanceBase(Func):
67
output_field = FloatField()
78

8-
def __init__(self, expression, vector, **extra):
9+
def __init__(self, expression: Any, vector: Any, **extra: Any) -> None:
910
if not hasattr(vector, 'resolve_expression'):
1011
if isinstance(vector, HalfVector):
1112
vector = Value(HalfVector._to_db(vector))
@@ -23,7 +24,7 @@ def __init__(self, expression, vector, **extra):
2324
class BitDistanceBase(Func):
2425
output_field = FloatField()
2526

26-
def __init__(self, expression, vector, **extra):
27+
def __init__(self, expression: Any, vector: Any, **extra: Any) -> None:
2728
if not hasattr(vector, 'resolve_expression'):
2829
vector = Value(vector)
2930
super().__init__(expression, vector, **extra)

pgvector/django/halfvec.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django import forms
22
from django.db.models import Field
3+
from typing import Any
34
from .. import HalfVector
45

56

@@ -8,44 +9,44 @@ class HalfVectorField(Field):
89
description = 'Half vector'
910
empty_strings_allowed = False
1011

11-
def __init__(self, *args, dimensions=None, **kwargs):
12+
def __init__(self, *args: Any, dimensions: int | None = None, **kwargs: Any) -> None:
1213
self.dimensions = dimensions
1314
super().__init__(*args, **kwargs)
1415

15-
def deconstruct(self):
16+
def deconstruct(self) -> tuple:
1617
name, path, args, kwargs = super().deconstruct()
1718
if self.dimensions is not None:
1819
kwargs['dimensions'] = self.dimensions
1920
return name, path, args, kwargs
2021

21-
def db_type(self, connection):
22+
def db_type(self, connection: Any) -> str:
2223
if self.dimensions is None:
2324
return 'halfvec'
2425
return 'halfvec(%d)' % self.dimensions
2526

26-
def from_db_value(self, value, expression, connection):
27+
def from_db_value(self, value: Any, expression: Any, connection: Any) -> HalfVector | None:
2728
return HalfVector._from_db(value)
2829

29-
def to_python(self, value):
30+
def to_python(self, value: Any) -> HalfVector | None:
3031
if value is None or isinstance(value, HalfVector):
3132
return value
3233
elif isinstance(value, str):
3334
return HalfVector._from_db(value)
3435
else:
3536
return HalfVector(value)
3637

37-
def get_prep_value(self, value):
38+
def get_prep_value(self, value: Any) -> str | None:
3839
return HalfVector._to_db(value)
3940

40-
def value_to_string(self, obj):
41+
def value_to_string(self, obj: Any) -> str | None:
4142
return self.get_prep_value(self.value_from_object(obj))
4243

4344
def formfield(self, **kwargs): # type: ignore
4445
return super().formfield(form_class=HalfVectorFormField, **kwargs)
4546

4647

4748
class HalfVectorWidget(forms.TextInput):
48-
def format_value(self, value):
49+
def format_value(self, value: Any) -> str | None:
4950
if isinstance(value, HalfVector):
5051
value = value.to_list()
5152
return super().format_value(value)
@@ -54,7 +55,7 @@ def format_value(self, value):
5455
class HalfVectorFormField(forms.CharField):
5556
widget = HalfVectorWidget
5657

57-
def to_python(self, value):
58+
def to_python(self, value: Any) -> Any:
5859
if isinstance(value, str) and value == '':
5960
return None
6061
return super().to_python(value)

pgvector/django/indexes.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from django.contrib.postgres.indexes import PostgresIndex
2+
from typing import Any
23

34

45
class IvfflatIndex(PostgresIndex):
56
suffix = 'ivfflat'
67

7-
def __init__(self, *expressions, lists=None, **kwargs):
8+
def __init__(self, *expressions: Any, lists: int | None = None, **kwargs: Any) -> None:
89
self.lists = lists
910
super().__init__(*expressions, **kwargs)
1011

11-
def deconstruct(self):
12+
def deconstruct(self) -> tuple:
1213
path, args, kwargs = super().deconstruct()
1314
if self.lists is not None:
1415
kwargs['lists'] = self.lists
1516
return path, args, kwargs
1617

17-
def get_with_params(self):
18+
def get_with_params(self) -> list[str]:
1819
with_params = []
1920
if self.lists is not None:
2021
with_params.append('lists = %d' % self.lists)
@@ -24,20 +25,20 @@ def get_with_params(self):
2425
class HnswIndex(PostgresIndex):
2526
suffix = 'hnsw'
2627

27-
def __init__(self, *expressions, m=None, ef_construction=None, **kwargs):
28+
def __init__(self, *expressions: Any, m: int | None = None, ef_construction: int | None = None, **kwargs: Any) -> None:
2829
self.m = m
2930
self.ef_construction = ef_construction
3031
super().__init__(*expressions, **kwargs)
3132

32-
def deconstruct(self):
33+
def deconstruct(self) -> tuple:
3334
path, args, kwargs = super().deconstruct()
3435
if self.m is not None:
3536
kwargs['m'] = self.m
3637
if self.ef_construction is not None:
3738
kwargs['ef_construction'] = self.ef_construction
3839
return path, args, kwargs
3940

40-
def get_with_params(self):
41+
def get_with_params(self) -> list[str]:
4142
with_params = []
4243
if self.m is not None:
4344
with_params.append('m = %d' % self.m)

pgvector/django/sparsevec.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django import forms
22
from django.db.models import Field
3+
from typing import Any
34
from .. import SparseVector
45

56

@@ -8,39 +9,39 @@ class SparseVectorField(Field):
89
description = 'Sparse vector'
910
empty_strings_allowed = False
1011

11-
def __init__(self, *args, dimensions=None, **kwargs):
12+
def __init__(self, *args: Any, dimensions: int | None = None, **kwargs: Any):
1213
self.dimensions = dimensions
1314
super().__init__(*args, **kwargs)
1415

15-
def deconstruct(self):
16+
def deconstruct(self) -> tuple:
1617
name, path, args, kwargs = super().deconstruct()
1718
if self.dimensions is not None:
1819
kwargs['dimensions'] = self.dimensions
1920
return name, path, args, kwargs
2021

21-
def db_type(self, connection):
22+
def db_type(self, connection: Any) -> str:
2223
if self.dimensions is None:
2324
return 'sparsevec'
2425
return 'sparsevec(%d)' % self.dimensions
2526

26-
def from_db_value(self, value, expression, connection):
27+
def from_db_value(self, value: Any, expression: Any, connection: Any) -> SparseVector | None:
2728
return SparseVector._from_db(value)
2829

29-
def to_python(self, value):
30+
def to_python(self, value: Any) -> SparseVector | None:
3031
return SparseVector._from_db(value)
3132

32-
def get_prep_value(self, value):
33+
def get_prep_value(self, value: Any) -> str | None:
3334
return SparseVector._to_db(value)
3435

35-
def value_to_string(self, obj):
36+
def value_to_string(self, obj: Any) -> str | None:
3637
return self.get_prep_value(self.value_from_object(obj))
3738

38-
def formfield(self, **kwargs): # type: ignore
39+
def formfield(self, **kwargs: Any): # type: ignore
3940
return super().formfield(form_class=SparseVectorFormField, **kwargs)
4041

4142

4243
class SparseVectorWidget(forms.TextInput):
43-
def format_value(self, value):
44+
def format_value(self, value: Any) -> str | None:
4445
if isinstance(value, SparseVector):
4546
value = value.to_text()
4647
return super().format_value(value)
@@ -49,7 +50,7 @@ def format_value(self, value):
4950
class SparseVectorFormField(forms.CharField):
5051
widget = SparseVectorWidget
5152

52-
def to_python(self, value):
53+
def to_python(self, value: Any) -> Any:
5354
if isinstance(value, str) and value == '':
5455
return None
5556
return super().to_python(value)

pgvector/django/vector.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from django import forms
22
from django.db.models import Field
33
import numpy as np
4+
from typing import Any
45
from .. import Vector
56

67

@@ -9,51 +10,51 @@ class VectorField(Field):
910
description = 'Vector'
1011
empty_strings_allowed = False
1112

12-
def __init__(self, *args, dimensions=None, **kwargs):
13+
def __init__(self, *args: Any, dimensions: int | None = None, **kwargs: Any) -> None:
1314
self.dimensions = dimensions
1415
super().__init__(*args, **kwargs)
1516

16-
def deconstruct(self):
17+
def deconstruct(self) -> tuple:
1718
name, path, args, kwargs = super().deconstruct()
1819
if self.dimensions is not None:
1920
kwargs['dimensions'] = self.dimensions
2021
return name, path, args, kwargs
2122

22-
def db_type(self, connection):
23+
def db_type(self, connection: Any) -> str:
2324
if self.dimensions is None:
2425
return 'vector'
2526
return 'vector(%d)' % self.dimensions
2627

27-
def from_db_value(self, value, expression, connection):
28+
def from_db_value(self, value: Any, expression: Any, connection: Any) -> np.ndarray | None:
2829
return Vector._from_db(value)
2930

30-
def to_python(self, value):
31+
def to_python(self, value: Any) -> np.ndarray | None:
3132
if isinstance(value, list):
3233
return np.array(value, dtype=np.float32)
3334
return Vector._from_db(value)
3435

35-
def get_prep_value(self, value):
36+
def get_prep_value(self, value: Any) -> str | None:
3637
return Vector._to_db(value)
3738

38-
def value_to_string(self, obj):
39+
def value_to_string(self, obj: Any) -> str | None:
3940
return self.get_prep_value(self.value_from_object(obj))
4041

41-
def validate(self, value, model_instance):
42+
def validate(self, value: Any, model_instance: Any) -> None:
4243
if isinstance(value, np.ndarray):
4344
value = value.tolist()
4445
super().validate(value, model_instance)
4546

46-
def run_validators(self, value):
47+
def run_validators(self, value: Any) -> None:
4748
if isinstance(value, np.ndarray):
4849
value = value.tolist()
4950
super().run_validators(value)
5051

51-
def formfield(self, **kwargs): # type: ignore
52+
def formfield(self, **kwargs: Any): # type: ignore
5253
return super().formfield(form_class=VectorFormField, **kwargs)
5354

5455

5556
class VectorWidget(forms.TextInput):
56-
def format_value(self, value):
57+
def format_value(self, value: Any) -> str | None:
5758
if isinstance(value, np.ndarray):
5859
value = value.tolist()
5960
return super().format_value(value)
@@ -62,12 +63,12 @@ def format_value(self, value):
6263
class VectorFormField(forms.CharField):
6364
widget = VectorWidget
6465

65-
def has_changed(self, initial, data):
66+
def has_changed(self, initial: Any, data: Any) -> bool:
6667
if isinstance(initial, np.ndarray):
6768
initial = initial.tolist()
6869
return super().has_changed(initial, data)
6970

70-
def to_python(self, value):
71+
def to_python(self, value: Any) -> Any:
7172
if isinstance(value, str) and value == '':
7273
return None
7374
return super().to_python(value)

0 commit comments

Comments
 (0)