Skip to content

Commit c64a72e

Browse files
committed
Removed from_dataset integration
1 parent e808922 commit c64a72e

9 files changed

Lines changed: 52 additions & 1317 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ dev = [
4343
"ruff",
4444
]
4545

46-
# Integrations
47-
datasets = ["datasets"]
48-
all = [
49-
"datasets",
50-
]
5146

5247
[project.urls]
5348
"Homepage" = "https://github.com/MinishLab"

semhash/records.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from frozendict import frozendict
66

77
from semhash.datamodels import DeduplicationResult, DuplicateRecord
8-
from semhash.utils import DatasetLike, Record, coerce_value, to_frozendict
8+
from semhash.utils import Record, coerce_value, to_frozendict
99

1010

1111
def group_records_by_key(
@@ -126,69 +126,6 @@ def prepare_records(
126126
return dict_records, columns, was_string
127127

128128

129-
def _validate_dataset(dataset: DatasetLike, columns: Sequence[str]) -> tuple[dict[str, Sequence[Any]], int]:
130-
"""Validate dataset structure and extract columns."""
131-
try:
132-
column_names = dataset.column_names
133-
except AttributeError as e:
134-
raise TypeError("dataset must satisfy DatasetLike (column_names, __len__, __getitem__)") from e
135-
136-
missing = set(columns) - set(column_names)
137-
if missing:
138-
raise ValueError(f"Columns {missing} not found in dataset")
139-
140-
n = len(dataset)
141-
if n == 0:
142-
raise ValueError("dataset must not be empty")
143-
144-
cols = {c: dataset[c] for c in columns}
145-
for c in columns:
146-
if len(cols[c]) != n:
147-
raise ValueError(f"Column '{c}' length ({len(cols[c])}) does not match dataset length ({n})")
148-
149-
return cols, n
150-
151-
152-
def prepare_dataset_records(
153-
dataset: DatasetLike,
154-
columns: Sequence[str],
155-
) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]], bool]:
156-
"""
157-
Extract, validate, and exact-deduplicate dataset rows using columnar access.
158-
159-
:param dataset: A dataset-like object with columnar access.
160-
:param columns: Columns to use for deduplication.
161-
:return: Tuple of (deduplicated_records, items, was_string) where:
162-
- deduplicated_records: representative record per exact-duplicate bucket
163-
- items: buckets of exact duplicates (each bucket is list[record])
164-
- was_string: True iff columns == ["text"] and ALL raw values were strings
165-
"""
166-
cols, n = _validate_dataset(dataset, columns)
167-
168-
# was_string controls whether deduplicate() returns strings or dicts.
169-
# We only return strings if: (1) single column named "text", AND (2) all raw
170-
# values in the dataset are actual strings (not integers/floats coerced to str).
171-
was_string = len(columns) == 1 and columns[0] == "text"
172-
173-
def validate_and_coerce(raw: Any, *, col: str, idx: int) -> Any:
174-
"""Validate value is not None, then coerce for encoding."""
175-
if raw is None:
176-
raise ValueError(f"Column '{col}' has None at index {idx}")
177-
return coerce_value(raw)
178-
179-
# Build all records while tracking was_string
180-
records: list[dict[str, Any]] = []
181-
for i in range(n):
182-
if was_string and not isinstance(cols["text"][i], str):
183-
was_string = False
184-
records.append({c: validate_and_coerce(cols[c][i], col=c, idx=i) for c in columns})
185-
186-
# Group by exact match, preserving first-occurrence order
187-
deduplicated_records, items = group_records_by_key(records, columns)
188-
189-
return deduplicated_records, items, was_string
190-
191-
192129
def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str:
193130
r"""
194131
Turn a record into a single string.

semhash/semhash.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
add_scores_to_records,
1717
group_records_by_key,
1818
map_deduplication_result_to_strings,
19-
prepare_dataset_records,
2019
prepare_records,
2120
remove_exact_duplicates,
2221
)
2322
from semhash.utils import (
24-
DatasetLike,
2523
Encoder,
2624
Record,
2725
coerce_value,
@@ -84,41 +82,6 @@ def from_records(
8482
index = Index.from_vectors_and_items(vectors=embeddings, items=items, backend_type=ann_backend, **kwargs)
8583
return cls(index=index, model=model, columns=columns, was_string=was_string)
8684

87-
@classmethod
88-
def from_dataset(
89-
cls,
90-
dataset: DatasetLike,
91-
columns: Sequence[str],
92-
model: Encoder | None = None,
93-
ann_backend: Backend | str = Backend.USEARCH,
94-
**kwargs: Any,
95-
) -> SemHash:
96-
"""
97-
Initialize SemHash from a dataset (e.g., HuggingFace Dataset).
98-
99-
Removes exact duplicates, featurizes the records, and fits a vicinity index.
100-
Supports any dataset-like object that follows the DatasetLike protocol.
101-
102-
:param dataset: A dataset-like object with columnar access.
103-
:param columns: Columns to use for deduplication (same as from_records).
104-
:param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M).
105-
:param ann_backend: (Optional) The ANN backend to use. Defaults to Backend.USEARCH.
106-
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
107-
:return: A SemHash instance with a fitted vicinity index.
108-
"""
109-
# Load default model if needed
110-
if model is None: # pragma: no cover
111-
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
112-
113-
# Extract, validate, and deduplicate dataset records
114-
deduplicated_records, items, was_string = prepare_dataset_records(dataset, columns)
115-
116-
# Create embeddings for deduplicated records only
117-
vectors = featurize(records=deduplicated_records, columns=columns, model=model)
118-
119-
index = Index.from_vectors_and_items(vectors=vectors, items=items, backend_type=ann_backend, **kwargs)
120-
return cls(index=index, model=model, columns=columns, was_string=was_string)
121-
12285
@classmethod
12386
def from_embeddings(
12487
cls,

semhash/utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,6 @@ def encode(
2828
... # pragma: no cover
2929

3030

31-
class DatasetLike(Protocol):
32-
"""
33-
Protocol for dataset-like objects compatible with SemHash.from_dataset().
34-
35-
Any object that provides columnar access (dataset[column_name] -> sequence)
36-
satisfies this protocol. HuggingFace datasets.Dataset is the primary example,
37-
but custom dataset implementations are supported.
38-
"""
39-
40-
column_names: Sequence[str]
41-
42-
def __len__(self) -> int:
43-
"""Return the number of rows in the dataset."""
44-
... # pragma: no cover
45-
46-
def __getitem__(self, key: str) -> Sequence[Any]:
47-
"""Return all values for the given column name."""
48-
... # pragma: no cover
49-
50-
5131
def make_hashable(value: Any) -> Any:
5232
"""
5333
Convert a value to a hashable representation for use as dict keys.

