Skip to content
This repository was archived by the owner on Feb 18, 2026. It is now read-only.

Commit d8800d3

Browse files
[FEATURE] RDS-1528: Add optional drop argument to add_column
GitOrigin-RevId: 23262f480128d45d63252c38b52714029be693b9
1 parent 93b6d6a commit d8800d3

4 files changed

Lines changed: 134 additions & 16 deletions

File tree

src/gretel_client/data_designer/data_designer.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,15 @@ def with_evaluation_report(
408408
self._evaluation_report = GeneralDatasetEvaluation(
409409
settings=settings
410410
or EvaluateDataDesignerDatasetSettings(
411-
llm_judge_columns=[c.name for c in self.llm_judge_columns],
412-
validation_columns=[c.name for c in self.code_validation_columns],
413-
defined_categorical_columns=[c.name for c in self._categorical_columns],
411+
llm_judge_columns=[
412+
c.name for c in self.llm_judge_columns if not c.drop
413+
],
414+
validation_columns=[
415+
c.name for c in self.code_validation_columns if not c.drop
416+
],
417+
defined_categorical_columns=[
418+
c.name for c in self._categorical_columns if not c.drop
419+
],
414420
)
415421
)
416422
return self
@@ -592,6 +598,21 @@ def validate(self) -> Self:
592598
self._build_workflow()
593599
# Run semantic validation on full schema.
594600
violations = self._run_semantic_validation()
601+
602+
# Ensure all columns are not dropped
603+
remaining_cols = [
604+
name
605+
for name in self._columns
606+
if name not in self._latent_person_columns
607+
and name not in self._drop_columns
608+
]
609+
610+
if len(remaining_cols) == 0:
611+
raise DataDesignerValidationError(
612+
"🛑 All generated columns are configured to be dropped. Please mark at "
613+
"least one column with `drop=False`."
614+
)
615+
595616
if len(violations) == 0:
596617
logger.info("Validation passed ✅")
597618
return self
@@ -665,6 +686,13 @@ def _categorical_columns(self) -> list[SamplerColumn]:
665686
if (col.type == SamplerType.CATEGORY or col.type == SamplerType.SUBCATEGORY)
666687
]
667688

