1+ import hashlib
12from collections import defaultdict
23from collections .abc import Sequence
34from typing import Any , Protocol , TypeAlias , TypeVar
1112
1213
1314class Encoder (Protocol ):
14- """An encoder protocol for SemHash."""
15+ """An encoder protocol for SemHash. Supports text, images, or any encodable data. """
1516
1617 def encode (
1718 self ,
18- sentences : list [ str ] | str | Sequence [ str ] ,
19+ inputs : Sequence [ Any ] | Any ,
1920 ** kwargs : Any ,
2021 ) -> np .ndarray :
2122 """
22- Encode a list of sentences into embeddings.
23+ Encode a list of inputs into embeddings.
2324
24- :param sentences : A list of sentences to encode.
25+ :param inputs : A list of inputs to encode (strings, images, etc.) .
2526 :param **kwargs: Additional keyword arguments.
26- :return: The embeddings of the sentences .
27+ :return: The embeddings of the inputs .
2728 """
2829 ... # pragma: no cover
2930
@@ -48,26 +49,67 @@ def __getitem__(self, key: str) -> Sequence[Any]:
4849 ... # pragma: no cover
4950
5051
51- def to_frozendict ( record : dict [ str , str ], columns : Sequence [ str ] | set [ str ] ) -> frozendict [ str , str ] :
52+ def make_hashable ( value : Any ) -> Any :
5253 """
53- Convert a record to a frozendict.
54+ Convert a value to a hashable representation for use as dict keys.
55+
56+ Strings and other hashable types are returned as-is.
57+ Non-hashable types (like PIL images, numpy arrays) are hashed to a string.
58+
59+ :param value: The value to make hashable.
60+ :return: A hashable representation of the value.
61+ """
62+ # Fast path: most values are strings or already hashable
63+ if isinstance (value , (str , int , float , bool , type (None ))):
64+ return value
65+ # Handle objects with tobytes() (PIL Image, numpy array, etc.)
66+ if hasattr (value , "tobytes" ):
67+ return hashlib .md5 (value .tobytes ()).hexdigest ()
68+ # Fallback: try to hash, otherwise stringify
69+ try :
70+ hash (value )
71+ return value
72+ except TypeError :
73+ return str (value )
74+
75+
76+ def coerce_value (value : Any ) -> Any :
77+ """
78+ Coerce a value for encoding: stringify primitives, keep complex types raw.
79+
80+ This ensures primitives (int, float, bool) work with text encoders,
81+ while complex types (PIL images, tensors, etc.) are passed through for multimodal encoders.
82+
83+ :param value: The value to coerce.
84+ :return: The coerced value.
85+ """
86+ if isinstance (value , (str , bytes )):
87+ return value
88+ if isinstance (value , (int , float , bool )):
89+ return str (value )
90+ return value # Complex types (images, tensors, etc.)
91+
92+
93+ def to_frozendict (record : dict [str , Any ], columns : Sequence [str ] | set [str ]) -> frozendict [str , Any ]:
94+ """
95+ Convert a record to a frozendict with hashable values.
5496
5597 :param record: The record to convert.
5698 :param columns: The columns to include.
57- :return: A frozendict with only the specified columns.
99+ :return: A frozendict with only the specified columns (values made hashable) .
58100 :raises ValueError: If a column is missing from the record.
59101 """
60102 try :
61- return frozendict ({k : record [k ] for k in columns })
103+ return frozendict ({k : make_hashable ( record [k ]) for k in columns })
62104 except KeyError as e :
63105 missing = e .args [0 ]
64106 raise ValueError (f"Missing column '{ missing } ' in record { record } " ) from e
65107
66108
67109def group_records_by_key (
68- records : Sequence [dict [str , str ]],
110+ records : Sequence [dict [str , Any ]],
69111 columns : Sequence [str ],
70- ) -> tuple [list [dict [str , str ]], list [list [dict [str , str ]]]]:
112+ ) -> tuple [list [dict [str , Any ]], list [list [dict [str , Any ]]]]:
71113 """
72114 Group records by exact match on columns, preserving first-occurrence order.
73115
@@ -77,8 +119,8 @@ def group_records_by_key(
77119 - deduplicated_records: first record from each unique group
78120 - items: list of groups, each group is a list of exact duplicates
79121 """
80- buckets : dict [frozendict [str , str ], list [dict [str , str ]]] = {}
81- order : list [frozendict [str , str ]] = []
122+ buckets : dict [frozendict [str , Any ], list [dict [str , Any ]]] = {}
123+ order : list [frozendict [str , Any ]] = []
82124
83125 for r in records :
84126 key = to_frozendict (r , columns )
@@ -123,7 +165,7 @@ def compute_candidate_limit(
123165
124166
125167def featurize (
126- records : Sequence [dict [str , str ]],
168+ records : Sequence [dict [str , Any ]],
127169 columns : Sequence [str ],
128170 model : Encoder ,
129171) -> np .ndarray :
@@ -150,12 +192,12 @@ def featurize(
150192
151193
152194def remove_exact_duplicates (
153- records : Sequence [dict [str , str ]],
195+ records : Sequence [dict [str , Any ]],
154196 columns : Sequence [str ],
155- reference_records : list [list [dict [str , str ]]] | None = None ,
156- ) -> tuple [list [dict [str , str ]], list [tuple [dict [str , str ], list [dict [str , str ]]]]]:
197+ reference_records : list [list [dict [str , Any ]]] | None = None ,
198+ ) -> tuple [list [dict [str , Any ]], list [tuple [dict [str , Any ], list [dict [str , Any ]]]]]:
157199 """
158- Remove exact duplicates based on the unpacked string representation of each record.
200+ Remove exact duplicates based on the hashable representation of each record.
159201
160202 If reference_records is None, the function will only check for duplicates within the records list.
161203
@@ -164,12 +206,12 @@ def remove_exact_duplicates(
164206 :param reference_records: A list of records to compare against. These are already unpacked
165207 :return: A list of deduplicated records and a list of duplicates.
166208 """
167- deduplicated = []
168- duplicates = []
209+ deduplicated : list [ dict [ str , Any ]] = []
210+ duplicates : list [ tuple [ dict [ str , Any ], list [ dict [ str , Any ]]]] = []
169211
170212 column_set = set (columns )
171213 # Build a seen set from reference_records if provided
172- seen : defaultdict [frozendict [str , str ], list [dict [str , str ]]] = defaultdict (list )
214+ seen : defaultdict [frozendict [str , Any ], list [dict [str , Any ]]] = defaultdict (list )
173215 if reference_records is not None :
174216 for record_set in reference_records :
175217 key = to_frozendict (record_set [0 ], column_set )
@@ -191,7 +233,7 @@ def remove_exact_duplicates(
191233
192234def prepare_records (
193235 records : Sequence [Record ], columns : Sequence [str ] | None
194- ) -> tuple [list [dict [str , str ]], Sequence [str ], bool ]:
236+ ) -> tuple [list [dict [str , Any ]], Sequence [str ], bool ]:
195237 """
196238 Validate and prepare records for processing.
197239
@@ -214,23 +256,23 @@ def prepare_records(
214256 if not all (isinstance (r , str ) for r in records ):
215257 raise ValueError ("All records must be strings when the first record is a string." )
216258 columns = ["text" ]
217- dict_records : list [dict [str , str ]] = [{"text" : str ( record ) } for record in records ]
259+ dict_records : list [dict [str , Any ]] = [{"text" : record } for record in records ]
218260 was_string = True
219261 else :
220262 # Validate all records are dicts
221263 if not all (isinstance (r , dict ) for r in records ):
222264 raise ValueError ("All records must be dicts when the first record is a dict." )
223265 assert columns is not None
224- # Coerce dict values to strings (matching dataset behavior )
266+ # Coerce values: stringify primitives, keep complex types raw (for images, etc. )
225267 dict_records_typed : list [dict [str , Any ]] = list (records ) # type: ignore[arg-type]
226268 dict_records = []
227269 for r in dict_records_typed :
228- coerced = {}
270+ coerced : dict [ str , Any ] = {}
229271 for c in columns :
230272 val = r .get (c )
231273 if val is None :
232274 raise ValueError (f"Column '{ c } ' has None value in record { r } " )
233- coerced [c ] = val if isinstance ( val , str ) else str (val )
275+ coerced [c ] = coerce_value (val )
234276 dict_records .append (coerced )
235277 was_string = False
236278
@@ -263,7 +305,7 @@ def _validate_dataset(dataset: DatasetLike, columns: Sequence[str]) -> tuple[dic
263305def prepare_dataset_records (
264306 dataset : DatasetLike ,
265307 columns : Sequence [str ],
266- ) -> tuple [list [dict [str , str ]], list [list [dict [str , str ]]], bool ]:
308+ ) -> tuple [list [dict [str , Any ]], list [list [dict [str , Any ]]], bool ]:
267309 """
268310 Extract, validate, and exact-deduplicate dataset rows using columnar access.
269311
@@ -286,17 +328,18 @@ def prepare_dataset_records(
286328 # values in the dataset are actual strings (not integers/floats coerced to str).
287329 was_string = len (columns ) == 1 and columns [0 ] == "text"
288330
289- def coerce (raw : Any , * , col : str , idx : int ) -> str :
331+ def validate_and_coerce (raw : Any , * , col : str , idx : int ) -> Any :
332+ """Validate value is not None, then coerce for encoding."""
290333 if raw is None :
291334 raise ValueError (f"Column '{ col } ' has None at index { idx } " )
292- return raw if isinstance ( raw , str ) else str (raw )
335+ return coerce_value (raw )
293336
294337 # Build all records while tracking was_string
295- records : list [dict [str , str ]] = []
338+ records : list [dict [str , Any ]] = []
296339 for i in range (n ):
297340 if was_string and not isinstance (cols ["text" ][i ], str ):
298341 was_string = False
299- records .append ({c : coerce (cols [c ][i ], col = c , idx = i ) for c in columns })
342+ records .append ({c : validate_and_coerce (cols [c ][i ], col = c , idx = i ) for c in columns })
300343
301344 # Group by exact match, preserving first-occurrence order
302345 deduplicated_records , items = group_records_by_key (records , columns )
0 commit comments