Skip to content

Commit d3e376e

Browse files
committed
Improved internal storage format for Bit [skip ci]
1 parent 033923b commit d3e376e

1 file changed

Lines changed: 15 additions & 10 deletions

File tree

pgvector/bit.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
class Bit:
88
def __init__(self, value: bytes | str | list[bool] | np.ndarray[tuple[int], np.dtype[np.bool | np.uint8]]) -> None:
99
if isinstance(value, bytes):
10-
self._len = 8 * len(value)
11-
self._data = value
10+
_len = 8 * len(value)
11+
_data = value
1212
else:
1313
if isinstance(value, str):
1414
value = [v != '0' for v in value]
@@ -30,28 +30,34 @@ def __init__(self, value: bytes | str | list[bool] | np.ndarray[tuple[int], np.d
3030
if value.ndim != 1:
3131
raise ValueError('expected ndim to be 1')
3232

33-
self._len = len(value)
34-
self._data = np.packbits(value).tobytes()
33+
_len = len(value)
34+
_data = np.packbits(value).tobytes()
35+
36+
self._value = pack('>i', _len) + _data
3537

3638
def __repr__(self) -> str:
3739
return f'Bit({self.to_text()})'
3840

3941
def __eq__(self, other: object) -> bool:
4042
if isinstance(other, self.__class__):
41-
return self._len == other._len and self._data == other._data
43+
return self.to_binary() == other.to_binary()
4244
return False
4345

46+
def _len(self):
47+
_len, = unpack_from('>i', self._value)
48+
return _len
49+
4450
def to_list(self) -> list[bool]:
4551
return self.to_numpy().tolist()
4652

4753
def to_numpy(self) -> np.ndarray[tuple[int], np.dtype[np.bool]]:
48-
return np.unpackbits(np.frombuffer(self._data, dtype=np.uint8), count=self._len).astype(bool)
54+
return np.unpackbits(np.frombuffer(self._value[4:], dtype=np.uint8), count=self._len()).astype(bool)
4955

5056
def to_text(self) -> str:
51-
return ''.join(format(v, '08b') for v in self._data)[:self._len]
57+
return ''.join(format(v, '08b') for v in self._value[4:])[:self._len()]
5258

5359
def to_binary(self) -> bytes:
54-
return pack('>i', self._len) + self._data
60+
return self._value
5561

5662
@classmethod
5763
def from_text(cls, value: str) -> Bit:
@@ -63,8 +69,7 @@ def from_binary(cls, value: bytes) -> Bit:
6369
raise ValueError('expected bytes')
6470

6571
bit = cls.__new__(cls)
66-
bit._len = unpack_from('>i', value)[0]
67-
bit._data = value[4:]
72+
bit._value = value
6873
return bit
6974

7075
@classmethod

0 commit comments

Comments
 (0)