689+
@property
690+
def _drop_columns(self) -> list[str]:
691+
"""Names of columns marked with drop=True (computed on demand)."""
692+
return [
693+
name for name, col in self._columns.items() if getattr(col, "drop", False)
694+
]
695+
668696
@handle_workflow_validation_error
669697
def _build_workflow(
670698
self,
@@ -773,7 +801,7 @@ def _build_workflow(
773801
last_step_added = next_step
774802

775803
########################################################
776-
# Drop all latent columns from the final dataset
804+
# Drop intermediate columns (`drop=True`) and latent person columns
777805
########################################################
778806

779807
if len(self._latent_person_columns) > 0:
@@ -790,6 +818,18 @@ def _build_workflow(
790818
)
791819
last_step_added = drop_latent_columns_step
792820

821+
if self._drop_columns:
822+
drop_cols_step = self._task_registry.DropColumns(columns=self._drop_columns)
823+
builder.add_step(
824+
step=drop_cols_step,
825+
step_inputs=[last_step_added],
826+
step_name=(
827+
f"dropping-{len(self._drop_columns)}-intermediate-column"
828+
f"{'s' if len(self._drop_columns) != 1 else ''}"
829+
),
830+
)
831+
last_step_added = drop_cols_step
832+
793833
########################################################
794834
# Run dataset evaluation if requested
795835
########################################################
@@ -806,11 +846,15 @@ def _build_workflow(
806846
)
807847
else:
808848
general_eval_step = self._task_registry.EvaluateDataDesignerDataset(
809-
llm_judge_columns=[c.name for c in self.llm_judge_columns],
849+
llm_judge_columns=[
850+
c.name for c in self.llm_judge_columns if not c.drop
851+
],
810852
columns_to_ignore=settings.columns_to_ignore,
811-
validation_columns=settings.validation_columns,
853+
validation_columns=[
854+
c.name for c in self.code_validation_columns if not c.drop
855+
],
812856
defined_categorical_columns=[
813-
c.name for c in self._categorical_columns
857+
c.name for c in self._categorical_columns if not c.drop
814858
],
815859
)
816860
builder.add_step(

src/gretel_client/data_designer/types.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ class SeedDataset(AIDDConfigBase):
128128
##########################################################
129129

130130

131+
class WithDropColumnMixin(BaseModel):
132+
"""Adds a `drop` flag to indicate the column should be
133+
removed from the final dataset before evaluation."""
134+
135+
drop: bool = Field(
136+
default=False,
137+
description="If true, remove this column from the final dataset "
138+
"before evaluation.",
139+
)
140+
141+
131142
class WithDAGColumnMixin:
132143
@property
133144
def required_columns(self) -> list[str]:
@@ -138,7 +149,7 @@ def side_effect_columns(self) -> list[str]:
138149
return []
139150

140151

141-
class SamplerColumn(WithPrettyRepr, tasks.ConditionalDataColumn):
152+
class SamplerColumn(WithDropColumnMixin, WithPrettyRepr, tasks.ConditionalDataColumn):
142153
"""AIDD column that uses a sampler to generate data.
143154
144155
Sampler columns can be conditioned on other sampler columns using the `conditional_params` argument,
@@ -208,7 +219,10 @@ def unpack(cls, column: SerializableConditionalDataColumn | dict) -> Self:
208219

209220

210221
class LLMGenColumn(
211-
WithPrettyRepr, tasks.GenerateColumnFromTemplateV2, WithDAGColumnMixin
222+
WithDropColumnMixin,
223+
WithPrettyRepr,
224+
tasks.GenerateColumnFromTemplateV2,
225+
WithDAGColumnMixin,
212226
):
213227
@model_validator(mode="before")
214228
@classmethod
@@ -306,7 +320,9 @@ class LLMStructuredColumn(LLMGenColumn):
306320
output_type: OutputType = Field(default=OutputType.STRUCTURED)
307321

308322

309-
class LLMJudgeColumn(WithPrettyRepr, tasks.JudgeWithLlm, WithDAGColumnMixin):
323+
class LLMJudgeColumn(
324+
WithDropColumnMixin, WithPrettyRepr, tasks.JudgeWithLlm, WithDAGColumnMixin
325+
):
310326
"""AIDD column for llm-as-a-judge with custom rubrics.
311327
312328
Args:
@@ -334,7 +350,9 @@ def step_name(self) -> str:
334350
return f"using-llm-to-judge-column-{self.name}"
335351

336352

337-
class CodeValidationColumn(WithPrettyRepr, AIDDConfigBase, WithDAGColumnMixin):
353+
class CodeValidationColumn(
354+
WithDropColumnMixin, WithPrettyRepr, AIDDConfigBase, WithDAGColumnMixin
355+
):
338356
"""AIDD column for validating code in another column.
339357
340358
Code validation is currently supported for Python and SQL.
@@ -371,7 +389,10 @@ def step_name(self) -> str:
371389

372390

373391
class ExpressionColumn(
374-
WithPrettyRepr, tasks.GenerateColumnFromExpression, WithDAGColumnMixin
392+
WithDropColumnMixin,
393+
WithPrettyRepr,
394+
tasks.GenerateColumnFromExpression,
395+
WithDAGColumnMixin,
375396
):
376397
"""AIDD column for generated data based on jinja2 expressions.
377398

src/gretel_client/data_designer/viz_tools.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class AIDDMetadata(BaseModel):
5454
validation_columns: list[str] = []
5555
expression_columns: list[str] = []
5656
evaluation_columns: list[str] = []
57+
drop_columns: list[str] = []
5758
person_samplers: list[str] = []
5859
code_langs: list[CodeLang | str] = []
5960
eval_type: LLMJudgePromptTemplateType | None = None
@@ -104,6 +105,7 @@ def from_aidd(cls, aidd: "DataDesigner") -> Self:
104105
llm_judge_columns=[col.name for col in aidd.llm_judge_columns],
105106
validation_columns=code_validation_columns,
106107
expression_columns=[col.name for col in aidd.expression_columns],
108+
drop_columns=aidd._drop_columns,
107109
person_samplers=list(aidd._latent_person_columns.keys()),
108110
code_langs=[col.output_format for col in aidd.llm_code_columns],
109111
eval_type=None,
@@ -161,7 +163,7 @@ def display_sample_record(
161163
table = Table(title="Seed Columns", **table_kws)
162164
table.add_column("Name")
163165
table.add_column("Value")
164-
for col in aidd_metadata.seed_columns:
166+
for col in aidd_metadata.seed_columns and col not in aidd_metadata.drop_columns:
165167
table.add_row(col, _convert_to_row_element(record[col]))
166168
render_list.append(_pad_console_element(table))
167169

@@ -176,7 +178,7 @@ def display_sample_record(
176178
table = Table(title="Generated Columns", **table_kws)
177179
table.add_column("Name")
178180
table.add_column("Value")
179-
for col in [c for c in non_code_columns]:
181+
for col in [c for c in non_code_columns if c not in aidd_metadata.drop_columns]:
180182
table.add_row(col, _convert_to_row_element(record[col]))
181183
render_list.append(_pad_console_element(table))
182184

@@ -207,7 +209,11 @@ def display_sample_record(
207209
if len(aidd_metadata.validation_columns) > 0:
208210
table = Table(title="Validation", **table_kws)
209211
row = []
210-
for col in aidd_metadata.validation_columns:
212+
for col in [
213+
c
214+
for c in aidd_metadata.validation_columns
215+
if c not in aidd_metadata.drop_columns
216+
]:
211217
value = record[col]
212218
if isinstance(value, numbers.Number):
213219
table.add_column(col)
@@ -224,7 +230,11 @@ def display_sample_record(
224230
render_list.append(_pad_console_element(table, (1, 0, 1, 0)))
225231

226232
if len(aidd_metadata.llm_judge_columns) > 0:
227-
for col in aidd_metadata.llm_judge_columns:
233+
for col in [
234+
c
235+
for c in aidd_metadata.llm_judge_columns
236+
if c not in aidd_metadata.drop_columns
237+
]:
228238
table = Table(title=f"LLM-as-a-Judge: {col}", **table_kws)
229239
row = []
230240
judge = record[col]

tests/gretel_client/data_designer/test_data_designer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,46 @@ def test_get_column_from_kwargs():
521521
assert person_sampler_column_no_params.params.locale == "en_US"
522522
assert person_sampler_column_no_params.params.sex is None
523523
assert person_sampler_column_no_params.params.city is None
524+
525+
526+
def _minimal_designer(resource_provider):
527+
"""Helper to build a designer with at least one sampler column so workflow validation passes."""
528+
dd = DataDesigner(gretel_resource_provider=resource_provider)
529+
dd.add_column(name="uid", type="uuid", params={})
530+
return dd
531+
532+
533+
def test_drop_flag_adds_dropcolumns_step(mock_low_level_sdk_resources):
534+
dd = _minimal_designer(mock_low_level_sdk_resources.mock_resource_provider)
535+
dd.add_column(
536+
name="dude", type="category", params={"values": ["John", "Jane"]}, drop=True
537+
)
538+
dd.preview()
539+
540+
steps = [
541+
c[2]["step"]
542+
for c in mock_low_level_sdk_resources.mock_workflow_builder.add_step.mock_calls
543+
]
544+
drop_step = next((s for s in steps if isinstance(s, DropColumns)), None)
545+
546+
assert drop_step is not None
547+
assert drop_step.columns == ["dude"]
548+
549+
550+
def test_drop_flag_false_retains_column(mock_low_level_sdk_resources):
551+
dd = _minimal_designer(mock_low_level_sdk_resources.mock_resource_provider)
552+
dd.add_column(
553+
name="dude",
554+
type="category",
555+
params={"values": ["John", "Jane"]},
556+
drop=False,
557+
)
558+
dd.preview()
559+
560+
assert "dude" not in dd._drop_columns
561+
562+
steps = [
563+
call[2]["step"]
564+
for call in mock_low_level_sdk_resources.mock_workflow_builder.add_step.mock_calls
565+
]
566+
assert next((s for s in steps if isinstance(s, DropColumns)), None) is None

0 commit comments

Comments
 (0)