|
1 | 1 | from dataclasses import dataclass |
2 | | -from typing import List |
| 2 | +from typing import Any, Dict, List, Sequence, cast |
| 3 | + |
| 4 | +from nucleus.async_job import AsyncJob, JobError |
| 5 | + |
| 6 | +REQUIRED_RESULT_FIELDS = ("unique_item_ids", "unique_reference_ids", "stats") |
| 7 | +REQUIRED_STATS_FIELDS = ("threshold", "original_count", "deduplicated_count") |
| 8 | + |
| 9 | + |
| 10 | +def _require_fields( |
| 11 | + payload: Dict[str, Any], required_fields: Sequence[str], context: str |
| 12 | +) -> None: |
| 13 | + missing_fields = [ |
| 14 | + field for field in required_fields if field not in payload |
| 15 | + ] |
| 16 | + if missing_fields: |
| 17 | + missing_fields_message = ", ".join(missing_fields) |
| 18 | + raise RuntimeError( |
| 19 | + f"Deduplication job result missing {context} field(s): {missing_fields_message}" |
| 20 | + ) |
3 | 21 |
|
4 | 22 |
|
5 | 23 | @dataclass |
6 | 24 | class DeduplicationStats: |
| 25 | + """Summary statistics for a deduplication run. |
| 26 | +
|
| 27 | + Attributes: |
| 28 | + threshold: The Hamming distance threshold the run was executed at. |
| 29 | + Lower values are stricter; ``0`` means exact matches only. |
| 30 | + original_count: How many items were considered before deduplication. |
| 31 | + deduplicated_count: How many unique items remained afterwards. |
| 32 | + """ |
| 33 | + |
7 | 34 | threshold: int |
8 | 35 | original_count: int |
9 | 36 | deduplicated_count: int |
10 | 37 |
|
11 | 38 |
|
12 | 39 | @dataclass |
13 | 40 | class DeduplicationResult: |
14 | | - unique_item_ids: List[str] # Internal dataset item IDs |
15 | | - unique_reference_ids: List[str] # User-defined reference IDs |
| 41 | + """Output of a deduplication run. |
| 42 | +
|
| 43 | + Attributes: |
| 44 | + unique_item_ids: Nucleus-internal dataset item IDs (e.g. |
| 45 | + ``"di_abc123..."``) that survived deduplication. One entry per |
| 46 | + kept item. |
| 47 | + unique_reference_ids: The user-defined reference IDs you supplied at |
| 48 | + upload time, in the same order as ``unique_item_ids``. |
| 49 | + stats: Summary statistics for the run. See :class:`DeduplicationStats`. |
| 50 | + """ |
| 51 | + |
| 52 | + unique_item_ids: List[str] |
| 53 | + unique_reference_ids: List[str] |
16 | 54 | stats: DeduplicationStats |
| 55 | + |
| 56 | + |
| 57 | +class DeduplicationJob(AsyncJob): |
| 58 | + """Handle to a long-running deduplication job. |
| 59 | +
|
| 60 | + Returned from :meth:`Dataset.deduplicate` and |
| 61 | + :meth:`Dataset.deduplicate_by_ids`. Deduplication always runs in the |
| 62 | + background; collect the completed output with :meth:`result`. |
| 63 | +
|
| 64 | + Inherits all the standard :class:`AsyncJob` controls |
| 65 | + (:meth:`status`, :meth:`errors`, :meth:`sleep_until_complete`). |
| 66 | +
|
| 67 | + :: |
| 68 | +
|
| 69 | + import nucleus |
| 70 | +
|
| 71 | + client = nucleus.NucleusClient(YOUR_API_KEY) |
| 72 | + dataset = client.get_dataset("ds_xxx") |
| 73 | +
|
| 74 | + job = dataset.deduplicate(threshold=10) |
| 75 | + result = job.result() # blocks until done |
| 76 | + print(result.stats.deduplicated_count) |
| 77 | + print(result.unique_reference_ids) |
| 78 | +
|
| 79 | + # You can also deduplicate a known set of internal dataset item IDs. |
| 80 | + job = dataset.deduplicate_by_ids( |
| 81 | + threshold=10, |
| 82 | + dataset_item_ids=["di_xxx", "di_yyy"], |
| 83 | + ) |
| 84 | + result = job.result() |
| 85 | +
|
| 86 | + # Or split the wait and fetch yourself. |
| 87 | + job.sleep_until_complete() |
| 88 | + result = job.result(wait_for_completion=False) |
| 89 | + """ |
| 90 | + |
| 91 | + def result( |
| 92 | + self, wait_for_completion: bool = True |
| 93 | + ) -> "DeduplicationResult": |
| 94 | + """Return the deduplication result, optionally waiting for the job. |
| 95 | +
|
| 96 | + Parameters: |
| 97 | + wait_for_completion: When ``True`` (default), block until the job |
| 98 | + reaches a terminal state. When ``False``, the caller is |
| 99 | + expected to have already waited (e.g. via |
| 100 | + :meth:`sleep_until_complete`). |
| 101 | +
|
| 102 | + Returns: |
| 103 | + A :class:`DeduplicationResult` containing the kept item IDs, |
| 104 | + reference IDs, and run statistics. |
| 105 | +
|
| 106 | + Raises: |
| 107 | + JobError: If the job did not finish successfully (e.g. it was |
| 108 | + cancelled or hit a server error). |
| 109 | + RuntimeError: If the completed job response is missing expected |
| 110 | + result fields. |
| 111 | + """ |
| 112 | + if wait_for_completion: |
| 113 | + self.sleep_until_complete(verbose_std_out=False) |
| 114 | + |
| 115 | + status = self.status() |
| 116 | + if status["status"] != "Completed": |
| 117 | + raise JobError(status, self) |
| 118 | + |
| 119 | + # AsyncJob.status() is typed as Dict[str, str] in the base class, but |
| 120 | + # the `message` slot is a JSON dict in practice. Cast locally so |
| 121 | + # static checkers don't flag the dict accesses below. |
| 122 | + msg = cast(Dict[str, Any], status["message"] or {}) |
| 123 | + _require_fields(msg, REQUIRED_RESULT_FIELDS, "result") |
| 124 | + stats = cast(Dict[str, Any], msg.get("stats") or {}) |
| 125 | + _require_fields(stats, REQUIRED_STATS_FIELDS, "stats") |
| 126 | + return DeduplicationResult( |
| 127 | + unique_item_ids=msg["unique_item_ids"], |
| 128 | + unique_reference_ids=msg["unique_reference_ids"], |
| 129 | + stats=DeduplicationStats( |
| 130 | + threshold=stats["threshold"], |
| 131 | + original_count=stats["original_count"], |
| 132 | + deduplicated_count=stats["deduplicated_count"], |
| 133 | + ), |
| 134 | + ) |
0 commit comments