semhash/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version_triple__ = (0, 3, 3)
2-
__version__ = ".".join(map(str, __version_triple__))
1+
__version_triple__ = (0, 3, 3) # pragma: no cover
2+
__version__ = ".".join(map(str, __version_triple__)) # pragma: no cover

tests/test_from_dataset.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

tests/test_semhash.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -279,37 +279,6 @@ def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None:
279279
assert semhash.index.vectors.tolist() == [[0.0], [1.0], [3.0]]
280280

281281

282-
def test_from_dataset_with_custom_dataset_like(model: Encoder) -> None:
283-
"""Test that from_dataset works with custom DatasetLike implementations (no HF dependency)."""
284-
285-
class MiniDataset:
286-
"""Minimal DatasetLike implementation for testing."""
287-
288-
column_names = ["text"]
289-
290-
def __init__(self, data: dict[str, list[str]]) -> None:
291-
self._data = data
292-
293-
def __len__(self) -> int:
294-
return len(self._data["text"])
295-
296-
def __getitem__(self, key: str) -> list[str]:
297-
return self._data[key]
298-
299-
# Create custom dataset with duplicates
300-
ds = MiniDataset({"text": ["apple", "banana", "apple"]})
301-
302-
semhash = SemHash.from_dataset(ds, columns=["text"], model=model)
303-
304-
# Should have deduplicated to 2 unique items
305-
assert len(semhash.index.items) == 2
306-
assert len(semhash.index.vectors) == 2
307-
308-
# Should work with deduplication
309-
result = semhash.self_deduplicate(threshold=0.95)
310-
assert len(result.selected) == 2
311-
312-
313282
def test_from_records_edge_cases(model: Encoder) -> None:
314283
"""Test from_records edge cases: coercion, order preservation, None rejection."""
315284
# Coerces non-string dict values to strings

tests/test_utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,54 @@
33
from frozendict import frozendict
44

55
from semhash.records import prepare_records, remove_exact_duplicates
6-
from semhash.utils import Encoder, compute_candidate_limit, featurize, to_frozendict
6+
from semhash.utils import Encoder, coerce_value, compute_candidate_limit, featurize, make_hashable, to_frozendict
7+
8+
9+
def test_make_hashable() -> None:
10+
"""Test make_hashable with various types."""
11+
# Fast path: primitives
12+
assert make_hashable("hello") == "hello"
13+
assert make_hashable(42) == 42
14+
assert make_hashable(3.14) == 3.14
15+
assert make_hashable(True) is True
16+
assert make_hashable(None) is None
17+
18+
# Objects with tobytes() (simulate PIL Image or numpy array)
19+
class MockImage:
20+
def tobytes(self) -> bytes:
21+
return b"fake_image_data"
22+
23+
img = MockImage()
24+
result = make_hashable(img)
25+
assert isinstance(result, str)
26+
assert len(result) == 32 # MD5 hex digest
27+
28+
# Hashable objects (like tuples)
29+
assert make_hashable((1, 2, 3)) == (1, 2, 3)
30+
31+
# Non-hashable fallback to string
32+
unhashable = {"key": "value"}
33+
result = make_hashable(unhashable)
34+
assert result == "{'key': 'value'}"
35+
36+
37+
def test_coerce_value() -> None:
38+
"""Test coerce_value for encoding preparation."""
39+
# Strings and bytes pass through
40+
assert coerce_value("hello") == "hello"
41+
assert coerce_value(b"bytes") == b"bytes"
42+
43+
# Primitives converted to strings
44+
assert coerce_value(42) == "42"
45+
assert coerce_value(3.14) == "3.14"
46+
assert coerce_value(True) == "True"
47+
48+
# Complex types pass through unchanged
49+
class MockImage:
50+
pass
51+
52+
img = MockImage()
53+
assert coerce_value(img) is img
754

855

956
def test_to_frozendict() -> None:

0 commit comments

Comments
 (0)