Skip to content

Commit e113ac8

Browse files
committed
Generalized hashing functions to support complex types
1 parent b101fdb commit e113ac8

2 files changed

Lines changed: 85 additions & 45 deletions

File tree

semhash/semhash.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
43
from collections.abc import Sequence
54
from math import ceil
65
from typing import Any, Generic, Literal
@@ -18,6 +17,7 @@
1817
DatasetLike,
1918
Encoder,
2019
Record,
20+
coerce_value,
2121
compute_candidate_limit,
2222
featurize,
2323
group_records_by_key,
@@ -373,12 +373,12 @@ def self_deduplicate(
373373

374374
return result
375375

376-
def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[dict[str, str]]:
376+
def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[dict[str, Any]]:
377377
"""
378378
Validate if the records are strings.
379379
380380
If the records are strings, they are converted to dictionaries with a single column.
381-
If the records are dicts, values are coerced to strings and None is rejected.
381+
If the records are dicts, primitives are stringified and complex types (images, etc.) are kept raw.
382382
383383
:param records: The records to validate.
384384
:return: The records as a list of dictionaries.
@@ -396,25 +396,22 @@ def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[
396396
raise ValueError("Records were not originally strings, but you passed strings.")
397397
if not all(isinstance(r, str) for r in records):
398398
raise ValueError("Records must be all strings.")
399-
# Type narrowing: we've validated all are strings
400-
return [{"text": str(r)} for r in records]
399+
return [{"text": r} for r in records]
401400

402-
# Dict path - coerce values to strings (matching prepare_records behavior)
401+
# Dict path
403402
if not all(isinstance(r, dict) for r in records):
404403
raise ValueError("Records must be all dictionaries.")
405404

406-
# Type narrowing: we've validated all are dicts
407405
dict_records: Sequence[dict[str, Any]] = records # type: ignore[assignment]
408-
coerced: list[dict[str, str]] = []
406+
result: list[dict[str, Any]] = []
409407
for r in dict_records:
410-
out: dict[str, str] = {}
408+
out = {}
411409
for c in self.columns:
412-
val = r.get(c)
413-
if val is None:
410+
if (val := r.get(c)) is None:
414411
raise ValueError(f"Column '{c}' has None value in record {r}")
415-
out[c] = val if isinstance(val, str) else str(val)
416-
coerced.append(out)
417-
return coerced
412+
out[c] = coerce_value(val)
413+
result.append(out)
414+
return result
418415

419416
def find_representative(
420417
self,

semhash/utils.py

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
from collections import defaultdict
23
from collections.abc import Sequence
34
from typing import Any, Protocol, TypeAlias, TypeVar
@@ -11,19 +12,19 @@
1112

1213

1314
class Encoder(Protocol):
14-
"""An encoder protocol for SemHash."""
15+
"""An encoder protocol for SemHash. Supports text, images, or any encodable data."""
1516

1617
def encode(
1718
self,
18-
sentences: list[str] | str | Sequence[str],
19+
inputs: Sequence[Any] | Any,
1920
**kwargs: Any,
2021
) -> np.ndarray:
2122
"""
22-
Encode a list of sentences into embeddings.
23+
Encode a list of inputs into embeddings.
2324
24-
:param sentences: A list of sentences to encode.
25+
:param inputs: A list of inputs to encode (strings, images, etc.).
2526
:param **kwargs: Additional keyword arguments.
26-
:return: The embeddings of the sentences.
27+
:return: The embeddings of the inputs.
2728
"""
2829
... # pragma: no cover
2930

@@ -48,26 +49,67 @@ def __getitem__(self, key: str) -> Sequence[Any]:
4849
... # pragma: no cover
4950

5051

51-
def to_frozendict(record: dict[str, str], columns: Sequence[str] | set[str]) -> frozendict[str, str]:
52+
def make_hashable(value: Any) -> Any:
5253
"""
53-
Convert a record to a frozendict.
54+
Convert a value to a hashable representation for use as dict keys.
55+
56+
Strings and other hashable types are returned as-is.
57+
Non-hashable types (like PIL images, numpy arrays) are hashed to a string.
58+
59+
:param value: The value to make hashable.
60+
:return: A hashable representation of the value.
61+
"""
62+
# Fast path: most values are strings or already hashable
63+
if isinstance(value, (str, int, float, bool, type(None))):
64+
return value
65+
# Handle objects with tobytes() (PIL Image, numpy array, etc.)
66+
if hasattr(value, "tobytes"):
67+
return hashlib.md5(value.tobytes()).hexdigest()
68+
# Fallback: try to hash, otherwise stringify
69+
try:
70+
hash(value)
71+
return value
72+
except TypeError:
73+
return str(value)
74+
75+
76+
def coerce_value(value: Any) -> Any:
77+
"""
78+
Coerce a value for encoding: stringify primitives, keep complex types raw.
79+
80+
This ensures primitives (int, float, bool) work with text encoders,
81+
while complex types (PIL images, tensors, etc.) are passed through for multimodal encoders.
82+
83+
:param value: The value to coerce.
84+
:return: The coerced value.
85+
"""
86+
if isinstance(value, (str, bytes)):
87+
return value
88+
if isinstance(value, (int, float, bool)):
89+
return str(value)
90+
return value # Complex types (images, tensors, etc.)
91+
92+
93+
def to_frozendict(record: dict[str, Any], columns: Sequence[str] | set[str]) -> frozendict[str, Any]:
94+
"""
95+
Convert a record to a frozendict with hashable values.
5496
5597
:param record: The record to convert.
5698
:param columns: The columns to include.
57-
:return: A frozendict with only the specified columns.
99+
:return: A frozendict with only the specified columns (values made hashable).
58100
:raises ValueError: If a column is missing from the record.
59101
"""
60102
try:
61-
return frozendict({k: record[k] for k in columns})
103+
return frozendict({k: make_hashable(record[k]) for k in columns})
62104
except KeyError as e:
63105
missing = e.args[0]
64106
raise ValueError(f"Missing column '{missing}' in record {record}") from e
65107

66108

67109
def group_records_by_key(
68-
records: Sequence[dict[str, str]],
110+
records: Sequence[dict[str, Any]],
69111
columns: Sequence[str],
70-
) -> tuple[list[dict[str, str]], list[list[dict[str, str]]]]:
112+
) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]:
71113
"""
72114
Group records by exact match on columns, preserving first-occurrence order.
73115
@@ -77,8 +119,8 @@ def group_records_by_key(
77119
- deduplicated_records: first record from each unique group
78120
- items: list of groups, each group is a list of exact duplicates
79121
"""
80-
buckets: dict[frozendict[str, str], list[dict[str, str]]] = {}
81-
order: list[frozendict[str, str]] = []
122+
buckets: dict[frozendict[str, Any], list[dict[str, Any]]] = {}
123+
order: list[frozendict[str, Any]] = []
82124

83125
for r in records:
84126
key = to_frozendict(r, columns)
@@ -123,7 +165,7 @@ def compute_candidate_limit(
123165

124166

125167
def featurize(
126-
records: Sequence[dict[str, str]],
168+
records: Sequence[dict[str, Any]],
127169
columns: Sequence[str],
128170
model: Encoder,
129171
) -> np.ndarray:
@@ -150,12 +192,12 @@ def featurize(
150192

151193

152194
def remove_exact_duplicates(
153-
records: Sequence[dict[str, str]],
195+
records: Sequence[dict[str, Any]],
154196
columns: Sequence[str],
155-
reference_records: list[list[dict[str, str]]] | None = None,
156-
) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]:
197+
reference_records: list[list[dict[str, Any]]] | None = None,
198+
) -> tuple[list[dict[str, Any]], list[tuple[dict[str, Any], list[dict[str, Any]]]]]:
157199
"""
158-
Remove exact duplicates based on the unpacked string representation of each record.
200+
Remove exact duplicates based on the hashable representation of each record.
159201
160202
If reference_records is None, the function will only check for duplicates within the records list.
161203
@@ -164,12 +206,12 @@ def remove_exact_duplicates(
164206
:param reference_records: A list of records to compare against. These are already unpacked
165207
:return: A list of deduplicated records and a list of duplicates.
166208
"""
167-
deduplicated = []
168-
duplicates = []
209+
deduplicated: list[dict[str, Any]] = []
210+
duplicates: list[tuple[dict[str, Any], list[dict[str, Any]]]] = []
169211

170212
column_set = set(columns)
171213
# Build a seen set from reference_records if provided
172-
seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list)
214+
seen: defaultdict[frozendict[str, Any], list[dict[str, Any]]] = defaultdict(list)
173215
if reference_records is not None:
174216
for record_set in reference_records:
175217
key = to_frozendict(record_set[0], column_set)
@@ -191,7 +233,7 @@ def remove_exact_duplicates(
191233

192234
def prepare_records(
193235
records: Sequence[Record], columns: Sequence[str] | None
194-
) -> tuple[list[dict[str, str]], Sequence[str], bool]:
236+
) -> tuple[list[dict[str, Any]], Sequence[str], bool]:
195237
"""
196238
Validate and prepare records for processing.
197239
@@ -214,23 +256,23 @@ def prepare_records(
214256
if not all(isinstance(r, str) for r in records):
215257
raise ValueError("All records must be strings when the first record is a string.")
216258
columns = ["text"]
217-
dict_records: list[dict[str, str]] = [{"text": str(record)} for record in records]
259+
dict_records: list[dict[str, Any]] = [{"text": record} for record in records]
218260
was_string = True
219261
else:
220262
# Validate all records are dicts
221263
if not all(isinstance(r, dict) for r in records):
222264
raise ValueError("All records must be dicts when the first record is a dict.")
223265
assert columns is not None
224-
# Coerce dict values to strings (matching dataset behavior)
266+
# Coerce values: stringify primitives, keep complex types raw (for images, etc.)
225267
dict_records_typed: list[dict[str, Any]] = list(records) # type: ignore[arg-type]
226268
dict_records = []
227269
for r in dict_records_typed:
228-
coerced = {}
270+
coerced: dict[str, Any] = {}
229271
for c in columns:
230272
val = r.get(c)
231273
if val is None:
232274
raise ValueError(f"Column '{c}' has None value in record {r}")
233-
coerced[c] = val if isinstance(val, str) else str(val)
275+
coerced[c] = coerce_value(val)
234276
dict_records.append(coerced)
235277
was_string = False
236278

@@ -263,7 +305,7 @@ def _validate_dataset(dataset: DatasetLike, columns: Sequence[str]) -> tuple[dic
263305
def prepare_dataset_records(
264306
dataset: DatasetLike,
265307
columns: Sequence[str],
266-
) -> tuple[list[dict[str, str]], list[list[dict[str, str]]], bool]:
308+
) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]], bool]:
267309
"""
268310
Extract, validate, and exact-deduplicate dataset rows using columnar access.
269311
@@ -286,17 +328,18 @@ def prepare_dataset_records(
286328
# values in the dataset are actual strings (not integers/floats coerced to str).
287329
was_string = len(columns) == 1 and columns[0] == "text"
288330

289-
def coerce(raw: Any, *, col: str, idx: int) -> str:
331+
def validate_and_coerce(raw: Any, *, col: str, idx: int) -> Any:
332+
"""Validate value is not None, then coerce for encoding."""
290333
if raw is None:
291334
raise ValueError(f"Column '{col}' has None at index {idx}")
292-
return raw if isinstance(raw, str) else str(raw)
335+
return coerce_value(raw)
293336

294337
# Build all records while tracking was_string
295-
records: list[dict[str, str]] = []
338+
records: list[dict[str, Any]] = []
296339
for i in range(n):
297340
if was_string and not isinstance(cols["text"][i], str):
298341
was_string = False
299-
records.append({c: coerce(cols[c][i], col=c, idx=i) for c in columns})
342+
records.append({c: validate_and_coerce(cols[c][i], col=c, idx=i) for c in columns})
300343

301344
# Group by exact match, preserving first-occurrence order
302345
deduplicated_records, items = group_records_by_key(records, columns)

0 commit comments

Comments
 (0)