diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index eb06565..b83428a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -7,4 +7,5 @@ /.github/ @NVIDIA-NeMo/data_designer_reviewers # Plugins +/plugins/data-designer-generalist-agent-env/ eric.tramel@gmail.com /plugins/data-designer-template/ @NVIDIA-NeMo/data_designer_reviewers diff --git a/catalog/plugins.json b/catalog/plugins.json index bab1a70..aa3f9e5 100644 --- a/catalog/plugins.json +++ b/catalog/plugins.json @@ -1,6 +1,56 @@ { "schema_version": 1, "plugins": [ + { + "name": "generalist-agent-environment", + "plugin_type": "column-generator", + "description": "Generalist agent environment and task assemblers for generated Data Designer data", + "package": { + "name": "data-designer-generalist-agent-env", + "version": "0.1.0", + "path": "plugins/data-designer-generalist-agent-env" + }, + "entry_point": { + "group": "data_designer.plugins", + "name": "generalist-agent-environment", + "value": "data_designer_generalist_agent_env.plugin:environment_plugin" + }, + "compatibility": { + "python": { + "specifier": ">=3.10" + }, + "data_designer": { + "requirement": "data-designer>=0.5.9", + "specifier": ">=0.5.9", + "marker": null + } + } + }, + { + "name": "generalist-agent-task", + "plugin_type": "column-generator", + "description": "Generalist agent environment and task assemblers for generated Data Designer data", + "package": { + "name": "data-designer-generalist-agent-env", + "version": "0.1.0", + "path": "plugins/data-designer-generalist-agent-env" + }, + "entry_point": { + "group": "data_designer.plugins", + "name": "generalist-agent-task", + "value": "data_designer_generalist_agent_env.plugin:task_plugin" + }, + "compatibility": { + "python": { + "specifier": ">=3.10" + }, + "data_designer": { + "requirement": "data-designer>=0.5.9", + "specifier": ">=0.5.9", + "marker": null + } + } + }, { "name": "text-transform", "plugin_type": "column-generator", diff --git a/docs/plugins/data-designer-generalist-agent-env/index.md b/docs/plugins/data-designer-generalist-agent-env/index.md new file mode 100644 index 0000000..84d217c --- /dev/null +++ b/docs/plugins/data-designer-generalist-agent-env/index.md @@ -0,0 +1,200 @@ +# data-designer-generalist-agent-env + +The `data-designer-generalist-agent-env` plugin adds a two-stage Generalist +environment workflow for Data Designer. It is designed for workflows where Data +Designer generates the topic, constraints, database schema, and database rows, +then the plugin assembles those generated artifacts into executable RL rollout +tuples. + +The workflow is: + +1. Use ordinary Data Designer columns, such as `llm-text` and `llm-structured`, + to generate a task topic and constraints. +2. Use additional Data Designer generation columns to generate a row-local + database schema and records that follow that schema. +3. Use `generalist-agent-environment` to validate and assemble the generated + schema and records into a sandbox with executable tools. +4. Use `generalist-agent-task` to synthesize the task prompt, tool-only solution, + verifier, reference answer, and simple-to-hard augmentation trace. + +No search provider or external retrieval step is required, and the plugin does +not fabricate fallback records. + +## Installation + +```bash +uv add "data-designer>=0.5.9" data-designer-generalist-agent-env +``` + +## Column types + +Use `generalist-agent-environment` to assemble the generated sandbox and toolset. + +| Field | Required | Description | +| --- | --- | --- | +| `name` | Yes | Output environment column name. | +| `task_topic_column` | Yes | Existing column containing a generated task topic such as `trip planning`. | +| `task_constraints_column` | No | Existing column containing generated constraints as text, JSON, or a structured object. | +| `database_schema_column` | Yes | Existing column containing the generated database schema. | +| `database_records_column` | Yes | Existing column containing generated database records. | +| `context_columns` | No | Existing columns copied into environment context. | + +Generated records must include `record_id`, `name`, `summary`, `cost`, +`duration`, `score`, and `tags`. Additional fields are preserved, and an +`attributes` object is recommended for topic-specific fields. + +Use `generalist-agent-task` to generate tasks from an environment. + +| Field | Required | Description | +| --- | --- | --- | +| `name` | Yes | Output task tuple column name. | +| `environment_column` | Yes | Column containing a `generalist-agent-environment` artifact. | +| `difficulty` | No | Final task difficulty: `simple`, `medium`, or `hard`; defaults to `hard`. | +| `required_tag` | No | Optional tag that the valid answer must contain. | +| `max_cost` | No | Optional maximum cost constraint. Unsatisfiable values are repaired upward. | +| `min_score` | No | Optional minimum score constraint. Unsatisfiable values are repaired downward. | + +## Usage + +```python +import pandas as pd +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.seed_source_dataframe import DataFrameSeedSource + +seed_df = pd.DataFrame( + { + "seed": ["travel planning"], + "brief": ["family-friendly museums, moderate budget, reliable transport"], + } +) + +constraint_schema = { + "type": "object", + "properties": { + "goal": {"type": "string"}, + "constraints": {"type": "array", "items": {"type": "string"}}, + "success_criteria": {"type": "array", "items": {"type": "string"}}, + "data_dimensions": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["goal", "constraints", "success_criteria", "data_dimensions"], +} + +database_schema_format = { + "type": "object", + "properties": { + "record_type": {"type": "string"}, + "primary_key": {"type": "string", "const": "record_id"}, + "fields": {"type": "array", "items": {"type": "object"}}, + "attribute_fields": {"type": "array", "items": {"type": "object"}}, + }, + "required": ["record_type", "primary_key", "fields", "attribute_fields"], +} + +records_format = { + "type": "object", + "properties": { + "records": { + "type": "array", + "items": { + "type": "object", + "properties": { + "record_id": {"type": "string"}, + "name": {"type": "string"}, + "summary": {"type": "string"}, + "cost": {"type": "integer"}, + "duration": {"type": "integer"}, + "score": {"type": "integer"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "attributes": {"type": "object"}, + }, + "required": ["record_id", "name", "summary", "cost", "duration", "score", "tags"], + }, + } + }, + "required": ["records"], +} + +builder = DataDesignerConfigBuilder() +builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) +builder.add_column( + name="task_topic", + column_type="llm-text", + model_alias="deepseek-v4-pro-live", + prompt="From {{ seed }} and {{ brief }}, write a concise task topic.", +) +builder.add_column( + name="task_constraints", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt="Generate constraints for topic {{ task_topic }} with brief {{ brief }}.", + output_format=constraint_schema, +) +builder.add_column( + name="database_schema", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt="Generate a database schema for topic {{ task_topic }} and constraints {{ task_constraints }}.", + output_format=database_schema_format, +) +builder.add_column( + name="database_records", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "Generate 8 records that follow schema {{ database_schema }} for topic " + "{{ task_topic }} and constraints {{ task_constraints }}. Include varied " + "cost, duration, score, tags, and attributes." + ), + output_format=records_format, +) +builder.add_column( + name="agent_environment", + column_type="generalist-agent-environment", + task_topic_column="task_topic", + task_constraints_column="task_constraints", + database_schema_column="database_schema", + database_records_column="database_records", + context_columns=["brief"], +) +builder.add_column( + name="agent_task", + column_type="generalist-agent-task", + environment_column="agent_environment", + difficulty="hard", + required_tag="reliable", +) +``` + +The generated `agent_task` value is a dictionary with these top-level keys: + +| Key | Description | +| --- | --- | +| `environment` | Sandbox metadata, generated database schema, generated records, and source context. | +| `tools` | Synthesized tool descriptors and Python function sources. | +| `tool_module_source` | Executable Python source defining the generated schema, generated database, and selected tools. | +| `task` | Prompt, difficulty, constraints, and answer schema. | +| `solution` | Python `solve(tools)` source restricted to tool calls and local logic. | +| `verifier` | Python `verify(answer, database)` source and reference validation status. | +| `reference_answer` | The generated solution output that the verifier accepts. | +| `task_iterations` | Simple-to-final task, solution, verifier, and augmentation artifacts. | +| `synthesis_trace` | Topic/constraint intake, schema intake, generated-data intake, task synthesis, solution, and verification events. | + +## Row validation helper + +The package includes a helper module for executable row validation. It executes +the generated tool module, smoke-tests the declared tools, runs the generated +solution, checks the generated verifier, and replays every task iteration: + +```python +from data_designer_generalist_agent_env.validation import verify_row_record + +validation = verify_row_record(result.dataset.loc[0], output_column="agent_task") +assert validation.passed, validation.errors +``` + +## Behavior notes + +The plugin does not generate the grounding records. It requires generated +schema and generated records from upstream Data Designer columns, validates the +minimum executable contract, and then builds tools and verifiers around that +generated data. diff --git a/docs/plugins/data-designer-generalist-agent-env/usage.md b/docs/plugins/data-designer-generalist-agent-env/usage.md new file mode 100644 index 0000000..4217a6c --- /dev/null +++ b/docs/plugins/data-designer-generalist-agent-env/usage.md @@ -0,0 +1,171 @@ +# Usage + +This example creates one Generalist RL rollout tuple from generated data. The +model generates the topic, constraints, database schema, and database records. +The plugin assembles those generated artifacts, adds executable tools, and then +synthesizes a task, tool-only solution, and verifier. + +```python +import pandas as pd +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.seed_source_dataframe import DataFrameSeedSource +from data_designer.interface.data_designer import DataDesigner + +constraint_schema = { + "type": "object", + "properties": { + "goal": {"type": "string"}, + "constraints": {"type": "array", "items": {"type": "string"}}, + "success_criteria": {"type": "array", "items": {"type": "string"}}, + "data_dimensions": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["goal", "constraints", "success_criteria", "data_dimensions"], +} + +database_schema_format = { + "type": "object", + "properties": { + "record_type": {"type": "string"}, + "primary_key": {"type": "string"}, + "fields": {"type": "array", "items": {"type": "object"}}, + "attribute_fields": {"type": "array", "items": {"type": "object"}}, + }, + "required": ["record_type", "primary_key", "fields", "attribute_fields"], +} + +records_format = { + "type": "object", + "properties": { + "records": { + "type": "array", + "items": { + "type": "object", + "properties": { + "record_id": {"type": "string"}, + "name": {"type": "string"}, + "summary": {"type": "string"}, + "cost": {"type": "integer"}, + "duration": {"type": "integer"}, + "score": {"type": "integer"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "attributes": {"type": "object"}, + }, + "required": ["record_id", "name", "summary", "cost", "duration", "score", "tags"], + }, + } + }, + "required": ["records"], +} + +seed_df = pd.DataFrame( + { + "seed": ["planning a travel itinerary"], + "brief": ["family-friendly museums, moderate budget, reliable transport"], + } +) + +builder = DataDesignerConfigBuilder() +builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) +builder.add_column( + name="task_topic", + column_type="llm-text", + model_alias="deepseek-v4-pro-live", + prompt="From this seed {{ seed }}, write a concise task topic.", +) +builder.add_column( + name="task_constraints", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "For topic {{ task_topic }} and brief {{ brief }}, generate constraints " + "that make the task hard to solve but easy to verify." + ), + output_format=constraint_schema, +) +builder.add_column( + name="database_schema", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "Generate a database schema for topic {{ task_topic }} and constraints " + "{{ task_constraints }}. Include record_id, name, summary, cost, " + "duration, score, tags, and topic-specific attributes." + ), + output_format=database_schema_format, +) +builder.add_column( + name="database_records", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "Generate 8 diverse records that follow schema {{ database_schema }} " + "for topic {{ task_topic }} and constraints {{ task_constraints }}. " + "At least two records must include the tag reliable." + ), + output_format=records_format, +) +builder.add_column( + name="agent_environment", + column_type="generalist-agent-environment", + task_topic_column="task_topic", + task_constraints_column="task_constraints", + database_schema_column="database_schema", + database_records_column="database_records", + context_columns=["brief"], +) +builder.add_column( + name="agent_task", + column_type="generalist-agent-task", + environment_column="agent_environment", + difficulty="hard", + required_tag="reliable", +) + +result = DataDesigner(artifact_path="artifacts").preview(builder, num_records=1) +environment_tuple = result.dataset.loc[0, "agent_task"] +``` + +The generated row can be validated with the package helper: + +```python +from data_designer_generalist_agent_env.validation import verify_environment_tuple + +validation = verify_environment_tuple(environment_tuple) +assert validation.passed, validation.errors +assert validation.answer == environment_tuple["reference_answer"] +``` + +The output task is intentionally search-like: the solving agent must inspect the +generated schema, filter records, and rank candidates through the tool interface. +The verifier remains straightforward because it checks fixed constraints and a +deterministic tie-break order directly against the generated database. + +## Expected output shape + +`generalist-agent-environment` emits: + +```text +schema_version +environment.database_schema +environment.database +environment.data_generation +tools +tool_module_source +synthesis_trace +``` + +`generalist-agent-task` emits: + +```text +schema_version +environment +tools +tool_module_source +task +solution +verifier +reference_answer +task_iterations +synthesis_trace +rl_filter_note +``` diff --git a/docs/plugins/index.md b/docs/plugins/index.md index 4e54e2e..1c74551 100644 --- a/docs/plugins/index.md +++ b/docs/plugins/index.md @@ -5,6 +5,17 @@ Browse available Data Designer plugins by what they add to your data generation workflow.
+ + + data-designer-generalist-agent-env + v0.1.0 + + Generalist agent environment and task assemblers for generated Data Designer data + + Column types + generalist-agent-environmentgeneralist-agent-task + + data-designer-template diff --git a/plugins/data-designer-generalist-agent-env/CODEOWNERS b/plugins/data-designer-generalist-agent-env/CODEOWNERS new file mode 100644 index 0000000..013e51f --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/CODEOWNERS @@ -0,0 +1,3 @@ +# Owner(s) of this plugin — used to generate the root CODEOWNERS file. +# GitHub accepts @username, @org/team, or email format. +* eric.tramel@gmail.com diff --git a/plugins/data-designer-generalist-agent-env/README.md b/plugins/data-designer-generalist-agent-env/README.md new file mode 100644 index 0000000..9f91ef3 --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/README.md @@ -0,0 +1,47 @@ +# data-designer-generalist-agent-env + +Generate Generalist-style agent environments and tasks from Data Designer +generated topics, constraints, database schemas, and records. The plugin +assembles generated grounding data into executable tool environments, then +generates task prompts, tool-only solution functions, and verifier functions +from those environments. + +## Installation + +```bash +uv add "data-designer>=0.5.9" data-designer-generalist-agent-env +``` + +## Usage + +Once installed, the `generalist-agent-environment` and +`generalist-agent-task` column types are automatically discovered by +[NeMo Data Designer](https://github.com/NVIDIA-NeMo/DataDesigner). + +Configure the workflow after generating a task topic, constraints, schema, and +records: + +```python +builder.add_column( + name="agent_environment", + column_type="generalist-agent-environment", + task_topic_column="task_topic", + task_constraints_column="task_constraints", + database_schema_column="database_schema", + database_records_column="database_records", + context_columns=["brief"], +) +builder.add_column( + name="agent_task", + column_type="generalist-agent-task", + environment_column="agent_environment", + difficulty="hard", + required_tag="reliable", +) +``` + +For the full plugin authoring guide, see the +[main repository docs](https://nvidia-nemo.github.io/DataDesignerPlugins/authoring/). + +Plugin documentation for the repository site lives in this package's `docs/` +directory. diff --git a/plugins/data-designer-generalist-agent-env/docs/index.md b/plugins/data-designer-generalist-agent-env/docs/index.md new file mode 100644 index 0000000..84d217c --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/docs/index.md @@ -0,0 +1,200 @@ +# data-designer-generalist-agent-env + +The `data-designer-generalist-agent-env` plugin adds a two-stage Generalist +environment workflow for Data Designer. It is designed for workflows where Data +Designer generates the topic, constraints, database schema, and database rows, +then the plugin assembles those generated artifacts into executable RL rollout +tuples. + +The workflow is: + +1. Use ordinary Data Designer columns, such as `llm-text` and `llm-structured`, + to generate a task topic and constraints. +2. Use additional Data Designer generation columns to generate a row-local + database schema and records that follow that schema. +3. Use `generalist-agent-environment` to validate and assemble the generated + schema and records into a sandbox with executable tools. +4. Use `generalist-agent-task` to synthesize the task prompt, tool-only solution, + verifier, reference answer, and simple-to-hard augmentation trace. + +No search provider or external retrieval step is required, and the plugin does +not fabricate fallback records. + +## Installation + +```bash +uv add "data-designer>=0.5.9" data-designer-generalist-agent-env +``` + +## Column types + +Use `generalist-agent-environment` to assemble the generated sandbox and toolset. + +| Field | Required | Description | +| --- | --- | --- | +| `name` | Yes | Output environment column name. | +| `task_topic_column` | Yes | Existing column containing a generated task topic such as `trip planning`. | +| `task_constraints_column` | No | Existing column containing generated constraints as text, JSON, or a structured object. | +| `database_schema_column` | Yes | Existing column containing the generated database schema. | +| `database_records_column` | Yes | Existing column containing generated database records. | +| `context_columns` | No | Existing columns copied into environment context. | + +Generated records must include `record_id`, `name`, `summary`, `cost`, +`duration`, `score`, and `tags`. Additional fields are preserved, and an +`attributes` object is recommended for topic-specific fields. + +Use `generalist-agent-task` to generate tasks from an environment. + +| Field | Required | Description | +| --- | --- | --- | +| `name` | Yes | Output task tuple column name. | +| `environment_column` | Yes | Column containing a `generalist-agent-environment` artifact. | +| `difficulty` | No | Final task difficulty: `simple`, `medium`, or `hard`; defaults to `hard`. | +| `required_tag` | No | Optional tag that the valid answer must contain. | +| `max_cost` | No | Optional maximum cost constraint. Unsatisfiable values are repaired upward. | +| `min_score` | No | Optional minimum score constraint. Unsatisfiable values are repaired downward. | + +## Usage + +```python +import pandas as pd +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.seed_source_dataframe import DataFrameSeedSource + +seed_df = pd.DataFrame( + { + "seed": ["travel planning"], + "brief": ["family-friendly museums, moderate budget, reliable transport"], + } +) + +constraint_schema = { + "type": "object", + "properties": { + "goal": {"type": "string"}, + "constraints": {"type": "array", "items": {"type": "string"}}, + "success_criteria": {"type": "array", "items": {"type": "string"}}, + "data_dimensions": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["goal", "constraints", "success_criteria", "data_dimensions"], +} + +database_schema_format = { + "type": "object", + "properties": { + "record_type": {"type": "string"}, + "primary_key": {"type": "string", "const": "record_id"}, + "fields": {"type": "array", "items": {"type": "object"}}, + "attribute_fields": {"type": "array", "items": {"type": "object"}}, + }, + "required": ["record_type", "primary_key", "fields", "attribute_fields"], +} + +records_format = { + "type": "object", + "properties": { + "records": { + "type": "array", + "items": { + "type": "object", + "properties": { + "record_id": {"type": "string"}, + "name": {"type": "string"}, + "summary": {"type": "string"}, + "cost": {"type": "integer"}, + "duration": {"type": "integer"}, + "score": {"type": "integer"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "attributes": {"type": "object"}, + }, + "required": ["record_id", "name", "summary", "cost", "duration", "score", "tags"], + }, + } + }, + "required": ["records"], +} + +builder = DataDesignerConfigBuilder() +builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) +builder.add_column( + name="task_topic", + column_type="llm-text", + model_alias="deepseek-v4-pro-live", + prompt="From {{ seed }} and {{ brief }}, write a concise task topic.", +) +builder.add_column( + name="task_constraints", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt="Generate constraints for topic {{ task_topic }} with brief {{ brief }}.", + output_format=constraint_schema, +) +builder.add_column( + name="database_schema", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt="Generate a database schema for topic {{ task_topic }} and constraints {{ task_constraints }}.", + output_format=database_schema_format, +) +builder.add_column( + name="database_records", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "Generate 8 records that follow schema {{ database_schema }} for topic " + "{{ task_topic }} and constraints {{ task_constraints }}. Include varied " + "cost, duration, score, tags, and attributes." + ), + output_format=records_format, +) +builder.add_column( + name="agent_environment", + column_type="generalist-agent-environment", + task_topic_column="task_topic", + task_constraints_column="task_constraints", + database_schema_column="database_schema", + database_records_column="database_records", + context_columns=["brief"], +) +builder.add_column( + name="agent_task", + column_type="generalist-agent-task", + environment_column="agent_environment", + difficulty="hard", + required_tag="reliable", +) +``` + +The generated `agent_task` value is a dictionary with these top-level keys: + +| Key | Description | +| --- | --- | +| `environment` | Sandbox metadata, generated database schema, generated records, and source context. | +| `tools` | Synthesized tool descriptors and Python function sources. | +| `tool_module_source` | Executable Python source defining the generated schema, generated database, and selected tools. | +| `task` | Prompt, difficulty, constraints, and answer schema. | +| `solution` | Python `solve(tools)` source restricted to tool calls and local logic. | +| `verifier` | Python `verify(answer, database)` source and reference validation status. | +| `reference_answer` | The generated solution output that the verifier accepts. | +| `task_iterations` | Simple-to-final task, solution, verifier, and augmentation artifacts. | +| `synthesis_trace` | Topic/constraint intake, schema intake, generated-data intake, task synthesis, solution, and verification events. | + +## Row validation helper + +The package includes a helper module for executable row validation. It executes +the generated tool module, smoke-tests the declared tools, runs the generated +solution, checks the generated verifier, and replays every task iteration: + +```python +from data_designer_generalist_agent_env.validation import verify_row_record + +validation = verify_row_record(result.dataset.loc[0], output_column="agent_task") +assert validation.passed, validation.errors +``` + +## Behavior notes + +The plugin does not generate the grounding records. It requires generated +schema and generated records from upstream Data Designer columns, validates the +minimum executable contract, and then builds tools and verifiers around that +generated data. diff --git a/plugins/data-designer-generalist-agent-env/docs/usage.md b/plugins/data-designer-generalist-agent-env/docs/usage.md new file mode 100644 index 0000000..4217a6c --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/docs/usage.md @@ -0,0 +1,171 @@ +# Usage + +This example creates one Generalist RL rollout tuple from generated data. The +model generates the topic, constraints, database schema, and database records. +The plugin assembles those generated artifacts, adds executable tools, and then +synthesizes a task, tool-only solution, and verifier. + +```python +import pandas as pd +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.seed_source_dataframe import DataFrameSeedSource +from data_designer.interface.data_designer import DataDesigner + +constraint_schema = { + "type": "object", + "properties": { + "goal": {"type": "string"}, + "constraints": {"type": "array", "items": {"type": "string"}}, + "success_criteria": {"type": "array", "items": {"type": "string"}}, + "data_dimensions": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["goal", "constraints", "success_criteria", "data_dimensions"], +} + +database_schema_format = { + "type": "object", + "properties": { + "record_type": {"type": "string"}, + "primary_key": {"type": "string"}, + "fields": {"type": "array", "items": {"type": "object"}}, + "attribute_fields": {"type": "array", "items": {"type": "object"}}, + }, + "required": ["record_type", "primary_key", "fields", "attribute_fields"], +} + +records_format = { + "type": "object", + "properties": { + "records": { + "type": "array", + "items": { + "type": "object", + "properties": { + "record_id": {"type": "string"}, + "name": {"type": "string"}, + "summary": {"type": "string"}, + "cost": {"type": "integer"}, + "duration": {"type": "integer"}, + "score": {"type": "integer"}, + "tags": {"type": "array", "items": {"type": "string"}}, + "attributes": {"type": "object"}, + }, + "required": ["record_id", "name", "summary", "cost", "duration", "score", "tags"], + }, + } + }, + "required": ["records"], +} + +seed_df = pd.DataFrame( + { + "seed": ["planning a travel itinerary"], + "brief": ["family-friendly museums, moderate budget, reliable transport"], + } +) + +builder = DataDesignerConfigBuilder() +builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) +builder.add_column( + name="task_topic", + column_type="llm-text", + model_alias="deepseek-v4-pro-live", + prompt="From this seed {{ seed }}, write a concise task topic.", +) +builder.add_column( + name="task_constraints", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "For topic {{ task_topic }} and brief {{ brief }}, generate constraints " + "that make the task hard to solve but easy to verify." + ), + output_format=constraint_schema, +) +builder.add_column( + name="database_schema", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "Generate a database schema for topic {{ task_topic }} and constraints " + "{{ task_constraints }}. Include record_id, name, summary, cost, " + "duration, score, tags, and topic-specific attributes." + ), + output_format=database_schema_format, +) +builder.add_column( + name="database_records", + column_type="llm-structured", + model_alias="deepseek-v4-pro-live", + prompt=( + "Generate 8 diverse records that follow schema {{ database_schema }} " + "for topic {{ task_topic }} and constraints {{ task_constraints }}. " + "At least two records must include the tag reliable." + ), + output_format=records_format, +) +builder.add_column( + name="agent_environment", + column_type="generalist-agent-environment", + task_topic_column="task_topic", + task_constraints_column="task_constraints", + database_schema_column="database_schema", + database_records_column="database_records", + context_columns=["brief"], +) +builder.add_column( + name="agent_task", + column_type="generalist-agent-task", + environment_column="agent_environment", + difficulty="hard", + required_tag="reliable", +) + +result = DataDesigner(artifact_path="artifacts").preview(builder, num_records=1) +environment_tuple = result.dataset.loc[0, "agent_task"] +``` + +The generated row can be validated with the package helper: + +```python +from data_designer_generalist_agent_env.validation import verify_environment_tuple + +validation = verify_environment_tuple(environment_tuple) +assert validation.passed, validation.errors +assert validation.answer == environment_tuple["reference_answer"] +``` + +The output task is intentionally search-like: the solving agent must inspect the +generated schema, filter records, and rank candidates through the tool interface. +The verifier remains straightforward because it checks fixed constraints and a +deterministic tie-break order directly against the generated database. + +## Expected output shape + +`generalist-agent-environment` emits: + +```text +schema_version +environment.database_schema +environment.database +environment.data_generation +tools +tool_module_source +synthesis_trace +``` + +`generalist-agent-task` emits: + +```text +schema_version +environment +tools +tool_module_source +task +solution +verifier +reference_answer +task_iterations +synthesis_trace +rl_filter_note +``` diff --git a/plugins/data-designer-generalist-agent-env/pyproject.toml b/plugins/data-designer-generalist-agent-env/pyproject.toml new file mode 100644 index 0000000..ffc0d8e --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/pyproject.toml @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "data-designer-generalist-agent-env" +version = "0.1.0" +description = "Generalist agent environment and task assemblers for generated Data Designer data" +requires-python = ">=3.10" +dependencies = [ + "data-designer>=0.5.9", +] +license = "Apache-2.0" +readme = "README.md" +authors = [ + {name = "NVIDIA Corporation"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", +] + +[project.entry-points."data_designer.plugins"] +generalist-agent-environment = "data_designer_generalist_agent_env.plugin:environment_plugin" +generalist-agent-task = "data_designer_generalist_agent_env.plugin:task_plugin" + +[project.urls] +Repository = "https://github.com/NVIDIA-NeMo/DataDesignerPlugins" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/data_designer_generalist_agent_env"] + +[tool.ruff] +extend = "../../pyproject.toml" diff --git a/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/__init__.py b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/__init__.py new file mode 100644 index 0000000..4c635f4 --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer_generalist_agent_env.validation import ( + IterationExecutionCheck, + RowRecordValidationResult, + ToolExecutionCheck, + verify_environment_tuple, + verify_row_record, +) + +__all__ = [ + "IterationExecutionCheck", + "RowRecordValidationResult", + "ToolExecutionCheck", + "verify_environment_tuple", + "verify_row_record", +] diff --git a/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/config.py b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/config.py new file mode 100644 index 0000000..6ade46a --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/config.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Literal + +from data_designer.config.base import SingleColumnConfig +from pydantic import Field, field_validator, model_validator +from typing_extensions import Self + +Difficulty = Literal["simple", "medium", "hard"] + + +def normalize_column_name(value: str, field_name: str) -> str: + """Normalize and validate one column name. + + Args: + value: Candidate column name. + field_name: Name used in validation messages. + + Returns: + The stripped column name. + + Raises: + ValueError: If the column name is empty. + """ + value = value.strip() + if not value: + raise ValueError(f"{field_name} must not be empty") + return value + + +def normalize_context_columns(value: list[str]) -> list[str]: + """Validate and de-duplicate context column names. + + Args: + value: Candidate context column names. + + Returns: + Context column names with duplicates removed while preserving order. + + Raises: + ValueError: If any context column name is empty. + """ + columns: list[str] = [] + for column in value: + column = normalize_column_name(column, "context_columns") + if column not in columns: + columns.append(column) + return columns + + +def normalize_required_tag(value: str | None) -> str | None: + """Normalize an optional required tag. + + Args: + value: Candidate tag value. + + Returns: + A lower-cased tag, or ``None`` when unset. + + Raises: + ValueError: If the tag contains only whitespace. + """ + if value is None: + return None + value = value.strip().lower() + if not value: + raise ValueError("required_tag must not be empty when provided") + return value + + +class GeneralistAgentEnvironmentColumnConfig(SingleColumnConfig): + """Configuration for constructing generated Generalist sandbox environments. + + The generator consumes Data Designer generated topic, constraints, database + schema, and database records, then emits a row-local environment with + executable tool implementations over those generated records. + """ + + column_type: Literal["generalist-agent-environment"] = "generalist-agent-environment" + + task_topic_column: str = Field( + description="Input column containing a generated task topic, such as 'trip planning'.", + ) + task_constraints_column: str | None = Field( + default=None, + description="Optional input column containing generated constraints as text, JSON, or a structured object.", + ) + database_schema_column: str = Field( + description="Input column containing the generated database schema for this row.", + ) + database_records_column: str = Field( + description="Input column containing generated database records for this row.", + ) + context_columns: list[str] = Field( + default_factory=list, + description="Optional seed columns copied into environment context.", + ) + + @staticmethod + def get_column_emoji() -> str: + return "🧱" + + @field_validator("task_topic_column") + @classmethod + def validate_task_topic_column(cls, value: str) -> str: + """Validate the task topic source column name.""" + return normalize_column_name(value, "task_topic_column") + + @field_validator("task_constraints_column") + @classmethod + def validate_task_constraints_column(cls, value: str | None) -> str | None: + """Validate the optional task constraints source column name.""" + if value is None: + return None + return normalize_column_name(value, "task_constraints_column") + + @field_validator("database_schema_column") + @classmethod + def validate_database_schema_column(cls, value: str) -> str: + """Validate the generated database schema source column name.""" + return normalize_column_name(value, "database_schema_column") + + @field_validator("database_records_column") + @classmethod + def validate_database_records_column(cls, value: str) -> str: + """Validate the generated database records source column name.""" + return normalize_column_name(value, "database_records_column") + + @field_validator("context_columns") + @classmethod + def validate_context_columns(cls, value: list[str]) -> list[str]: + """Validate context column names.""" + return normalize_context_columns(value) + + @model_validator(mode="after") + def validate_distinct_columns(self) -> Self: + """Validate cross-field column references.""" + named_columns = [ + self.task_topic_column, + self.database_schema_column, + self.database_records_column, + *self.context_columns, + ] + if self.task_constraints_column is not None: + named_columns.append(self.task_constraints_column) + if len(named_columns) != len(set(named_columns)): + raise ValueError( + "task_topic_column, task_constraints_column, database_schema_column, " + "database_records_column, and context_columns must be distinct" + ) + return self + + @property + def required_columns(self) -> list[str]: + columns = [self.task_topic_column] + if self.task_constraints_column is not None: + columns.append(self.task_constraints_column) + columns.extend([self.database_schema_column, self.database_records_column]) + columns.extend(self.context_columns) + return columns + + @property + def side_effect_columns(self) -> list[str]: + return [] + + +class GeneralistAgentTaskColumnConfig(SingleColumnConfig): + """Configuration for synthesizing tasks from generated environments.""" + + column_type: Literal["generalist-agent-task"] = "generalist-agent-task" + + environment_column: str = Field( + description="Column containing a generalist-agent-environment artifact.", + ) + difficulty: Difficulty = Field( + default="hard", + description="Final task difficulty to synthesize after the simple-to-hard iteration trace.", + ) + required_tag: str | None = Field( + default=None, + description="Optional tag that every valid solution candidate must contain.", + ) + max_cost: int | None = Field( + default=None, + ge=1, + description="Optional maximum cost constraint. Unsatisfiable values are repaired upward.", + ) + min_score: int | None = Field( + default=None, + ge=0, + le=100, + description="Optional minimum score constraint. Unsatisfiable values are repaired downward.", + ) + + @staticmethod + def get_column_emoji() -> str: + return "🧪" + + @field_validator("environment_column") + @classmethod + def validate_environment_column(cls, value: str) -> str: + """Validate the environment source column name.""" + return normalize_column_name(value, "environment_column") + + @field_validator("required_tag") + @classmethod + def validate_required_tag(cls, value: str | None) -> str | None: + """Normalize the optional required tag.""" + return normalize_required_tag(value) + + @property + def required_columns(self) -> list[str]: + return [self.environment_column] + + @property + def side_effect_columns(self) -> list[str]: + return [] diff --git a/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/impl.py b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/impl.py new file mode 100644 index 0000000..7801b6d --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/impl.py @@ -0,0 +1,1202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +import json +import math +import re +import textwrap +from collections.abc import Mapping +from pprint import pformat +from typing import TYPE_CHECKING, Any + +from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn + +from data_designer_generalist_agent_env.config import ( + Difficulty, + GeneralistAgentEnvironmentColumnConfig, + GeneralistAgentTaskColumnConfig, +) + +if TYPE_CHECKING: + import pandas as pd + +BASE_SANDBOX_TOOLS = ["data_designer_generated_schema", "data_designer_generated_records"] +DIFFICULTY_ORDER: list[Difficulty] = ["simple", "medium", "hard"] +REQUIRED_RECORD_FIELDS = ["record_id", "name", "summary", "cost", "duration", "score", "tags"] +DEFAULT_DATABASE_SCHEMA = { + "record_type": "generated_candidate", + "primary_key": "record_id", + "fields": [ + {"name": "record_id", "type": "string", "description": "Stable row-local identifier."}, + {"name": "name", "type": "string", "description": "Human-readable candidate name."}, + {"name": "summary", "type": "string", "description": "Short generated candidate description."}, + {"name": "cost", "type": "integer", "description": "Integer cost proxy; lower is better."}, + {"name": "duration", "type": "integer", "description": "Integer duration or effort proxy."}, + {"name": "score", "type": "integer", "description": "Integer quality score from 0 to 100; higher is better."}, + {"name": "tags", "type": "list[string]", "description": "Searchable task-specific labels."}, + {"name": "attributes", "type": "object", "description": "Topic-specific generated attributes."}, + ], +} + +TOOL_FUNCTION_SOURCES = { + "describe_schema": ''' +def describe_schema(): + """Return the generated database schema.""" + return dict(DATABASE_SCHEMA) +''', + "list_records": ''' +def list_records(): + """Return every record in the sandbox database.""" + return [dict(record) for record in DATABASE] +''', + "search_records": ''' +def search_records(query="", max_results=10): + """Search database records by name, summary, topic, tag, or generated attribute.""" + needle = str(query or "").casefold() + limit = max(0, int(max_results)) + matches = [] + for record in DATABASE: + attributes = record.get("attributes", {}) + attribute_text = " ".join(str(value) for value in attributes.values()) if isinstance(attributes, dict) else "" + haystack = " ".join( + [ + str(record.get("name", "")), + str(record.get("summary", "")), + str(record.get("topic", "")), + " ".join(str(tag) for tag in record.get("tags", [])), + attribute_text, + ], + ).casefold() + if not needle or needle in haystack: + matches.append(dict(record)) + return matches[:limit] +''', + "get_record": ''' +def get_record(record_id): + """Return one record by id, or None when the id is unknown.""" + for record in DATABASE: + if str(record.get("record_id")) == str(record_id): + return dict(record) + return None +''', + "filter_records": ''' +def filter_records(max_cost=None, min_score=None, required_tag=None): + """Filter records by cost, score, and tag constraints.""" + matches = [] + for record in DATABASE: + if max_cost is not None and int(record["cost"]) > int(max_cost): + continue + if min_score is not None and int(record["score"]) < int(min_score): + continue + if required_tag is not None and str(required_tag) not in record.get("tags", []): + continue + matches.append(dict(record)) + return matches +''', + "rank_records": ''' +def rank_records(records=None, metric="score", descending=True): + """Rank supplied records, or all database records, by a numeric metric.""" + source = DATABASE if records is None else records + return sorted( + [dict(record) for record in source], + key=lambda record: int(record.get(metric, 0)), + reverse=bool(descending), + ) +''', +} + +TOOL_DESCRIPTIONS = { + "describe_schema": "Inspect the generated row-local database schema.", + "list_records": "Inspect all generated rows in the hidden sandbox database.", + "search_records": "Retrieve topic-relevant records through a search-style interface.", + "get_record": "Fetch one generated database record by identifier.", + "filter_records": "Apply verifier-aligned constraints without exposing the database directly.", + "rank_records": "Rank generated candidate records for the final selection step.", +} + + +def is_null_like(value: object) -> bool: + """Return whether a value is empty or pandas-null-like. + + Args: + value: Candidate cell value. + + Returns: + ``True`` when the value should be treated as missing. + """ + if value is None: + return True + if isinstance(value, float) and math.isnan(value): + return True + try: + return bool(value != value) + except (TypeError, ValueError): + return False + + +def normalize_cell(value: object) -> str: + """Normalize one pandas cell into a stable text value. + + Args: + value: Cell value from a seed row. + + Returns: + A stripped string, or an empty string for null-like values. + """ + if is_null_like(value): + return "" + return str(value).strip() + + +def slugify(value: str, fallback: str) -> str: + """Convert text into a stable lowercase identifier fragment. + + Args: + value: Input text. + fallback: Value to use when no identifier characters remain. + + Returns: + A slug containing lowercase letters, digits, and hyphens. + """ + slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") + return slug or fallback + + +def stable_int(seed: str, modulo: int) -> int: + """Hash text into a deterministic integer range. + + Args: + seed: Hash seed. + modulo: Exclusive upper bound. + + Returns: + A deterministic integer in ``[0, modulo)``. + """ + digest = hashlib.sha256(seed.encode("utf-8")).hexdigest() + return int(digest[:12], 16) % modulo + + +def coerce_list_like(value: Any) -> list[Any] | None: + """Coerce common list-like values into a Python list. + + Args: + value: Candidate list-like value. + + Returns: + A Python list when coercion is possible, otherwise ``None``. + """ + if isinstance(value, list): + return value + if isinstance(value, tuple): + return list(value) + tolist = getattr(value, "tolist", None) + if callable(tolist): + converted = tolist() + if isinstance(converted, list): + return converted + return None + + +def to_plain_data(value: Any) -> Any: + """Convert nested array-like values into JSON-style Python containers. + + Args: + value: Arbitrary generated artifact value. + + Returns: + The value with mappings and list-like values recursively normalized. + """ + if isinstance(value, Mapping): + return {key: to_plain_data(nested_value) for key, nested_value in value.items()} + + values = coerce_list_like(value) + if values is not None: + return [to_plain_data(nested_value) for nested_value in values] + + item = getattr(value, "item", None) + if callable(item) and not isinstance(value, (str, bytes)): + try: + return item() + except (TypeError, ValueError): + return value + return value + + +def parse_generated_payload(value: Any, field_name: str) -> Any: + """Parse generated JSON-like cell values. + + Args: + value: Cell value generated by a previous Data Designer column. + field_name: Name used in validation errors. + + Returns: + Parsed JSON-like data. + + Raises: + ValueError: If the value is missing or a JSON string cannot be parsed. + """ + if is_null_like(value): + msg = f"{field_name} must not be empty" + raise ValueError(msg) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + msg = f"{field_name} must not be empty" + raise ValueError(msg) + try: + return json.loads(stripped) + except json.JSONDecodeError: + return stripped + return to_plain_data(value) + + +def constraint_payload_to_text(value: Any) -> str: + """Flatten a generated constraints payload into compact text. + + Args: + value: Constraint value from a row. Supported values include strings, + mappings, lists, JSON strings, and scalar values. + + Returns: + Text used for environment provenance. + """ + if is_null_like(value): + return "" + + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return "" + try: + return constraint_payload_to_text(json.loads(stripped)) + except json.JSONDecodeError: + return stripped + + plain = to_plain_data(value) + if isinstance(plain, Mapping): + parts = [] + for key, nested_value in plain.items(): + nested_text = constraint_payload_to_text(nested_value) + if nested_text: + parts.append(f"{key}: {nested_text}") + return "; ".join(parts) + + values = coerce_list_like(plain) + if values is not None: + return "; ".join(text for text in (constraint_payload_to_text(item) for item in values) if text) + + return normalize_cell(plain) + + +def normalize_database_schema(value: Any) -> dict[str, Any]: + """Normalize a generated database schema payload. + + Args: + value: Generated schema value from a row. + + Returns: + Schema metadata as a dictionary. + + Raises: + ValueError: If the schema is not mapping-like. + """ + parsed = parse_generated_payload(value, "database_schema") + if not isinstance(parsed, Mapping): + msg = "database_schema must be a mapping generated by an upstream Data Designer column" + raise ValueError(msg) + + schema = dict(parsed) + if "record_type" not in schema: + schema["record_type"] = DEFAULT_DATABASE_SCHEMA["record_type"] + if "primary_key" not in schema: + schema["primary_key"] = DEFAULT_DATABASE_SCHEMA["primary_key"] + if "fields" not in schema: + schema["fields"] = DEFAULT_DATABASE_SCHEMA["fields"] + return schema + + +def extract_records_payload(value: Any) -> list[Any]: + """Extract generated record payloads from common structured output shapes. + + Args: + value: Generated records value from a row. + + Returns: + List of record payloads. + + Raises: + ValueError: If no list-like records can be extracted. + """ + parsed = parse_generated_payload(value, "database_records") + if isinstance(parsed, Mapping): + for key in ("records", "items", "data", "rows"): + nested = parsed.get(key) + if nested is not None: + records = coerce_list_like(nested) + if records is not None: + return records + msg = "database_records mapping must contain a records, items, data, or rows list" + raise ValueError(msg) + + records = coerce_list_like(parsed) + if records is None: + msg = "database_records must be a list or an object containing a records list" + raise ValueError(msg) + return records + + +def normalize_tags(value: Any, record_id: str) -> list[str]: + """Normalize generated record tags. + + Args: + value: Generated tags value. + record_id: Record id used in errors. + + Returns: + List of tag strings. + + Raises: + ValueError: If tags are missing or cannot be interpreted as a non-empty list. + """ + tags = coerce_list_like(value) + if tags is None and isinstance(value, str): + tags = [tag.strip() for tag in re.split(r"[,;]", value) if tag.strip()] + if tags is None or not tags: + msg = f"generated record {record_id!r} must include at least one tag" + raise ValueError(msg) + return [str(tag).strip().lower() for tag in tags if str(tag).strip()] + + +def normalize_int_field(value: Any, field_name: str, record_id: str) -> int: + """Normalize an integer field from a generated record. + + Args: + value: Generated field value. + field_name: Field name. + record_id: Record id used in errors. + + Returns: + Integer field value. + + Raises: + ValueError: If the value cannot be converted to an integer. + """ + try: + return int(value) + except (TypeError, ValueError) as exc: + msg = f"generated record {record_id!r} field {field_name!r} must be an integer" + raise ValueError(msg) from exc + + +def normalize_generated_record(value: Any, index: int, topic: str) -> dict[str, Any]: + """Normalize and validate one generated database record. + + Args: + value: Generated record value. + index: Zero-based record index. + topic: Generated task topic. + + Returns: + Normalized record. + + Raises: + ValueError: If required fields are absent or invalid. + """ + value = to_plain_data(value) + if not isinstance(value, Mapping): + msg = f"database_records[{index}] must be a mapping" + raise ValueError(msg) + + record = dict(value) + missing = [field for field in REQUIRED_RECORD_FIELDS if field not in record] + if missing: + msg = f"database_records[{index}] is missing required fields: {', '.join(missing)}" + raise ValueError(msg) + + record_id = str(record["record_id"]).strip() + if not record_id: + msg = f"database_records[{index}].record_id must not be empty" + raise ValueError(msg) + + normalized = dict(record) + normalized["record_id"] = record_id + normalized["name"] = str(record["name"]).strip() + normalized["summary"] = str(record["summary"]).strip() + normalized["topic"] = str(record.get("topic") or topic) + normalized["cost"] = normalize_int_field(record["cost"], "cost", record_id) + normalized["duration"] = normalize_int_field(record["duration"], "duration", record_id) + normalized["score"] = normalize_int_field(record["score"], "score", record_id) + normalized["tags"] = normalize_tags(record["tags"], record_id) + attributes = record.get("attributes", {}) + normalized["attributes"] = dict(attributes) if isinstance(attributes, Mapping) else {"value": attributes} + return normalized + + +def normalize_database_records(value: Any, topic: str | None = None) -> list[dict[str, Any]]: + """Normalize generated records restored from memory or saved artifacts. + + Args: + value: Database records payload. + topic: Generated topic used when records omit a topic field. + + Returns: + Database records as plain dictionaries. + + Raises: + ValueError: If records are absent or invalid. + """ + records = extract_records_payload(value) + if not records: + msg = "database_records must contain at least one generated record" + raise ValueError(msg) + topic = topic or "general task" + normalized = [normalize_generated_record(record, index, topic) for index, record in enumerate(records)] + duplicate_ids = sorted( + { + record["record_id"] + for record in normalized + if [candidate["record_id"] for candidate in normalized].count(record["record_id"]) > 1 + } + ) + if duplicate_ids: + msg = f"database_records contain duplicate record_id values: {', '.join(duplicate_ids)}" + raise ValueError(msg) + return normalized + + +def validate_schema_covers_records(schema: Mapping[str, Any], records: list[dict[str, Any]]) -> None: + """Validate that generated records are compatible with the generated schema. + + Args: + schema: Generated schema metadata. + records: Normalized generated records. + + Raises: + ValueError: If the schema primary key is incompatible with records. + """ + primary_key = str(schema.get("primary_key", "record_id")) + if primary_key != "record_id": + msg = "generated database schema primary_key must be 'record_id'" + raise ValueError(msg) + for field in REQUIRED_RECORD_FIELDS: + if field not in records[0]: + msg = f"generated records must include required field {field!r}" + raise ValueError(msg) + + +def record_matches_constraints(record: dict[str, Any], constraints: dict[str, Any]) -> bool: + """Return whether a record satisfies task constraints. + + Args: + record: Database record. + constraints: Task constraints. + + Returns: + ``True`` when the record is eligible. + """ + required_tag = constraints.get("required_tag") + return ( + int(record["cost"]) <= int(constraints["max_cost"]) + and int(record["score"]) >= int(constraints["min_score"]) + and (required_tag is None or str(required_tag) in record.get("tags", [])) + ) + + +def eligible_records(database: list[dict[str, Any]], constraints: dict[str, Any]) -> list[dict[str, Any]]: + """Filter records that satisfy task constraints. + + Args: + database: Sandbox database records. + constraints: Task constraints. + + Returns: + Eligible records. + """ + return [record for record in database if record_matches_constraints(record, constraints)] + + +def select_best_record(records: list[dict[str, Any]]) -> dict[str, Any] | None: + """Select the optimal answer under the verifier ordering. + + Args: + records: Candidate records. + + Returns: + The best record, or ``None`` when no candidates exist. + """ + if not records: + return None + return sorted(records, key=lambda record: (-int(record["score"]), int(record["cost"]), str(record["record_id"])))[0] + + +def default_constraints( + database: list[dict[str, Any]], + config: GeneralistAgentTaskColumnConfig, + difficulty: Difficulty | None = None, +) -> dict[str, Any]: + """Create feasible default constraints for the requested difficulty. + + Args: + database: Sandbox database records. + config: Task column configuration. + difficulty: Difficulty to synthesize; defaults to the configured final difficulty. + + Returns: + Constraint values and repair notes. + """ + difficulty = difficulty or config.difficulty + required_tag = config.required_tag + target_pool = [record for record in database if required_tag is None or required_tag in record["tags"]] + target = select_best_record(target_pool) or select_best_record(database) + if target is None: + msg = "database must contain at least one record" + raise ValueError(msg) + + if required_tag is None and difficulty in ("medium", "hard"): + required_tag = str(target["tags"][0]) + + if difficulty == "simple": + default_max_cost = max(int(record["cost"]) for record in database) + default_min_score = min(int(record["score"]) for record in database) + elif difficulty == "medium": + default_max_cost = int(target["cost"]) + 120 + default_min_score = max(0, int(target["score"]) - 12) + else: + default_max_cost = int(target["cost"]) + 40 + default_min_score = max(0, int(target["score"]) - 4) + + constraints = { + "max_cost": config.max_cost if config.max_cost is not None else default_max_cost, + "min_score": config.min_score if config.min_score is not None else default_min_score, + "required_tag": required_tag, + "repair_notes": [], + } + return repair_constraints(database, constraints) + + +def repair_constraints(database: list[dict[str, Any]], constraints: dict[str, Any]) -> dict[str, Any]: + """Repair constraints that would otherwise make the task unsatisfiable. + + Args: + database: Sandbox database records. + constraints: Initial task constraints. + + Returns: + Feasible constraints plus repair notes. + """ + if eligible_records(database, constraints): + return constraints + + required_tag = constraints.get("required_tag") + target_pool = [record for record in database if required_tag is None or required_tag in record["tags"]] + target = select_best_record(target_pool) or select_best_record(database) + if target is None: + return constraints + + if required_tag is not None and required_tag not in target["tags"]: + constraints["required_tag"] = target["tags"][0] + constraints["repair_notes"].append("required_tag changed to a tag present in the database") + + if int(target["cost"]) > int(constraints["max_cost"]): + constraints["max_cost"] = int(target["cost"]) + constraints["repair_notes"].append("max_cost increased to keep at least one valid candidate") + + if int(target["score"]) < int(constraints["min_score"]): + constraints["min_score"] = int(target["score"]) + constraints["repair_notes"].append("min_score decreased to keep at least one valid candidate") + + return constraints + + +def selected_tool_names(difficulty: Difficulty) -> list[str]: + """Select the synthesized toolset for a difficulty level. + + Args: + difficulty: Final task difficulty. + + Returns: + Tool names to expose to the solution function. + """ + tool_names = ["describe_schema", "list_records", "search_records", "get_record"] + if difficulty in ("medium", "hard"): + tool_names.append("filter_records") + if difficulty == "hard": + tool_names.append("rank_records") + return tool_names + + +def build_tool_specs(tool_names: list[str]) -> list[dict[str, str]]: + """Build tool metadata and function source snippets. + + Args: + tool_names: Selected tool names. + + Returns: + Tool descriptors for the output tuple. + """ + return [ + { + "name": tool_name, + "description": TOOL_DESCRIPTIONS[tool_name], + "source": textwrap.dedent(TOOL_FUNCTION_SOURCES[tool_name]).strip(), + } + for tool_name in tool_names + ] + + +def build_tool_module_source( + database_schema: Mapping[str, Any], database: list[dict[str, Any]], tool_names: list[str] +) -> str: + """Build executable Python source for the generated tool module. + + Args: + database_schema: Generated row-local database schema. + database: Generated sandbox database. + tool_names: Selected tool names. + + Returns: + Python module source defining ``DATABASE_SCHEMA``, ``DATABASE``, and tool functions. + """ + parts = [ + f"DATABASE_SCHEMA = {pformat(dict(database_schema), sort_dicts=False, width=120)}", + f"DATABASE = {pformat(database, sort_dicts=False, width=120)}", + ] + parts.extend(textwrap.dedent(TOOL_FUNCTION_SOURCES[tool_name]).strip() for tool_name in tool_names) + return "\n\n".join(parts) + "\n" + + +def build_task_prompt(topic: str, difficulty: Difficulty, constraints: dict[str, Any]) -> str: + """Create the task prompt presented to a solving agent. + + Args: + topic: Generated task topic. + difficulty: Final task difficulty. + constraints: Task constraints. + + Returns: + Natural language task prompt. + """ + clauses = [ + f"Use the synthesized tools to solve this {difficulty} {topic!r} task.", + "Inspect the generated schema and records through the tool interface; do not access the database directly.", + "Return the record_id for the eligible database record with the highest score.", + f"Only consider records with cost <= {constraints['max_cost']} and score >= {constraints['min_score']}.", + ] + if constraints.get("required_tag") is not None: + clauses.append(f"The record must include the tag {constraints['required_tag']!r}.") + clauses.append("Break ties by lower cost, then lexicographic record_id.") + return " ".join(clauses) + + +def build_reference_answer(database: list[dict[str, Any]], constraints: dict[str, Any]) -> dict[str, Any]: + """Compute the verifier's expected answer. + + Args: + database: Sandbox database records. + constraints: Task constraints. + + Returns: + JSON-compatible answer object. + """ + best = select_best_record(eligible_records(database, constraints)) + if best is None: + return {"record_id": None, "reason": "no eligible records"} + return { + "record_id": best["record_id"], + "score": best["score"], + "cost": best["cost"], + "tags": list(best["tags"]), + } + + +def verify_answer(answer: dict[str, Any], database: list[dict[str, Any]], constraints: dict[str, Any]) -> bool: + """Verify an answer against the database and constraints. + + Args: + answer: Candidate answer. + database: Sandbox database records. + constraints: Task constraints. + + Returns: + ``True`` when the answer is exactly the verifier-optimal record. + """ + if not isinstance(answer, dict): + return False + best = select_best_record(eligible_records(database, constraints)) + if best is None: + return answer.get("record_id") is None + return ( + answer.get("record_id") == best["record_id"] + and int(answer.get("score", -1)) == int(best["score"]) + and int(answer.get("cost", -1)) == int(best["cost"]) + ) + + +def build_solution_source(constraints: dict[str, Any], difficulty: Difficulty) -> str: + """Build a tool-only Python solution function. + + Args: + constraints: Task constraints. + difficulty: Final task difficulty. + + Returns: + Python source defining ``solve(tools)``. + """ + required_tag = repr(constraints.get("required_tag")) + lines = [ + "def solve(tools):", + ' """Solve the task using only synthesized tool functions and local logic."""', + ' tools["describe_schema"]()', + f" required_tag = {required_tag}", + ] + + if difficulty == "simple": + lines.extend( + [ + " candidates = []", + ' for record in tools["list_records"]():', + f' if int(record["cost"]) > {constraints["max_cost"]}:', + " continue", + f' if int(record["score"]) < {constraints["min_score"]}:', + " continue", + ' if required_tag is not None and required_tag not in record.get("tags", []):', + " continue", + " candidates.append(record)", + ] + ) + else: + lines.extend( + [ + ' candidates = tools["filter_records"](', + f" max_cost={constraints['max_cost']},", + f" min_score={constraints['min_score']},", + " required_tag=required_tag,", + " )", + ] + ) + + lines.extend( + [ + " if not candidates:", + ' return {"record_id": None, "reason": "no eligible records"}', + ] + ) + if difficulty == "hard": + lines.extend( + [ + ' ranked = tools["rank_records"](candidates, metric="score", descending=True)', + ' ranked = sorted(ranked, key=lambda record: (-int(record["score"]), int(record["cost"]), str(record["record_id"])))', + ] + ) + else: + lines.append( + ' ranked = sorted(candidates, key=lambda record: (-int(record["score"]), int(record["cost"]), str(record["record_id"])))' + ) + + lines.extend( + [ + " best = ranked[0]", + " return {", + ' "record_id": best["record_id"],', + ' "score": best["score"],', + ' "cost": best["cost"],', + ' "tags": list(best.get("tags", [])),', + " }", + ] + ) + return "\n".join(lines) + + +def build_verifier_source(constraints: dict[str, Any]) -> str: + """Build a Python verifier function for the synthesized task. + + Args: + constraints: Task constraints. + + Returns: + Python source defining ``verify(answer, database)``. + """ + verifier_constraints = { + "max_cost": constraints["max_cost"], + "min_score": constraints["min_score"], + "required_tag": constraints.get("required_tag"), + } + return textwrap.dedent( + f''' + CONSTRAINTS = {pformat(verifier_constraints, sort_dicts=False, width=120)} + + + def verify(answer, database): + """Return True when answer satisfies the task and is verifier-optimal.""" + if not isinstance(answer, dict): + return False + + eligible = [] + for record in database: + if int(record["cost"]) > int(CONSTRAINTS["max_cost"]): + continue + if int(record["score"]) < int(CONSTRAINTS["min_score"]): + continue + required_tag = CONSTRAINTS.get("required_tag") + if required_tag is not None and str(required_tag) not in record.get("tags", []): + continue + eligible.append(record) + + if not eligible: + return answer.get("record_id") is None + + best = sorted( + eligible, + key=lambda record: (-int(record["score"]), int(record["cost"]), str(record["record_id"])), + )[0] + return ( + answer.get("record_id") == best["record_id"] + and int(answer.get("score", -1)) == int(best["score"]) + and int(answer.get("cost", -1)) == int(best["cost"]) + ) + ''' + ).strip() + + +def build_task_iteration( + topic: str, + database: list[dict[str, Any]], + config: GeneralistAgentTaskColumnConfig, + difficulty: Difficulty, +) -> dict[str, Any]: + """Build one synthesized task, solution, and verifier iteration. + + Args: + topic: Generated task topic. + database: Sandbox database records. + config: Task column configuration. + difficulty: Difficulty level for this iteration. + + Returns: + JSON-compatible iteration artifact. + """ + constraints = default_constraints(database, config, difficulty) + answer = build_reference_answer(database, constraints) + verified = verify_answer(answer, database, constraints) + return { + "difficulty": difficulty, + "tool_names": selected_tool_names(difficulty), + "task_prompt": build_task_prompt(topic, difficulty, constraints), + "constraints": constraints, + "solution_source": build_solution_source(constraints, difficulty), + "verifier_source": build_verifier_source(constraints), + "reference_answer": answer, + "reference_solution_passed": verified, + "augmentation_required": difficulty in ("medium", "hard"), + } + + +def difficulty_trace(final_difficulty: Difficulty) -> list[Difficulty]: + """List difficulty levels synthesized before the final task. + + Args: + final_difficulty: Requested final difficulty. + + Returns: + Ordered difficulty names through the final level. + """ + return DIFFICULTY_ORDER[: DIFFICULTY_ORDER.index(final_difficulty) + 1] + + +def build_task_iterations( + topic: str, + database: list[dict[str, Any]], + config: GeneralistAgentTaskColumnConfig, +) -> list[dict[str, Any]]: + """Build the simple-to-final task synthesis iterations. + + Args: + topic: Generated task topic. + database: Sandbox database records. + config: Task column configuration. + + Returns: + Ordered task iteration artifacts. + """ + return [ + build_task_iteration(topic, database, config, difficulty) for difficulty in difficulty_trace(config.difficulty) + ] + + +def build_task_synthesis_trace( + topic: str, + difficulty: Difficulty, + tool_names: list[str], + constraints: dict[str, Any], + verified: bool, +) -> list[dict[str, Any]]: + """Describe the task synthesis workflow for one row. + + Args: + topic: Generated task topic. + difficulty: Final task difficulty. + tool_names: Synthesized tool names. + constraints: Final task constraints. + verified: Whether the generated reference answer passes verification. + + Returns: + Ordered workflow events. + """ + trace: list[dict[str, Any]] = [] + for level in difficulty_trace(difficulty): + trace.append( + { + "stage": "task_synthesis", + "difficulty": level, + "goal": "hard to solve through tools, easy to verify by deterministic constraints", + } + ) + if level in ("medium", "hard"): + trace.append( + { + "stage": "toolset_augmentation", + "difficulty": level, + "available_tools": selected_tool_names(level), + } + ) + trace.append( + { + "stage": "solution_generation", + "topic": topic, + "solution_restriction": "solution source calls synthesized tools and uses local logical computation only", + "final_tools": tool_names, + } + ) + trace.append( + { + "stage": "verification", + "constraints": {key: value for key, value in constraints.items() if key != "repair_notes"}, + "reference_solution_passed": verified, + } + ) + return trace + + +def build_environment_id(topic: str, context_values: dict[str, str], row_number: int) -> str: + """Build a stable row-local environment identifier. + + Args: + topic: Generated task topic. + context_values: Context copied from the seed row. + row_number: Zero-based row position. + + Returns: + Stable environment identifier. + """ + topic_slug = slugify(topic, "task") + context_slug = stable_int(json.dumps(context_values, sort_keys=True), 10_000) + return f"{topic_slug}-{row_number + 1:04d}-{context_slug:04d}" + + +def build_environment_artifact( + topic: str, + constraints_payload: Any, + constraints_text: str, + context_values: dict[str, str], + database_schema: dict[str, Any], + database: list[dict[str, Any]], + row_number: int, +) -> dict[str, Any]: + """Build one standalone generated environment and toolset artifact. + + Args: + topic: Generated task topic. + constraints_payload: Raw generated constraints payload normalized to JSON-like data. + constraints_text: Generated constraints flattened to text. + context_values: Context copied from the seed row. + database_schema: Generated database schema. + database: Generated database records. + row_number: Zero-based row position used for stable ids. + + Returns: + Structured Generalist environment artifact. + """ + validate_schema_covers_records(database_schema, database) + environment_id = build_environment_id(topic, context_values, row_number) + tool_names = selected_tool_names("hard") + return { + "schema_version": "generalist-agent-environment/v1", + "source_workflow": "Generated Generalist environment and toolset assembly", + "environment": { + "environment_id": environment_id, + "topic": topic, + "sandbox": { + "base_tools": list(BASE_SANDBOX_TOOLS), + "database_name": f"{environment_id}_db", + }, + "database_schema": database_schema, + "database": database, + "database_record_count": len(database), + "task_constraints": constraints_payload, + "task_constraints_text": constraints_text, + "source_context": dict(context_values), + "data_generation": { + "mode": "generated_by_data_designer_columns", + "note": "Topic, constraints, schema, and records are generated upstream by Data Designer columns.", + }, + }, + "tools": build_tool_specs(tool_names), + "tool_module_source": build_tool_module_source(database_schema, database, tool_names), + "synthesis_trace": [ + { + "stage": "topic_and_constraint_intake", + "topic": topic, + "constraints_available": bool(constraints_text), + }, + { + "stage": "schema_intake", + "record_type": database_schema.get("record_type"), + "primary_key": database_schema.get("primary_key"), + }, + { + "stage": "generated_data_intake", + "database_record_count": len(database), + "toolset": tool_names, + }, + ], + } + + +def build_task_tuple( + environment_artifact: dict[str, Any], + config: GeneralistAgentTaskColumnConfig, +) -> dict[str, Any]: + """Build one ```` tuple from an environment. + + Args: + environment_artifact: Output from ``generalist-agent-environment``. + config: Task column configuration. + + Returns: + Structured Generalist task tuple. + """ + environment = dict(environment_artifact["environment"]) + database_schema = normalize_database_schema(environment["database_schema"]) + topic = str(environment.get("topic") or "general task") + database = normalize_database_records(environment["database"], topic) + environment["database_schema"] = database_schema + environment["database"] = database + task_iterations = build_task_iterations(topic, database, config) + final_iteration = task_iterations[-1] + constraints = final_iteration["constraints"] + answer = final_iteration["reference_answer"] + verified = bool(final_iteration["reference_solution_passed"]) + tool_names = final_iteration["tool_names"] + + return { + "schema_version": "generalist-agent-task/v1", + "source_workflow": "Generalist task synthesis from generated environment", + "environment": environment, + "tools": build_tool_specs(tool_names), + "tool_module_source": build_tool_module_source(database_schema, database, tool_names), + "task": { + "difficulty": config.difficulty, + "topic": topic, + "prompt": build_task_prompt(topic, config.difficulty, constraints), + "constraints": constraints, + "answer_schema": { + "record_id": "string or null", + "score": "integer when record_id is not null", + "cost": "integer when record_id is not null", + "tags": "list of strings when record_id is not null", + }, + }, + "solution": { + "language": "python", + "entrypoint": "solve", + "source": final_iteration["solution_source"], + "restrictions": [ + "may call synthesized tool functions", + "may perform local logical computation", + "must not directly access the sandbox database", + ], + }, + "verifier": { + "language": "python", + "entrypoint": "verify", + "source": final_iteration["verifier_source"], + "reference_solution_passed": verified, + }, + "reference_answer": answer, + "task_iterations": task_iterations, + "synthesis_trace": [ + *environment_artifact.get("synthesis_trace", []), + *build_task_synthesis_trace(topic, config.difficulty, tool_names, constraints, verified), + ], + "rl_filter_note": "Downstream RL retention can keep generated tuples with non-zero pass@100.", + } + + +class GeneralistAgentEnvironmentColumnGenerator(ColumnGeneratorFullColumn[GeneralistAgentEnvironmentColumnConfig]): + """Assemble generated Generalist environment and toolset artifacts.""" + + def generate(self, data: pd.DataFrame) -> pd.DataFrame: + """Generate environment artifacts from upstream generated schema and records. + + Args: + data: Input DataFrame containing generated task topic, optional + generated constraints, generated database schema, generated + database records, and optional context columns. + + Returns: + The input DataFrame with the configured output column populated. + """ + artifacts: list[dict[str, Any]] = [] + for row_number, (_, row) in enumerate(data.iterrows()): + topic = normalize_cell(row[self.config.task_topic_column]) or "general task" + constraints_cell = ( + row[self.config.task_constraints_column] if self.config.task_constraints_column is not None else None + ) + constraints_payload = to_plain_data(constraints_cell) if constraints_cell is not None else {} + constraints_text = constraint_payload_to_text(constraints_cell) + database_schema = normalize_database_schema(row[self.config.database_schema_column]) + database = normalize_database_records(row[self.config.database_records_column], topic) + context_values = { + column: normalize_cell(row[column]) + for column in self.config.context_columns + if normalize_cell(row[column]) + } + artifacts.append( + build_environment_artifact( + topic, + constraints_payload, + constraints_text, + context_values, + database_schema, + database, + row_number, + ) + ) + data[self.config.name] = artifacts + return data + + +class GeneralistAgentTaskColumnGenerator(ColumnGeneratorFullColumn[GeneralistAgentTaskColumnConfig]): + """Generate Generalist task tuples from constructed generated environments.""" + + def generate(self, data: pd.DataFrame) -> pd.DataFrame: + """Generate task, solution, and verifier tuples from environments. + + Args: + data: Input DataFrame containing the configured environment column. + + Returns: + The input DataFrame with the configured output column populated. + """ + tuples: list[dict[str, Any]] = [] + for _, row in data.iterrows(): + environment_artifact = row[self.config.environment_column] + if not isinstance(environment_artifact, dict): + msg = f"{self.config.environment_column!r} must contain environment artifact dictionaries" + raise ValueError(msg) + tuples.append(build_task_tuple(environment_artifact, self.config)) + data[self.config.name] = tuples + return data diff --git a/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/plugin.py b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/plugin.py new file mode 100644 index 0000000..b4dbec8 --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/plugin.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.plugins.plugin import Plugin, PluginType + +environment_plugin = Plugin( + config_qualified_name="data_designer_generalist_agent_env.config.GeneralistAgentEnvironmentColumnConfig", + impl_qualified_name="data_designer_generalist_agent_env.impl.GeneralistAgentEnvironmentColumnGenerator", + plugin_type=PluginType.COLUMN_GENERATOR, +) + +task_plugin = Plugin( + config_qualified_name="data_designer_generalist_agent_env.config.GeneralistAgentTaskColumnConfig", + impl_qualified_name="data_designer_generalist_agent_env.impl.GeneralistAgentTaskColumnGenerator", + plugin_type=PluginType.COLUMN_GENERATOR, +) diff --git a/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/validation.py b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/validation.py new file mode 100644 index 0000000..e824ec2 --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/src/data_designer_generalist_agent_env/validation.py @@ -0,0 +1,599 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class ToolExecutionCheck: + """Execution result for one generated tool. + + Attributes: + name: Tool name from the generated row artifact. + passed: Whether the tool executed and returned the expected output shape. + output_type: Python type name returned by the smoke invocation. + output_size: Length of the output when the output is a sized collection. + error: Error message when execution failed. + """ + + name: str + passed: bool + output_type: str | None = None + output_size: int | None = None + error: str | None = None + + +@dataclass(frozen=True) +class IterationExecutionCheck: + """Execution result for one generated task iteration. + + Attributes: + difficulty: Iteration difficulty label. + passed: Whether the iteration solution was accepted by its verifier. + answer: Answer returned by the iteration solution. + verifier_passed: Raw verifier decision for the generated answer. + error: Error message when execution failed. + """ + + difficulty: str + passed: bool + answer: Any | None = None + verifier_passed: bool = False + error: str | None = None + + +@dataclass(frozen=True) +class RowRecordValidationResult: + """Validation result for a generated Generalist environment row record. + + Attributes: + passed: Whether all executable artifacts passed validation. + answer: Answer returned by the final generated solution. + verifier_passed: Raw verifier decision for the final generated answer. + tools_passed: Whether all generated tools passed smoke execution. + tool_checks: Per-tool execution checks. + iteration_checks: Per-iteration solution and verifier checks. + errors: Validation errors collected across executable artifacts. + """ + + passed: bool + answer: Any | None + verifier_passed: bool + tools_passed: bool + tool_checks: list[ToolExecutionCheck] = field(default_factory=list) + iteration_checks: list[IterationExecutionCheck] = field(default_factory=list) + errors: list[str] = field(default_factory=list) + + +def failed_validation(error: str) -> RowRecordValidationResult: + """Build a failed validation result. + + Args: + error: Error message to attach to the result. + + Returns: + A failed validation result with no executable artifacts. + """ + return RowRecordValidationResult( + passed=False, + answer=None, + verifier_passed=False, + tools_passed=False, + errors=[error], + ) + + +def execute_source_module(source: str, expected_entrypoints: list[str] | None = None) -> dict[str, Any]: + """Execute generated Python source and return its namespace. + + Args: + source: Python module source emitted by the plugin. + expected_entrypoints: Callable names that must exist after execution. + + Returns: + The execution namespace. + + Raises: + ValueError: If an expected entrypoint is missing or is not callable. + Exception: Any exception raised by the generated source during execution. + """ + namespace: dict[str, Any] = {} + exec(source, namespace) + for entrypoint in expected_entrypoints or []: + candidate = namespace.get(entrypoint) + if not callable(candidate): + msg = f"expected callable {entrypoint!r} in generated source" + raise ValueError(msg) + return namespace + + +def extract_environment_tuple(row_record: Mapping[str, Any], output_column: str | None = None) -> Mapping[str, Any]: + """Extract the generated environment tuple from a row-like record. + + Args: + row_record: Either the generated environment tuple itself or a row mapping + that contains the tuple in ``output_column``. + output_column: Optional output column containing the generated tuple. + + Returns: + The generated environment tuple mapping. + + Raises: + KeyError: If ``output_column`` is supplied but absent. + TypeError: If the extracted value is not mapping-like. + ValueError: If ``row_record`` is not already an environment tuple and no + ``output_column`` is supplied. + """ + if output_column is None: + if "schema_version" in row_record and "environment" in row_record and "tools" in row_record: + return row_record + msg = "row_record must be an environment tuple unless output_column is provided" + raise ValueError(msg) + + if output_column not in row_record: + msg = f"row_record does not contain output column {output_column!r}" + raise KeyError(msg) + + environment_tuple = row_record[output_column] + if not isinstance(environment_tuple, Mapping): + msg = f"row_record[{output_column!r}] must be a mapping" + raise TypeError(msg) + return environment_tuple + + +def output_size(output: Any) -> int | None: + """Return a compact size for common collection outputs. + + Args: + output: Tool output. + + Returns: + The output length for common collections, otherwise ``None``. + """ + if isinstance(output, (dict, list, set, tuple)): + return len(output) + return None + + +def coerce_list_like(value: Any) -> list[Any] | None: + """Coerce common list-like values into a Python list. + + Args: + value: Candidate list-like value. This includes values restored from + nested Parquet structures, such as NumPy arrays, without importing + optional array libraries directly. + + Returns: + A Python list when coercion is possible, otherwise ``None``. + """ + if isinstance(value, list): + return value + if isinstance(value, tuple): + return list(value) + tolist = getattr(value, "tolist", None) + if callable(tolist): + converted = tolist() + if isinstance(converted, list): + return converted + return None + + +def to_plain_data(value: Any) -> Any: + """Convert nested array-like values into JSON-style Python containers. + + Args: + value: Arbitrary generated artifact value. + + Returns: + The value with mappings and list-like values recursively normalized. + """ + if isinstance(value, Mapping): + return {key: to_plain_data(nested_value) for key, nested_value in value.items()} + + values = coerce_list_like(value) + if values is not None: + return [to_plain_data(nested_value) for nested_value in values] + + item = getattr(value, "item", None) + if callable(item) and not isinstance(value, (str, bytes)): + try: + return item() + except (TypeError, ValueError): + return value + return value + + +def normalize_database(database: Any) -> tuple[list[dict[str, Any]] | None, str | None]: + """Normalize an environment database for execution. + + Args: + database: Database value from a generated row. In-memory Data Designer + rows use lists, while saved Parquet artifacts may restore nested + lists as array-like values. + + Returns: + A tuple of normalized database records and an error message. Exactly one + element is non-``None``. + """ + records = coerce_list_like(database) + if records is None: + return None, "environment.database must be list-like" + + normalized_records: list[dict[str, Any]] = [] + for index, record in enumerate(records): + if not isinstance(record, Mapping): + return None, f"environment.database[{index}] must be a mapping" + normalized_record = dict(record) + tags = normalized_record.get("tags") + if tags is not None and not isinstance(tags, list): + normalized_tags = coerce_list_like(tags) + if normalized_tags is None: + return None, f"environment.database[{index}].tags must be list-like" + normalized_record["tags"] = normalized_tags + normalized_records.append(normalized_record) + return normalized_records, None + + +def tool_output_error(tool_name: str, output: Any) -> str | None: + """Validate the expected output shape for a generated tool. + + Args: + tool_name: Tool name. + output: Value returned by the tool smoke invocation. + + Returns: + An error message when the output shape is unexpected, otherwise ``None``. + """ + if tool_name == "describe_schema" and not isinstance(output, dict): + return f"{tool_name} returned {type(output).__name__}; expected dict" + if tool_name in {"list_records", "search_records", "filter_records", "rank_records"}: + if not isinstance(output, list): + return f"{tool_name} returned {type(output).__name__}; expected list" + if tool_name == "get_record" and output is not None and not isinstance(output, dict): + return f"get_record returned {type(output).__name__}; expected dict or None" + return None + + +def invoke_tool_for_smoke_check( + tool_name: str, + tool: Callable[..., Any], + database: list[dict[str, Any]], + constraints: Mapping[str, Any], +) -> Any: + """Invoke one generated tool with a row-local smoke-test call. + + Args: + tool_name: Tool name. + tool: Callable loaded from the generated tool module. + database: Row-local sandbox database. + constraints: Final task constraints. + + Returns: + The tool output. + """ + if tool_name == "describe_schema": + return tool() + if tool_name == "list_records": + return tool() + if tool_name == "search_records": + return tool("", max_results=2) + if tool_name == "get_record": + record_id = database[0].get("record_id") if database else "__missing__" + return tool(record_id) + if tool_name == "filter_records": + return tool( + max_cost=constraints.get("max_cost"), + min_score=constraints.get("min_score"), + required_tag=constraints.get("required_tag"), + ) + if tool_name == "rank_records": + return tool(list(database), metric="score", descending=True) + return tool() + + +def run_tool_execution_check( + tool_name: str, + tool: Callable[..., Any], + database: list[dict[str, Any]], + constraints: Mapping[str, Any], +) -> ToolExecutionCheck: + """Execute one generated tool and validate its output shape. + + Args: + tool_name: Tool name. + tool: Callable loaded from the generated tool module. + database: Row-local sandbox database. + constraints: Final task constraints. + + Returns: + A structured per-tool execution result. + """ + try: + output = invoke_tool_for_smoke_check(tool_name, tool, database, constraints) + except Exception as exc: # noqa: BLE001 + return ToolExecutionCheck(name=tool_name, passed=False, error=str(exc)) + + error = tool_output_error(tool_name, output) + return ToolExecutionCheck( + name=tool_name, + passed=error is None, + output_type=type(output).__name__, + output_size=output_size(output), + error=error, + ) + + +def build_tools_from_namespace( + tool_names: list[str], + namespace: Mapping[str, Any], +) -> tuple[dict[str, Callable[..., Any]], list[str]]: + """Build the generated tool mapping from an executed namespace. + + Args: + tool_names: Names requested by the row artifact. + namespace: Namespace returned by ``execute_source_module``. + + Returns: + A tuple of callable tools and validation errors. + """ + tools: dict[str, Callable[..., Any]] = {} + errors: list[str] = [] + for tool_name in tool_names: + candidate = namespace.get(tool_name) + if not callable(candidate): + errors.append(f"generated tool {tool_name!r} is missing or not callable") + continue + tools[tool_name] = candidate + return tools, errors + + +def tool_names_from_specs(tool_specs: Any) -> tuple[list[str], list[str]]: + """Extract tool names from generated tool specs. + + Args: + tool_specs: Value from the environment tuple ``tools`` field. + + Returns: + A tuple of tool names and validation errors. + """ + tool_specs = coerce_list_like(tool_specs) + if tool_specs is None: + return [], ["environment tuple tools field must be list-like"] + + tool_names: list[str] = [] + errors: list[str] = [] + for index, tool_spec in enumerate(tool_specs): + if not isinstance(tool_spec, Mapping): + errors.append(f"tools[{index}] must be a mapping") + continue + tool_name = tool_spec.get("name") + if not isinstance(tool_name, str) or not tool_name: + errors.append(f"tools[{index}].name must be a non-empty string") + continue + tool_names.append(tool_name) + return tool_names, errors + + +def run_solution_and_verifier( + solution_source: str, + solution_entrypoint: str, + verifier_source: str, + verifier_entrypoint: str, + tools: Mapping[str, Callable[..., Any]], + database: list[dict[str, Any]], +) -> tuple[Any | None, bool, list[str]]: + """Execute generated solution source and validate it with generated verifier source. + + Args: + solution_source: Python source defining the solution function. + solution_entrypoint: Name of the solution function. + verifier_source: Python source defining the verifier function. + verifier_entrypoint: Name of the verifier function. + tools: Callable tool mapping exposed to the solution. + database: Row-local sandbox database exposed to the verifier. + + Returns: + The solution answer, verifier decision, and collected execution errors. + """ + errors: list[str] = [] + try: + solution_namespace = execute_source_module(solution_source, [solution_entrypoint]) + answer = solution_namespace[solution_entrypoint](dict(tools)) + except Exception as exc: # noqa: BLE001 + return None, False, [f"solution execution failed: {exc}"] + + try: + verifier_namespace = execute_source_module(verifier_source, [verifier_entrypoint]) + verifier_passed = bool(verifier_namespace[verifier_entrypoint](answer, database)) + except Exception as exc: # noqa: BLE001 + return answer, False, [f"verifier execution failed: {exc}"] + + if not verifier_passed: + errors.append("verifier rejected generated solution answer") + return answer, verifier_passed, errors + + +def run_iteration_execution_check( + iteration: Mapping[str, Any], + tool_namespace: Mapping[str, Any], + database: list[dict[str, Any]], +) -> IterationExecutionCheck: + """Execute and verify one generated task iteration. + + Args: + iteration: Task iteration artifact from ``task_iterations``. + tool_namespace: Executed namespace from ``tool_module_source``. + database: Row-local sandbox database. + + Returns: + A structured per-iteration execution result. + """ + difficulty = str(iteration.get("difficulty", "unknown")) + tool_names = coerce_list_like(iteration.get("tool_names", [])) + if tool_names is None or not all(isinstance(tool_name, str) for tool_name in tool_names): + return IterationExecutionCheck( + difficulty=difficulty, + passed=False, + error="iteration tool_names must be a list-like value of strings", + ) + + tools, tool_errors = build_tools_from_namespace(tool_names, tool_namespace) + if tool_errors: + return IterationExecutionCheck(difficulty=difficulty, passed=False, error="; ".join(tool_errors)) + + answer, verifier_passed, errors = run_solution_and_verifier( + str(iteration.get("solution_source", "")), + "solve", + str(iteration.get("verifier_source", "")), + "verify", + tools, + database, + ) + reference_answer = iteration.get("reference_answer") + if reference_answer is not None and to_plain_data(answer) != to_plain_data(reference_answer): + errors.append("iteration answer does not match reference_answer") + + expected_passed = iteration.get("reference_solution_passed") + if expected_passed is not None and bool(expected_passed) != verifier_passed: + errors.append("iteration reference_solution_passed does not match verifier result") + + return IterationExecutionCheck( + difficulty=difficulty, + passed=not errors and verifier_passed, + answer=answer, + verifier_passed=verifier_passed, + error="; ".join(errors) if errors else None, + ) + + +def verify_environment_tuple(environment_tuple: Mapping[str, Any]) -> RowRecordValidationResult: + """Verify one generated environment tuple by executing all generated artifacts. + + The helper executes the generated tool module, smoke-tests every declared + tool, runs the final generated ``solve(tools)`` function, and checks the + result with the generated ``verify(answer, database)`` function. It also + replays every artifact in ``task_iterations`` when present. + + Args: + environment_tuple: Generated ``generalist-agent-task`` output value. + + Returns: + A structured validation result with per-artifact status and errors. + """ + errors: list[str] = [] + try: + environment = environment_tuple["environment"] + database = environment["database"] + task = environment_tuple["task"] + constraints = task["constraints"] + except KeyError as exc: + return failed_validation(f"environment tuple is missing required key: {exc}") + + database, database_error = normalize_database(database) + if database_error is not None or database is None: + return failed_validation(database_error or "environment.database could not be normalized") + if not isinstance(constraints, Mapping): + return failed_validation("task.constraints must be a mapping") + + tool_names, tool_spec_errors = tool_names_from_specs(environment_tuple.get("tools")) + errors.extend(tool_spec_errors) + + try: + tool_namespace = execute_source_module(str(environment_tuple["tool_module_source"]), tool_names) + except Exception as exc: # noqa: BLE001 + return RowRecordValidationResult( + passed=False, + answer=None, + verifier_passed=False, + tools_passed=False, + errors=[*errors, f"tool_module_source execution failed: {exc}"], + ) + + tools, tool_errors = build_tools_from_namespace(tool_names, tool_namespace) + errors.extend(tool_errors) + tool_checks = [ + run_tool_execution_check(tool_name, tool, database, constraints) for tool_name, tool in tools.items() + ] + errors.extend( + f"tool {check.name!r} failed smoke execution: {check.error}" for check in tool_checks if not check.passed + ) + tools_passed = not tool_spec_errors and not tool_errors and all(check.passed for check in tool_checks) + + solution = environment_tuple.get("solution", {}) + verifier = environment_tuple.get("verifier", {}) + if not isinstance(solution, Mapping) or not isinstance(verifier, Mapping): + return RowRecordValidationResult( + passed=False, + answer=None, + verifier_passed=False, + tools_passed=tools_passed, + tool_checks=tool_checks, + errors=[*errors, "solution and verifier fields must be mappings"], + ) + + answer, verifier_passed, source_errors = run_solution_and_verifier( + str(solution.get("source", "")), + str(solution.get("entrypoint", "solve")), + str(verifier.get("source", "")), + str(verifier.get("entrypoint", "verify")), + tools, + database, + ) + errors.extend(source_errors) + + reference_answer = environment_tuple.get("reference_answer") + if reference_answer is not None and to_plain_data(answer) != to_plain_data(reference_answer): + errors.append("final solution answer does not match reference_answer") + + expected_passed = verifier.get("reference_solution_passed") + if expected_passed is not None and bool(expected_passed) != verifier_passed: + errors.append("verifier.reference_solution_passed does not match verifier result") + + iteration_checks: list[IterationExecutionCheck] = [] + task_iterations = environment_tuple.get("task_iterations", []) + task_iterations = coerce_list_like(task_iterations) + if task_iterations is not None: + iteration_checks = [ + run_iteration_execution_check(iteration, tool_namespace, database) + for iteration in task_iterations + if isinstance(iteration, Mapping) + ] + errors.extend( + f"iteration {check.difficulty!r} failed execution: {check.error}" + for check in iteration_checks + if not check.passed + ) + else: + errors.append("task_iterations must be list-like when present") + + passed = not errors and tools_passed and verifier_passed and all(check.passed for check in iteration_checks) + return RowRecordValidationResult( + passed=passed, + answer=answer, + verifier_passed=verifier_passed, + tools_passed=tools_passed, + tool_checks=tool_checks, + iteration_checks=iteration_checks, + errors=errors, + ) + + +def verify_row_record(row_record: Mapping[str, Any], output_column: str | None = None) -> RowRecordValidationResult: + """Verify a Data Designer row record containing a generated environment tuple. + + Args: + row_record: Row mapping or generated environment tuple. + output_column: Optional column name containing the generated tuple. + + Returns: + A structured validation result. + """ + try: + environment_tuple = extract_environment_tuple(row_record, output_column) + except (KeyError, TypeError, ValueError) as exc: + return failed_validation(str(exc)) + return verify_environment_tuple(environment_tuple) diff --git a/plugins/data-designer-generalist-agent-env/tests/test_plugin.py b/plugins/data-designer-generalist-agent-env/tests/test_plugin.py new file mode 100644 index 0000000..d3eee7c --- /dev/null +++ b/plugins/data-designer-generalist-agent-env/tests/test_plugin.py @@ -0,0 +1,393 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from copy import deepcopy +from pathlib import Path + +import pandas as pd +import pytest +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.seed_source_dataframe import DataFrameSeedSource +from data_designer.engine.testing.utils import assert_valid_plugin +from data_designer.interface.data_designer import DataDesigner +from pydantic import ValidationError + +from data_designer_generalist_agent_env.config import ( + GeneralistAgentEnvironmentColumnConfig, + GeneralistAgentTaskColumnConfig, +) +from data_designer_generalist_agent_env.impl import ( + GeneralistAgentEnvironmentColumnGenerator, + GeneralistAgentTaskColumnGenerator, + build_environment_artifact, + build_reference_answer, + build_task_tuple, + default_constraints, + selected_tool_names, +) +from data_designer_generalist_agent_env.plugin import environment_plugin, task_plugin +from data_designer_generalist_agent_env.validation import verify_environment_tuple, verify_row_record + + +def generated_schema() -> dict: + """Return a representative upstream-generated database schema.""" + return { + "record_type": "trip_candidate", + "primary_key": "record_id", + "fields": [ + {"name": "record_id", "type": "string"}, + {"name": "name", "type": "string"}, + {"name": "summary", "type": "string"}, + {"name": "cost", "type": "integer"}, + {"name": "duration", "type": "integer"}, + {"name": "score", "type": "integer"}, + {"name": "tags", "type": "list[string]"}, + {"name": "attributes", "type": "object"}, + ], + "attribute_fields": [ + {"name": "hotel_fit", "type": "integer"}, + {"name": "transport_risk", "type": "integer"}, + {"name": "restaurant_quality", "type": "integer"}, + ], + } + + +def generated_records() -> list[dict]: + """Return representative upstream-generated database records.""" + return [ + { + "record_id": "trip-001", + "name": "Museum Rail Plan", + "summary": "Generated itinerary candidate with reliable transit and moderate cost.", + "cost": 240, + "duration": 3, + "score": 92, + "tags": ["reliable", "museum", "budget"], + "attributes": {"hotel_fit": 88, "transport_risk": 12, "restaurant_quality": 82}, + }, + { + "record_id": "trip-002", + "name": "Luxury Dining Plan", + "summary": "Generated itinerary candidate with high restaurant quality and higher cost.", + "cost": 520, + "duration": 3, + "score": 97, + "tags": ["restaurant", "premium", "ranked"], + "attributes": {"hotel_fit": 90, "transport_risk": 18, "restaurant_quality": 96}, + }, + { + "record_id": "trip-003", + "name": "Compact Family Plan", + "summary": "Generated itinerary candidate that balances family activities and reliable transport.", + "cost": 180, + "duration": 2, + "score": 95, + "tags": ["reliable", "family", "verified"], + "attributes": {"hotel_fit": 91, "transport_risk": 10, "restaurant_quality": 80}, + }, + ] + + +def test_valid_plugin() -> None: + assert_valid_plugin(environment_plugin) + assert_valid_plugin(task_plugin) + + +def make_environment_generator( + config: GeneralistAgentEnvironmentColumnConfig, +) -> GeneralistAgentEnvironmentColumnGenerator: + """Create an environment generator instance without requiring a ResourceProvider.""" + generator = GeneralistAgentEnvironmentColumnGenerator.__new__(GeneralistAgentEnvironmentColumnGenerator) + generator._config = config + return generator + + +def make_task_generator(config: GeneralistAgentTaskColumnConfig) -> GeneralistAgentTaskColumnGenerator: + """Create a task generator instance without requiring a ResourceProvider.""" + generator = GeneralistAgentTaskColumnGenerator.__new__(GeneralistAgentTaskColumnGenerator) + generator._config = config + return generator + + +def build_valid_task_tuple() -> dict: + """Build a representative valid task tuple for validation tests.""" + task_config = GeneralistAgentTaskColumnConfig( + name="agent_task", + environment_column="agent_environment", + difficulty="hard", + required_tag="reliable", + ) + environment = build_environment_artifact( + "trip planning", + { + "goal": "plan a constrained itinerary", + "constraints": ["moderate budget", "reliable transport", "strong local evidence"], + }, + "goal: plan a constrained itinerary; constraints: moderate budget; reliable transport", + {"notes": "family-friendly museums and restaurants"}, + generated_schema(), + generated_records(), + row_number=0, + ) + return build_task_tuple(environment, task_config) + + +class TestGeneralistAgentEnvColumnConfig: + def test_environment_config_required_columns_include_generated_schema_and_records(self) -> None: + config = GeneralistAgentEnvironmentColumnConfig( + name="agent_environment", + task_topic_column="topic", + task_constraints_column="constraints", + database_schema_column="schema", + database_records_column="records", + context_columns=["notes", "persona"], + ) + + assert config.required_columns == ["topic", "constraints", "schema", "records", "notes", "persona"] + assert config.side_effect_columns == [] + + def test_task_config_requires_environment_column(self) -> None: + config = GeneralistAgentTaskColumnConfig( + name="agent_task", + environment_column="agent_environment", + ) + + assert config.required_columns == ["agent_environment"] + assert config.side_effect_columns == [] + + def test_rejects_repeated_input_columns(self) -> None: + with pytest.raises(ValidationError, match="must be distinct"): + GeneralistAgentEnvironmentColumnConfig( + name="agent_environment", + task_topic_column="topic", + task_constraints_column="constraints", + database_schema_column="schema", + database_records_column="records", + context_columns=["constraints"], + ) + + def test_rejects_empty_topic_column(self) -> None: + with pytest.raises(ValidationError, match="task_topic_column must not be empty"): + GeneralistAgentEnvironmentColumnConfig( + name="agent_environment", + task_topic_column=" ", + database_schema_column="schema", + database_records_column="records", + ) + + def test_normalizes_task_required_tag(self) -> None: + config = GeneralistAgentTaskColumnConfig( + name="agent_task", + environment_column="agent_environment", + required_tag=" Reliable ", + ) + + assert config.required_tag == "reliable" + + +class TestGeneralistAgentEnvHelpers: + def test_tool_names_follow_difficulty(self) -> None: + assert selected_tool_names("simple") == ["describe_schema", "list_records", "search_records", "get_record"] + assert selected_tool_names("medium") == [ + "describe_schema", + "list_records", + "search_records", + "get_record", + "filter_records", + ] + assert selected_tool_names("hard") == [ + "describe_schema", + "list_records", + "search_records", + "get_record", + "filter_records", + "rank_records", + ] + + def test_reference_answer_is_verifier_optimal(self) -> None: + environment_tuple = build_valid_task_tuple() + + validation = verify_environment_tuple(environment_tuple) + + assert validation.passed is True + assert validation.verifier_passed is True + assert validation.tools_passed is True + assert validation.answer == environment_tuple["reference_answer"] + assert environment_tuple["verifier"]["reference_solution_passed"] is True + assert environment_tuple["task"]["constraints"]["required_tag"] == "reliable" + + def test_constraints_are_repaired_when_user_values_are_unsat(self) -> None: + task_config = GeneralistAgentTaskColumnConfig( + name="agent_task", + environment_column="agent_environment", + required_tag="rare", + max_cost=1, + min_score=100, + ) + environment = build_environment_artifact( + "debugging a build failure", + {}, + "", + {}, + generated_schema(), + generated_records(), + row_number=0, + ) + task_tuple = build_task_tuple(environment, task_config) + database = task_tuple["environment"]["database"] + constraints = default_constraints(database, task_config) + answer = build_reference_answer(database, constraints) + + assert constraints["repair_notes"] + assert answer["record_id"] is not None + + +class TestGeneralistAgentEnvColumnGenerator: + def test_two_step_environment_then_task_generation(self) -> None: + source_df = pd.DataFrame( + { + "topic": ["trip planning"], + "constraints": [ + { + "goal": "build a three-day itinerary", + "constraints": ["hotels, restaurants, and attractions", "moderate budget"], + "success_criteria": ["reliable transport", "strong local evidence"], + } + ], + "schema": [generated_schema()], + "records": [{"records": generated_records()}], + "notes": ["family-friendly museums, moderate budget, reliable transport"], + } + ) + environment_config = GeneralistAgentEnvironmentColumnConfig( + name="agent_environment", + task_topic_column="topic", + task_constraints_column="constraints", + database_schema_column="schema", + database_records_column="records", + context_columns=["notes"], + ) + task_config = GeneralistAgentTaskColumnConfig( + name="agent_task", + environment_column="agent_environment", + difficulty="hard", + required_tag="reliable", + ) + environment_generator = make_environment_generator(environment_config) + task_generator = make_task_generator(task_config) + + with_environment = environment_generator.generate(source_df) + result = task_generator.generate(with_environment) + environment_artifact = result.loc[0, "agent_environment"] + task_tuple = result.loc[0, "agent_task"] + validation = verify_environment_tuple(task_tuple) + + assert environment_artifact["schema_version"] == "generalist-agent-environment/v1" + assert environment_artifact["environment"]["data_generation"]["mode"] == "generated_by_data_designer_columns" + assert environment_artifact["environment"]["database_schema"]["record_type"] == "trip_candidate" + assert environment_artifact["environment"]["database"][0]["record_id"] == "trip-001" + assert task_tuple["schema_version"] == "generalist-agent-task/v1" + assert task_tuple["task"]["constraints"]["required_tag"] == "reliable" + assert "describe_schema" in task_tuple["solution"]["source"] + assert validation.passed is True + + def test_generated_python_sources_pass_verifier(self) -> None: + task_tuple = build_valid_task_tuple() + + validation = verify_environment_tuple(task_tuple) + + assert validation.passed is True + assert validation.answer["record_id"] + assert {check.name for check in validation.tool_checks} == set(selected_tool_names("hard")) + assert [check.difficulty for check in validation.iteration_checks] == ["simple", "medium", "hard"] + + def test_row_record_validation_reads_named_output_column(self) -> None: + task_tuple = build_valid_task_tuple() + row = pd.Series({"agent_task": task_tuple}) + + validation = verify_row_record(row, output_column="agent_task") + + assert validation.passed is True + assert validation.verifier_passed is True + + def test_row_record_validation_reports_missing_tool_implementation(self) -> None: + task_tuple = build_valid_task_tuple() + broken_tuple = deepcopy(task_tuple) + broken_tuple["tool_module_source"] = broken_tuple["tool_module_source"].replace( + "def rank_records(", + "def missing_rank_records(", + ) + + validation = verify_environment_tuple(broken_tuple) + + assert validation.passed is False + assert validation.tools_passed is False + assert any("rank_records" in error for error in validation.errors) + + def test_rejects_generated_records_missing_required_fields(self) -> None: + source_df = pd.DataFrame( + { + "topic": ["trip planning"], + "schema": [generated_schema()], + "records": [{"records": [{"record_id": "bad"}]}], + } + ) + config = GeneralistAgentEnvironmentColumnConfig( + name="agent_environment", + task_topic_column="topic", + database_schema_column="schema", + database_records_column="records", + ) + generator = make_environment_generator(config) + + with pytest.raises(ValueError, match="missing required fields"): + generator.generate(source_df) + + def test_row_record_validation_accepts_parquet_restored_arrays(self, tmp_path: Path) -> None: + task_tuple = build_valid_task_tuple() + dataset_path = tmp_path / "dataset.parquet" + pd.DataFrame({"agent_task": [task_tuple]}).to_parquet(dataset_path) + restored = pd.read_parquet(dataset_path) + + validation = verify_row_record(restored.loc[0], output_column="agent_task") + + assert validation.passed is True + assert validation.answer == task_tuple["reference_answer"] + + +class TestGeneralistAgentEnvPreviewIntegration: + def test_preview_generates_environment_tuple(self, tmp_path: Path) -> None: + seed_df = pd.DataFrame( + { + "topic": ["planning a travel itinerary"], + "constraints": ["compare candidate plans by score, cost, and family suitability"], + "schema": [generated_schema()], + "records": [{"records": generated_records()}], + "notes": ["family-friendly museums and restaurants"], + } + ) + + builder = DataDesignerConfigBuilder() + builder.with_seed_dataset(DataFrameSeedSource(df=seed_df)) + builder.add_column( + name="agent_environment", + column_type="generalist-agent-environment", + task_topic_column="topic", + task_constraints_column="constraints", + database_schema_column="schema", + database_records_column="records", + context_columns=["notes"], + ) + builder.add_column( + name="agent_task", + column_type="generalist-agent-task", + environment_column="agent_environment", + required_tag="family", + ) + + result = DataDesigner(artifact_path=tmp_path / "artifacts").preview(builder, num_records=1) + + assert result.dataset is not None + environment_tuple = result.dataset.loc[0, "agent_task"] + assert environment_tuple["task"]["constraints"]["required_tag"] == "family" + assert environment_tuple["verifier"]["reference_solution_passed"] is True diff --git a/pyproject.toml b/pyproject.toml index cffc13f..34b9263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ ignore = [ ] [tool.ruff.lint.isort] -known-first-party = ["ddp", "data_designer_template"] +known-first-party = ["ddp", "data_designer_generalist_agent_env", "data_designer_template"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" diff --git a/uv.lock b/uv.lock index 781f8c8..90231e2 100644 --- a/uv.lock +++ b/uv.lock @@ -9,6 +9,7 @@ resolution-markers = [ [manifest] members = [ + "data-designer-generalist-agent-env", "data-designer-plugins-workspace", "data-designer-template", "ddp", @@ -353,7 +354,7 @@ wheels = [ [[package]] name = "data-designer" -version = "0.5.7" +version = "0.5.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "data-designer-config" }, @@ -361,14 +362,14 @@ dependencies = [ { name = "prompt-toolkit" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dd/4b/00aeaaf364f1a7efbf5103954196ca351cdecc6d65203e5a7e4e33a69a2b/data_designer-0.5.7.tar.gz", hash = "sha256:374f9d15f7774fb5a79935b9e6ce989b7b5c364a8d1e0ce0e6e792258376b1a3", size = 120078, upload-time = "2026-04-17T22:03:14.088Z" } +sdist = { url = "https://files.pythonhosted.org/packages/21/a5/ee29d0f858e8b36e348fc4b99929384865a31cbce81c4e303bf96d6a39c7/data_designer-0.5.9.tar.gz", hash = "sha256:2b50e075bf4f58532fb22dea5aa777fe5679050b57e540ab7f51de33e8d74299", size = 120115, upload-time = "2026-04-28T23:23:04.673Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/18/93/ddadb9707ba8bde858bcca5cdfcaa016426b9ae9e5ce3bbfbaf3813a281f/data_designer-0.5.7-py3-none-any.whl", hash = "sha256:ec4f162b1c248c8d7fe81a8ca19c246998e0bb557f8dfbe629b8c85ac7e68182", size = 99133, upload-time = "2026-04-17T22:03:12.507Z" }, + { url = "https://files.pythonhosted.org/packages/3f/2b/b6adf669d53d04d2c5c78608974f8e68d32cb22334e1bbbb0c4998ac34db/data_designer-0.5.9-py3-none-any.whl", hash = "sha256:5c583d635fc26e5effe63f96001d1d11cdb8a5e23df96d1855780a7224559b8e", size = 99164, upload-time = "2026-04-28T23:23:03.397Z" }, ] [[package]] name = "data-designer-config" -version = "0.5.7" +version = "0.5.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2" }, @@ -384,14 +385,14 @@ dependencies = [ { name = "requests" }, { name = "rich" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/b6/e1b29e2fc98322f9865f20a0c3baa18972bbe353b65dd52c7f9786f8b9c5/data_designer_config-0.5.7.tar.gz", hash = "sha256:248b28ad2ec446599614e4656bae443ba9a9f3805e14ab478374fb34eb89636d", size = 128660, upload-time = "2026-04-17T22:03:06.42Z" } +sdist = { url = "https://files.pythonhosted.org/packages/72/a4/304374c08ede51262aaeb59848adb6a3cf8b18811847111e6152821ecebd/data_designer_config-0.5.9.tar.gz", hash = "sha256:401cad60ac68f15f05d694c4a053f8ee9001609838f96f6022394884eab10bba", size = 128778, upload-time = "2026-04-28T23:22:57.847Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/65/fc/a85d1d9d7436e4ebccc11b1fc7c0be1287584f4bee3a9ed71d58da0d0b0a/data_designer_config-0.5.7-py3-none-any.whl", hash = "sha256:3d4d22c4d8e4b36189f62ef103122a108add1cfbbacf8afcfdd281e9458bd77d", size = 114479, upload-time = "2026-04-17T22:03:04.993Z" }, + { url = "https://files.pythonhosted.org/packages/a7/ea/e46d62abe4eeb2180bf1c512930501dcee85b606974fb4eebb0be5180c23/data_designer_config-0.5.9-py3-none-any.whl", hash = "sha256:13f0b3232db3f565c8eb759b08763bf66ed0da9f712efe451d242cb81ec396ee", size = 114597, upload-time = "2026-04-28T23:22:56.406Z" }, ] [[package]] name = "data-designer-engine" -version = "0.5.7" +version = "0.5.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyascii" }, @@ -417,11 +418,22 @@ dependencies = [ { name = "sqlfluff" }, { name = "tiktoken" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5c/35/a8abd88c44aa603bacff33d6983959b95bdc4c0116fc03460fb4ef04f803/data_designer_engine-0.5.7.tar.gz", hash = "sha256:f1dfeaad52a12fe12bf9796ae45dddb9d1eed82bdb02979d6cdab8c723631651", size = 794680, upload-time = "2026-04-17T22:03:10.25Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f7/c9/16a5d199fcb30f045cfc83e26a3a69dcf06c6e0387b13f98f43d8ed345e7/data_designer_engine-0.5.9.tar.gz", hash = "sha256:3adf2185404b156fe7ebc03d022eef2398c96f033a31cf96434b1eaa81f51cea", size = 799004, upload-time = "2026-04-28T23:23:01.811Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/35/d4/3844529ae989be9e63b0b8f47c28492793993427dc7d54d6d2a923ad2acc/data_designer_engine-0.5.7-py3-none-any.whl", hash = "sha256:75cd7d5ad0b230ddf75950ba7f97c9ad75c54887ad1247cdf623dc008e31a418", size = 631945, upload-time = "2026-04-17T22:03:08.584Z" }, + { url = "https://files.pythonhosted.org/packages/8d/9a/48816e9e83854c00b26c165a8fbb53701a5b845c8f403d19dad52418fe84/data_designer_engine-0.5.9-py3-none-any.whl", hash = "sha256:09b4db93304d7a1fc1500f562684cc9199e9b9d986fbb18f21609e152390b01a", size = 632446, upload-time = "2026-04-28T23:22:59.734Z" }, ] +[[package]] +name = "data-designer-generalist-agent-env" +version = "0.1.0" +source = { editable = "plugins/data-designer-generalist-agent-env" } +dependencies = [ + { name = "data-designer" }, +] + +[package.metadata] +requires-dist = [{ name = "data-designer", specifier = ">=0.5.9" }] + [[package]] name = "data-designer-plugins-workspace" version = "0.0.0" diff --git a/zensical.toml b/zensical.toml index 3f1af80..d2ac622 100644 --- a/zensical.toml +++ b/zensical.toml @@ -19,6 +19,10 @@ nav = [ {"Plugins" = [ {"Overview" = "plugins/index.md"}, # BEGIN GENERATED PLUGIN DOCS NAV + {"data-designer-generalist-agent-env" = [ + {"Overview" = "plugins/data-designer-generalist-agent-env/index.md"}, + {"Usage" = "plugins/data-designer-generalist-agent-env/usage.md"}, + ]}, {"data-designer-template" = [ {"Overview" = "plugins/data-designer-template/index.md"}, {"Usage" = "plugins/data-designer-template/usage.md"},