diff --git a/docs/concepts/architecture-and-performance.md b/docs/concepts/architecture-and-performance.md index 69f500c22..5d1545ad7 100644 --- a/docs/concepts/architecture-and-performance.md +++ b/docs/concepts/architecture-and-performance.md @@ -161,6 +161,7 @@ import data_designer.config as dd model = dd.ModelConfig( alias="my-model", model="nvidia/nemotron-3-nano-30b-a3b", + provider="nvidia", inference_parameters=dd.ChatCompletionInferenceParams( max_parallel_requests=8, ), diff --git a/docs/concepts/models/configure-model-settings-with-the-cli.md b/docs/concepts/models/configure-model-settings-with-the-cli.md index 9fa5b38db..e7c1b5edd 100644 --- a/docs/concepts/models/configure-model-settings-with-the-cli.md +++ b/docs/concepts/models/configure-model-settings-with-the-cli.md @@ -67,6 +67,12 @@ data-designer config providers **Change default provider**: Set which provider is used by default. This option is only available when multiple providers are configured. +!!! warning "Deprecated: 'Change default provider' workflow" + The "Change default provider" workflow is **deprecated** and will be removed in a future + release alongside the registry-level default. Specify `provider=` explicitly on each + `ModelConfig` instead — the workflow now emits a `DeprecationWarning` when entered. + See [issue #589](https://github.com/NVIDIA-NeMo/DataDesigner/issues/589). + ## Managing Model Configurations Run the interactive model configuration command: @@ -111,7 +117,7 @@ data-designer config list This command displays: - **Model Providers**: All configured providers with their endpoints (API keys are masked) -- **Default Provider**: The currently selected default provider +- **Default Provider**: The currently selected default provider _(deprecated; see issue #589)_ - **Model Configurations**: All configured models with their settings ## Resetting Configurations diff --git a/docs/concepts/models/custom-model-settings.md b/docs/concepts/models/custom-model-settings.md index 42e32d15c..be73ae408 100644 --- a/docs/concepts/models/custom-model-settings.md +++ b/docs/concepts/models/custom-model-settings.md @@ -90,6 +90,13 @@ preview_result.display_sample_record() !!! note "Default Providers Always Available" When you only specify `model_configs`, the default model providers (NVIDIA, OpenAI, and OpenRouter) are still available. You only need to create custom providers if you want to connect to different endpoints or modify provider settings. +!!! warning "Always specify `provider=` on `ModelConfig`" + Leaving `provider` unset (or passing `provider=None`) on `ModelConfig` is **deprecated**. + The legacy "implicit default provider" routing — used when `provider` is omitted — emits + a `DeprecationWarning` and will be removed in a future release. Always reference the + intended provider by name, as the examples below do. See + [issue #589](https://github.com/NVIDIA-NeMo/DataDesigner/issues/589). + !!! tip "Mixing Custom and Default Models" When you provide custom `model_configs` to `DataDesignerConfigBuilder`, they **replace** the defaults entirely. To use custom model configs in addition to the default configs, use the add_model_config method: diff --git a/docs/concepts/models/default-model-settings.md b/docs/concepts/models/default-model-settings.md index 8f6f6cd47..31fd14449 100644 --- a/docs/concepts/models/default-model-settings.md +++ b/docs/concepts/models/default-model-settings.md @@ -107,6 +107,13 @@ Both methods operate on the same files, ensuring consistency across your entire !!! warning "API Key Requirements" While default model configurations are always available, you need to set the appropriate API key environment variable (`NVIDIA_API_KEY`, `OPENAI_API_KEY`, or `OPENROUTER_API_KEY`) to actually use the corresponding models for data generation. Without a valid API key, any attempt to generate data using that provider's models will fail. +!!! warning "Deprecated: implicit default provider routing" + The `default:` key in `~/.data-designer/model_providers.yaml` and the registry-level + "default provider" concept are **deprecated** and will be removed in a future release. + Specify `provider=` explicitly on every `ModelConfig` instead — the built-in defaults + above already do this, and a `DeprecationWarning` is now emitted whenever the legacy + routing is exercised. See [issue #589](https://github.com/NVIDIA-NeMo/DataDesigner/issues/589). + !!! tip "Environment Variables" Store your API keys in environment variables rather than hardcoding them in your scripts: diff --git a/docs/concepts/models/inference-parameters.md b/docs/concepts/models/inference-parameters.md index e5772439d..03e932ed4 100644 --- a/docs/concepts/models/inference-parameters.md +++ b/docs/concepts/models/inference-parameters.md @@ -167,6 +167,7 @@ dd.ModelConfig( dd.ModelConfig( alias="dalle", model="dall-e-3", + provider="openai", inference_parameters=dd.ImageInferenceParams( extra_body={"size": "1024x1024", "quality": "hd"} ), diff --git a/docs/concepts/models/model-providers.md b/docs/concepts/models/model-providers.md index 9d397a87a..f8625ae9b 100644 --- a/docs/concepts/models/model-providers.md +++ b/docs/concepts/models/model-providers.md @@ -6,6 +6,14 @@ Model providers are external services that host and serve models. Data Designer A `ModelProvider` defines how Data Designer connects to a provider's API endpoint. When you create a `ModelConfig`, you reference a provider by name, and Data Designer uses that provider's settings to make API calls to the appropriate endpoint. +!!! warning "Deprecated: implicit default provider routing" + Earlier versions of Data Designer let you omit `provider=` on `ModelConfig` and + fall back to a registry-level default — including the `default:` key in + `~/.data-designer/model_providers.yaml`. That implicit routing is **deprecated** + and will be removed in a future release. Always reference a provider by name on + every `ModelConfig`. A `DeprecationWarning` is now emitted when the legacy path + is exercised. See [issue #589](https://github.com/NVIDIA-NeMo/DataDesigner/issues/589). + ## ModelProvider Configuration The `ModelProvider` class has the following fields: diff --git a/packages/data-designer-config/README.md b/packages/data-designer-config/README.md index 78eaa7933..3c7faaf9e 100644 --- a/packages/data-designer-config/README.md +++ b/packages/data-designer-config/README.md @@ -21,6 +21,7 @@ config_builder = dd.DataDesignerConfigBuilder( dd.ModelConfig( alias="my-model", model="nvidia/nemotron-3-nano-30b-a3b", + provider="nvidia", inference_parameters=dd.ChatCompletionInferenceParams(temperature=0.7), ), ] diff --git a/packages/data-designer-config/src/data_designer/config/default_model_settings.py b/packages/data-designer-config/src/data_designer/config/default_model_settings.py index d97a286a6..8a0366733 100644 --- a/packages/data-designer-config/src/data_designer/config/default_model_settings.py +++ b/packages/data-designer-config/src/data_designer/config/default_model_settings.py @@ -24,6 +24,7 @@ PREDEFINED_PROVIDERS_MODEL_MAP, ) from data_designer.config.utils.io_helpers import load_config_file, save_config_file +from data_designer.config.utils.warning_helpers import warn_at_caller logger = logging.getLogger(__name__) @@ -95,7 +96,28 @@ def get_default_providers() -> list[ModelProvider]: def get_default_provider_name() -> str | None: - return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default") + """Return the YAML's ``default:`` provider name, if set. + + Deprecated: this function and the underlying YAML key are deprecated and + will be removed in a future release. Specify ``provider=`` explicitly on + each ``ModelConfig`` instead. See issue #589. + """ + default = _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default") + if default is not None: + # ``warn_at_caller`` (rather than ``warnings.warn(stacklevel=2)``) so the + # warning attributes to the user's call site rather than this library + # module. The only real call path is ``DataDesigner.__init__``, which + # is itself a ``data_designer`` frame; under default Python filters, + # library-attributed ``DeprecationWarning`` entries are silenced + # (``ignore::DeprecationWarning``), so library attribution = invisible + # warning. See PR #594 review. + warn_at_caller( + f"The 'default:' key in {MODEL_PROVIDERS_FILE_PATH} is deprecated and will " + "be removed in a future release. Remove it and specify provider= explicitly " + "on each ModelConfig instead. See issue #589.", + DeprecationWarning, + ) + return default def resolve_seed_default_model_settings() -> None: diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index cfe6520df..9e3d8c44c 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -31,6 +31,7 @@ load_image_path_to_base64, ) from data_designer.config.utils.io_helpers import smart_load_yaml +from data_designer.config.utils.warning_helpers import warn_at_caller logger = logging.getLogger(__name__) @@ -503,7 +504,10 @@ class ModelConfig(ConfigBase): model: Model identifier (e.g., from build.nvidia.com or other providers). inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.). The generation_type is determined by the type of inference_parameters. - provider: Optional model provider name if using custom providers. + provider: Name of the model provider. Required in a future release. Leaving + ``provider`` unset (or ``None``) currently routes through the registry's + implicit default and is **deprecated**; specify ``provider=`` explicitly. + See issue #589. skip_health_check: Whether to skip the health check for this model. Defaults to False. """ @@ -535,6 +539,22 @@ def _convert_inference_parameters(cls, value: Any) -> Any: return ChatCompletionInferenceParams(**value) return value + @model_validator(mode="after") + def _warn_on_implicit_provider(self) -> Self: + if self.provider is None: + # Use ``warn_at_caller`` so the warning is attributed to the user's + # ``ModelConfig(...)`` / ``model_validate(...)`` call rather than a + # pydantic-internal frame. Without this, every call dedupes to the + # same pydantic line and only the first emission is shown. See + # PR #594 review. + warn_at_caller( + f"ModelConfig.provider=None is deprecated and will be required in a future release. " + f"Specify provider= explicitly on ModelConfig(alias={self.alias!r}, ...). " + "See issue #589.", + DeprecationWarning, + ) + return self + class ModelProvider(ConfigBase): """Configuration for a custom model provider. diff --git a/packages/data-designer-config/src/data_designer/config/testing/fixtures.py b/packages/data-designer-config/src/data_designer/config/testing/fixtures.py index 113ee3080..8fece3bf8 100644 --- a/packages/data-designer-config/src/data_designer/config/testing/fixtures.py +++ b/packages/data-designer-config/src/data_designer/config/testing/fixtures.py @@ -139,6 +139,7 @@ def stub_model_configs() -> list[ModelConfig]: ModelConfig( alias="stub-model", model="stub-model", + provider="provider-1", inference_parameters=ChatCompletionInferenceParams( temperature=0.9, top_p=0.9, diff --git a/packages/data-designer-config/src/data_designer/config/utils/warning_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/warning_helpers.py new file mode 100644 index 000000000..45bbbdecb --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/utils/warning_helpers.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Helpers for emitting warnings that attribute correctly to user code. + +Library-internal warnings (typically emitted from a pydantic ``@model_validator`` +or from a helper function) need to be attributed to the *user's* call site, not +to the library frame that happened to fire them. Two reasons: + +1. Attribution — a source location pointing at library internals isn't + actionable. +2. Visibility under default filters — Python's default ``DeprecationWarning`` + filter is:: + + default::DeprecationWarning:__main__ + ignore::DeprecationWarning + + Library-attributed ``DeprecationWarning`` entries fall under the second + filter and are silenced. Attributing to user code is what gets the warning + actually shown. + +3. Deduplication — Python's once-per-location dedup key is + ``(category, module, lineno)``. When every call resolves to the same + library-internal line, every warning after the first is silently suppressed + regardless of which user file triggered it. + +``warn_at_caller`` walks the stack past frames whose module belongs to a known +internal package (pydantic, data_designer) and uses ``warnings.warn_explicit`` +to attribute the warning at the first user frame. +""" + +from __future__ import annotations + +import sys +import warnings + +DEFAULT_INTERNAL_PREFIXES: tuple[str, ...] = ("pydantic", "pydantic_core", "data_designer") +"""Modules whose frames are skipped when walking up to the user's call site. + +Matching is exact-or-dotted-prefix (see ``_module_in_prefixes``), so +``pydantic_helpers`` is *not* treated as part of ``pydantic``.""" + + +def _module_in_prefixes(module_name: str, prefixes: tuple[str, ...]) -> bool: + """Return True if ``module_name`` belongs to one of the prefix-rooted packages. + + Uses exact-equality plus dotted-prefix matching so that, e.g., + ``pydantic_helpers`` is NOT treated as part of the ``pydantic`` package + while ``pydantic.fields`` is. Same for ``data_designer`` vs. a hypothetical + ``data_designer_other``. + """ + return any(module_name == prefix or module_name.startswith(prefix + ".") for prefix in prefixes) + + +def warn_at_caller( + message: str, + category: type[Warning], + *, + skip_prefixes: tuple[str, ...] = DEFAULT_INTERNAL_PREFIXES, +) -> None: + """Emit ``message`` attributed to the first frame outside ``skip_prefixes``. + + Intended for warnings whose root cause is the user's call site but whose + emission point is library code (a pydantic validator, an internal helper, + etc.). The walk starts above this helper's own frame and skips every frame + whose module belongs to a package in ``skip_prefixes`` until it reaches a + user frame. + + The default skip set covers: + + * ``pydantic`` / ``pydantic_core`` — so warnings emitted from + ``@model_validator`` callbacks escape pydantic's dispatch frames. + * ``data_designer`` — so warnings emitted from a registry / model-config + built deep inside a DataDesigner helper still attribute to the outermost + user call. Without this, attribution lands on a library file and Python's + default ``DeprecationWarning`` filter silences the warning entirely. + + The user frame's ``__warningregistry__`` is passed to + ``warnings.warn_explicit`` so Python's built-in once-per-location dedup keys + on the *user's* (filename, lineno) rather than an internal frame. + + We deliberately do *not* pass ``module_globals`` — it's only used for + ``linecache`` source-line display, and for scripts run with ``python -c`` + (where the user frame's ``__loader__`` is ``BuiltinImporter`` for + ``__main__``) the lookup raises ``ImportError("'__main__' is not a built-in + module")``. Skipping ``module_globals`` keeps the warning path robust at + the cost of an empty source line in the formatted output. + """ + frame = sys._getframe(1) if hasattr(sys, "_getframe") else None + while frame is not None: + module_name = frame.f_globals.get("__name__", "") + if not _module_in_prefixes(module_name, skip_prefixes): + warnings.warn_explicit( + message, + category, + frame.f_code.co_filename, + frame.f_lineno, + module=module_name, + registry=frame.f_globals.setdefault("__warningregistry__", {}), + ) + return + frame = frame.f_back + + # Fallback: never escaped library frames (or no frame access). Use stacklevel. + warnings.warn(message, category, stacklevel=3) diff --git a/packages/data-designer-config/tests/config/test_config_builder.py b/packages/data-designer-config/tests/config/test_config_builder.py index c81893d9d..01580ae5d 100644 --- a/packages/data-designer-config/tests/config/test_config_builder.py +++ b/packages/data-designer-config/tests/config/test_config_builder.py @@ -868,6 +868,7 @@ def test_add_model_config(stub_empty_builder): new_model_config = ModelConfig( alias="new-model", model="openai/gpt-4", + provider="openai", inference_parameters=ChatCompletionInferenceParams( temperature=0.7, top_p=0.95, @@ -915,6 +916,7 @@ def test_add_model_config_duplicate_alias(stub_empty_builder): duplicate_model_config = ModelConfig( alias="stub-model", model="different/model", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams(temperature=0.5), ) @@ -931,11 +933,13 @@ def test_delete_model_config(stub_empty_builder): model_config_1 = ModelConfig( alias="model-to-delete", model="model/delete", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams(temperature=0.5), ) model_config_2 = ModelConfig( alias="model-to-keep", model="model/keep", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams(temperature=0.6), ) stub_empty_builder.add_model_config(model_config_1) diff --git a/packages/data-designer-config/tests/config/test_default_model_settings.py b/packages/data-designer-config/tests/config/test_default_model_settings.py index 7df4c731a..1c144b7dd 100644 --- a/packages/data-designer-config/tests/config/test_default_model_settings.py +++ b/packages/data-designer-config/tests/config/test_default_model_settings.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import warnings from pathlib import Path from unittest.mock import patch @@ -142,19 +143,54 @@ def test_get_default_providers_path_does_not_exist(): def test_get_default_provider_name_with_default_key(tmp_path: Path): + """When the YAML carries a non-None ``default:``, the function must + return that value AND emit a ``DeprecationWarning`` (regression for #589). + """ providers_file_path = tmp_path / "providers.yaml" providers_file_path.write_text( json.dumps(dict(providers=[p.model_dump() for p in get_builtin_model_providers()], default="nvidia")) ) with patch("data_designer.config.default_model_settings.MODEL_PROVIDERS_FILE_PATH", new=providers_file_path): - assert get_default_provider_name() == "nvidia" + with pytest.warns(DeprecationWarning, match="'default:' key.*is deprecated"): + assert get_default_provider_name() == "nvidia" def test_get_default_provider_name_without_default_key(tmp_path: Path): + """Pin the post-deprecation happy path: a YAML without ``default:`` must + return ``None`` and NOT emit a ``DeprecationWarning``. + """ providers_file_path = tmp_path / "providers.yaml" providers_file_path.write_text(json.dumps({"providers": [p.model_dump() for p in get_builtin_model_providers()]})) with patch("data_designer.config.default_model_settings.MODEL_PROVIDERS_FILE_PATH", new=providers_file_path): - assert get_default_provider_name() is None + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + assert get_default_provider_name() is None + + +def test_get_default_provider_name_warning_attributes_to_user_frame(tmp_path: Path): + """Regression for PR #594 review (andreatgretel): the YAML-default warning + must attribute to the user's call site, not to ``default_model_settings.py``. + Python's default filter ignores library-attributed ``DeprecationWarning`` + entries, so the previous ``stacklevel=2`` attribution rendered the warning + invisible under default filters on the only real call path + (``DataDesigner.__init__``). See issue #589. + """ + providers_file_path = tmp_path / "providers.yaml" + providers_file_path.write_text( + json.dumps(dict(providers=[p.model_dump() for p in get_builtin_model_providers()], default="nvidia")) + ) + with patch("data_designer.config.default_model_settings.MODEL_PROVIDERS_FILE_PATH", new=providers_file_path): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + assert get_default_provider_name() == "nvidia" + + matches = [w for w in caught if "'default:' key" in str(w.message)] + assert len(matches) == 1, [str(w.message) for w in caught] + assert matches[0].filename == __file__, ( + f"Warning attributed to {matches[0].filename!r} (line {matches[0].lineno}) " + f"instead of the test file. Library-attributed DeprecationWarnings are " + f"silenced under default filters." + ) def test_get_default_provider_name_path_does_not_exist(): diff --git a/packages/data-designer-config/tests/config/test_fingerprint.py b/packages/data-designer-config/tests/config/test_fingerprint.py index 0535de704..a02ca3603 100644 --- a/packages/data-designer-config/tests/config/test_fingerprint.py +++ b/packages/data-designer-config/tests/config/test_fingerprint.py @@ -79,10 +79,11 @@ def test_fingerprint_deterministic_across_processes(stub_data_designer_config_st # --------------------------------------------------------------------------- -def _make_model(alias: str = "m", model: str = "some-model") -> ModelConfig: +def _make_model(alias: str = "m", model: str = "some-model", provider: str = "some-provider") -> ModelConfig: return ModelConfig( alias=alias, model=model, + provider=provider, inference_parameters=ChatCompletionInferenceParams(temperature=0.5, top_p=0.9, max_tokens=128), ) @@ -129,7 +130,7 @@ def test_changing_sampler_params_changes_hash() -> None: def test_changing_model_identity_changes_hash() -> None: a = _make_minimal_config() - b = _make_minimal_config(model_configs=[ModelConfig(alias="m", model="other-model")]) + b = _make_minimal_config(model_configs=[ModelConfig(alias="m", model="other-model", provider="some-provider")]) assert _compute_hash(a) != _compute_hash(b) @@ -140,6 +141,7 @@ def test_changing_temperature_changes_hash() -> None: ModelConfig( alias="m", model="some-model", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams(temperature=0.99, top_p=0.9, max_tokens=128), ) ], @@ -199,6 +201,7 @@ def test_changing_extra_body_changes_hash() -> None: ModelConfig( alias="m", model="some-model", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams( temperature=0.5, top_p=0.9, max_tokens=128, extra_body={"frequency_penalty": 0.5} ), @@ -302,6 +305,7 @@ def test_skip_health_check_does_not_change_hash() -> None: ModelConfig( alias="m", model="some-model", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams(temperature=0.5, top_p=0.9, max_tokens=128), skip_health_check=True, ) @@ -317,6 +321,7 @@ def test_max_parallel_requests_does_not_change_hash() -> None: ModelConfig( alias="m", model="some-model", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams( temperature=0.5, top_p=0.9, max_tokens=128, max_parallel_requests=32 ), @@ -333,6 +338,7 @@ def test_inference_timeout_does_not_change_hash() -> None: ModelConfig( alias="m", model="some-model", + provider="some-provider", inference_parameters=ChatCompletionInferenceParams( temperature=0.5, top_p=0.9, max_tokens=128, timeout=30 ), diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index 7bdc6ce91..c5bacd818 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -4,6 +4,7 @@ import base64 import json import tempfile +import warnings from collections import Counter from pathlib import Path @@ -413,8 +414,8 @@ def test_generation_parameters_max_tokens_validation(): def test_load_model_configs(): stub_model_configs = [ - ModelConfig(alias="test", model="test"), - ModelConfig(alias="test2", model="test2"), + ModelConfig(alias="test", model="test", provider="test-provider"), + ModelConfig(alias="test2", model="test2", provider="test-provider"), ] stub_model_configs_dict_list = [mc.model_dump(mode="json") for mc in stub_model_configs] assert load_model_configs([]) == [] @@ -454,35 +455,72 @@ def test_load_model_configs(): def test_model_config_construction(): # test default construction - model_config = ModelConfig(alias="test", model="test") + model_config = ModelConfig(alias="test", model="test", provider="test-provider") assert model_config.inference_parameters == ChatCompletionInferenceParams() assert model_config.generation_type == GenerationType.CHAT_COMPLETION # test construction with completion inference parameters completion_params = ChatCompletionInferenceParams(temperature=0.5, top_p=0.5, max_tokens=100) - model_config = ModelConfig(alias="test", model="test", inference_parameters=completion_params) + model_config = ModelConfig( + alias="test", model="test", provider="test-provider", inference_parameters=completion_params + ) assert model_config.inference_parameters == completion_params assert model_config.generation_type == GenerationType.CHAT_COMPLETION # test construction with embedding inference parameters embedding_params = EmbeddingInferenceParams(dimensions=100) - model_config = ModelConfig(alias="test", model="test", inference_parameters=embedding_params) + model_config = ModelConfig( + alias="test", model="test", provider="test-provider", inference_parameters=embedding_params + ) assert model_config.inference_parameters == embedding_params assert model_config.generation_type == GenerationType.EMBEDDING # test construction with image inference parameters image_params = ImageInferenceParams(extra_body={"size": "1024x1024", "quality": "hd"}) - model_config = ModelConfig(alias="test", model="test", inference_parameters=image_params) + model_config = ModelConfig(alias="test", model="test", provider="test-provider", inference_parameters=image_params) assert model_config.inference_parameters == image_params assert model_config.generation_type == GenerationType.IMAGE +def test_model_config_provider_none_emits_deprecation_warning(): + """Regression for #589: omitting ``provider=`` (or passing ``provider=None``) + on a ``ModelConfig`` is deprecated; construction must emit a + ``DeprecationWarning`` pointing users at the explicit-provider migration. + """ + with pytest.warns(DeprecationWarning, match="ModelConfig.provider=None is deprecated"): + ModelConfig(alias="legacy", model="legacy-model") + + with pytest.warns(DeprecationWarning, match="ModelConfig.provider=None is deprecated"): + ModelConfig(alias="legacy", model="legacy-model", provider=None) + + +def test_model_config_provider_none_via_model_validate_emits_deprecation_warning(): + """Regression for #589 / PR #594 review: deserialising legacy on-disk configs + via ``ModelConfig.model_validate(...)`` must surface the same + ``DeprecationWarning`` as direct construction. Both paths funnel through + the same validator today, so this pin protects against a future refactor + that, e.g., only runs the validator on construction and not on revalidation. + """ + with pytest.warns(DeprecationWarning, match="ModelConfig.provider=None is deprecated"): + ModelConfig.model_validate({"alias": "legacy", "model": "legacy-model"}) + + +def test_model_config_with_provider_does_not_warn(): + """Pin the post-deprecation happy path: specifying ``provider=`` must not + emit any deprecation warning. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + ModelConfig(alias="modern", model="modern-model", provider="some-provider") + + def test_model_config_generation_type_from_dict(): # Test that generation_type in dict is used to create the right inference params type model_config = ModelConfig.model_validate( { "alias": "test", "model": "test", + "provider": "test-provider", "inference_parameters": {"generation_type": "embedding", "dimensions": 100}, } ) @@ -493,6 +531,7 @@ def test_model_config_generation_type_from_dict(): { "alias": "test", "model": "test", + "provider": "test-provider", "inference_parameters": {"generation_type": "chat-completion", "temperature": 0.5}, } ) @@ -503,6 +542,7 @@ def test_model_config_generation_type_from_dict(): { "alias": "test", "model": "image-model", + "provider": "test-provider", "inference_parameters": { "generation_type": "image", "extra_body": {"size": "1024x1024", "quality": "hd"}, diff --git a/packages/data-designer-config/tests/config/utils/test_warning_helpers.py b/packages/data-designer-config/tests/config/utils/test_warning_helpers.py new file mode 100644 index 000000000..9df72c588 --- /dev/null +++ b/packages/data-designer-config/tests/config/utils/test_warning_helpers.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import warnings + +from data_designer.config.utils.warning_helpers import _module_in_prefixes, warn_at_caller + + +def test_module_in_prefixes_exact_match(): + assert _module_in_prefixes("pydantic", ("pydantic",)) is True + + +def test_module_in_prefixes_dotted_submodule(): + assert _module_in_prefixes("pydantic.fields", ("pydantic",)) is True + assert _module_in_prefixes("data_designer.config.models", ("data_designer",)) is True + + +def test_module_in_prefixes_rejects_prefix_collision(): + """Regression for PR #594 review (johnnygreco): ``startswith`` matching + naively on the prefix would silently treat ``pydantic_helpers`` as part of + the ``pydantic`` package. Anchor on exact-or-dotted-prefix instead. + """ + assert _module_in_prefixes("pydantic_helpers", ("pydantic",)) is False + assert _module_in_prefixes("pydanticfoo", ("pydantic",)) is False + assert _module_in_prefixes("data_designer_other", ("data_designer",)) is False + + +def test_warn_at_caller_attributes_to_direct_caller(): + """When called from a non-skipped module, the warning attributes to the + caller's frame. + """ + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + warn_at_caller("hello", DeprecationWarning) # line anchored below + + assert len(caught) == 1 + assert caught[0].filename == __file__ + assert "hello" in str(caught[0].message) + + +def test_warn_at_caller_skips_skip_prefix_frames(): + """The walk should escape any frame whose module is listed in + ``skip_prefixes`` and attribute to the first frame outside them. We + simulate a library frame by ``exec``-ing a helper with a fake module name + in its globals; calling that helper produces a frame whose + ``f_globals["__name__"]`` is the fake name, mirroring how a real library + frame would appear during the walk. + """ + library_globals: dict[str, object] = { + "__name__": "fake_library.dispatch", + "warn_at_caller": warn_at_caller, + "DeprecationWarning": DeprecationWarning, + } + exec( + "def emit():\n warn_at_caller('from-library', DeprecationWarning, skip_prefixes=('fake_library',))\n", + library_globals, + ) + emit = library_globals["emit"] + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + emit() + + assert len(caught) == 1 + assert caught[0].filename == __file__, f"Expected attribution at {__file__!r}, got {caught[0].filename!r}" + + +def test_warn_at_caller_default_skips_pydantic_and_data_designer(): + """Default ``skip_prefixes`` should cover both pydantic and data_designer + so warnings emitted from validators inside DataDesigner internals attribute + to the user, not to either library. + """ + from data_designer.config.utils.warning_helpers import DEFAULT_INTERNAL_PREFIXES + + assert "pydantic" in DEFAULT_INTERNAL_PREFIXES + assert "data_designer" in DEFAULT_INTERNAL_PREFIXES + + +def test_warn_at_caller_dedup_keys_on_user_call_site(): + """Python's once-per-location dedup keys on (text, category, lineno) inside + the *attributing* frame's ``__warningregistry__``. With proper user + attribution, two distinct call sites in the user's file each emit a + warning under ``default`` filtering, while a repeated call at the same + site emits only the first. + """ + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("default", DeprecationWarning) + warn_at_caller("dedup-test", DeprecationWarning) # site A + warn_at_caller("dedup-test", DeprecationWarning) # site B + + linenos = {w.lineno for w in caught} + assert len(caught) == 2, [str(w.message) for w in caught] + assert len(linenos) == 2, "Each call site should produce a distinct dedup key" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 985ca2c09..8018ef648 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -9,7 +9,6 @@ import os import time import uuid -import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable @@ -23,6 +22,7 @@ ProcessorConfig, ProcessorType, ) +from data_designer.config.utils.warning_helpers import warn_at_caller from data_designer.config.version import get_library_version from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, @@ -327,7 +327,17 @@ def _resolve_async_compatibility(self) -> bool: "use workflow chaining instead (see issue #552)." ) logger.warning(f"⚠️ {msg}") - warnings.warn(msg, DeprecationWarning, stacklevel=4) + # ``warn_at_caller`` rather than ``warnings.warn(stacklevel=N)`` so + # attribution lands on the user's call site instead of an internal + # ``DatasetBuilder.build`` / ``data_designer.interface`` frame. + # The exact internal-frame depth from this method up to user code + # depends on which entry point invoked the builder (build vs. + # build_preview, sync vs. async wrapping), so a hard-coded + # ``stacklevel`` is brittle; ``warn_at_caller`` walks past every + # ``data_designer.*`` frame regardless of chain shape. Library + # attribution would also be silenced under Python's default + # ``ignore::DeprecationWarning`` filter. See PR #594 review. + warn_at_caller(msg, DeprecationWarning) return False return True diff --git a/packages/data-designer-engine/src/data_designer/engine/model_provider.py b/packages/data-designer-engine/src/data_designer/engine/model_provider.py index b34442fd5..d3d8dc7be 100644 --- a/packages/data-designer-engine/src/data_designer/engine/model_provider.py +++ b/packages/data-designer-engine/src/data_designer/engine/model_provider.py @@ -10,12 +10,16 @@ from data_designer.config.mcp import MCPProviderT from data_designer.config.models import ModelProvider +from data_designer.config.utils.warning_helpers import warn_at_caller from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError class ModelProviderRegistry(BaseModel): providers: list[ModelProvider] default: str | None = None + """Deprecated: registry-level default provider. Will be removed in a future + release; specify ``provider=`` explicitly on each ``ModelConfig`` instead. + See issue #589.""" @field_validator("providers", mode="after") @classmethod @@ -50,6 +54,26 @@ def check_default_exists(self) -> Self: raise ValueError(f"Specified default {self.default!r} not found in providers list") return self + @model_validator(mode="after") + def _warn_on_explicit_default(self) -> Self: + # Fires only when the caller actually passed a non-None ``default=``. + # The ``model_fields_set`` guard distinguishes "caller opted into the + # deprecated field" from "field at its default value of None", and the + # ``self.default is not None`` clause additionally lets callers + # explicitly opt *out* via ``default=None`` without tripping the + # warning. ``resolve_model_provider_registry`` avoids passing + # ``default=`` in the single-provider case so common construction paths + # stay quiet. ``warn_at_caller`` keeps attribution and dedup correct + # under pydantic's validator dispatch. See issue #589 / PR #594 review. + if "default" in self.model_fields_set and self.default is not None: + warn_at_caller( + "ModelProviderRegistry.default is deprecated and will be removed in a " + "future release. Specify provider= explicitly on each ModelConfig " + "instead of relying on a registry-level default. See issue #589.", + DeprecationWarning, + ) + return self + def get_default_provider_name(self) -> str: return self.default or self.providers[0].name @@ -72,6 +96,15 @@ def resolve_model_provider_registry( ) -> ModelProviderRegistry: if len(model_providers) == 0: raise NoModelProvidersError("At least one model provider must be defined") + # In the single-provider case, the registry's ``get_default_provider_name`` + # falls back to ``providers[0].name`` when ``default`` is unset, so we can + # avoid passing ``default=`` and keep the common construction path quiet + # under the #589 deprecation warning. The multi-provider case still + # requires ``default`` (per ``check_implicit_default``); callers who supply + # multiple providers with no explicit default fall back to first-wins, + # matching the contract pinned in #588. + if len(model_providers) == 1 and default_provider_name is None: + return ModelProviderRegistry(providers=model_providers) return ModelProviderRegistry( providers=model_providers, default=default_provider_name or model_providers[0].name, diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 846089095..dab109dbb 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -101,6 +101,14 @@ def test_resolve_async_compatibility(configs: list[Mock], expected: bool) -> Non assert len(w) == 1 assert issubclass(w[0].category, DeprecationWarning) assert "allow_resize" in str(w[0].message) + # Regression for PR #594 review: the warning must attribute to the + # caller's frame (this test file), not to a ``data_designer.*`` library + # frame. Library-attributed ``DeprecationWarning`` entries fall under + # Python's default ``ignore::DeprecationWarning`` filter and are + # silenced. A regression to ``warnings.warn(..., stacklevel=N)`` would + # land somewhere inside the engine package and silently break the + # user-facing nudge. + assert w[0].filename == __file__ else: assert len(w) == 0 diff --git a/packages/data-designer-engine/tests/engine/test_model_provider.py b/packages/data-designer-engine/tests/engine/test_model_provider.py index 55f750a76..72006f824 100644 --- a/packages/data-designer-engine/tests/engine/test_model_provider.py +++ b/packages/data-designer-engine/tests/engine/test_model_provider.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import warnings + import pytest from data_designer.config.mcp import LocalStdioMCPProvider @@ -56,10 +58,14 @@ def test_no_duplicate_provider_names(stub_foo_provider: ModelProvider): def test_get_provider(stub_foo_provider: ModelProvider, stub_bar_provider: ModelProvider): - registry = ModelProviderRegistry( - providers=[stub_foo_provider, stub_bar_provider], - default="foo", - ) + # Multi-provider construction with an explicit default exercises the #589 + # deprecation path; wrap so this test stays green if the project ever runs + # with ``-W error::DeprecationWarning``. + with pytest.warns(DeprecationWarning, match="ModelProviderRegistry.default is deprecated"): + registry = ModelProviderRegistry( + providers=[stub_foo_provider, stub_bar_provider], + default="foo", + ) assert registry.get_provider(None) == stub_foo_provider assert registry.get_provider("foo") == stub_foo_provider @@ -80,8 +86,15 @@ def test_resolve_model_provider_registry(stub_foo_provider: ModelProvider) -> No def test_resolve_model_provider_registry_with_explicit_default( stub_foo_provider: ModelProvider, stub_bar_provider: ModelProvider ) -> None: - """Test resolve_model_provider_registry with explicit default.""" - registry = resolve_model_provider_registry([stub_foo_provider, stub_bar_provider], default_provider_name="bar") + """Test resolve_model_provider_registry with explicit default. + + The multi-provider/explicit-default path is the deprecated one (see #589), + so the construction emits a ``DeprecationWarning``. Wrap the call in + ``pytest.warns`` so this test stays green if the project ever runs under + ``-W error::DeprecationWarning``. + """ + with pytest.warns(DeprecationWarning, match="ModelProviderRegistry.default is deprecated"): + registry = resolve_model_provider_registry([stub_foo_provider, stub_bar_provider], default_provider_name="bar") assert registry.get_default_provider_name() == "bar" @@ -92,6 +105,86 @@ def test_resolve_model_provider_registry_empty_error() -> None: resolve_model_provider_registry([]) +def test_explicit_default_emits_deprecation_warning(stub_foo_provider: ModelProvider) -> None: + """Regression for #589: passing ``default=`` explicitly to ``ModelProviderRegistry`` + must emit a ``DeprecationWarning``. The registry-level default field is on its + way out; users should specify ``provider=`` per ``ModelConfig`` instead. + """ + with pytest.warns(DeprecationWarning, match="ModelProviderRegistry.default is deprecated"): + ModelProviderRegistry(providers=[stub_foo_provider], default="foo") + + +def test_no_default_does_not_emit_deprecation_warning(stub_foo_provider: ModelProvider) -> None: + """Pin the post-deprecation happy path: omitting ``default=`` (single-provider + case) must NOT emit a warning, since callers haven't opted into the deprecated + field. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + ModelProviderRegistry(providers=[stub_foo_provider]) + + +def test_explicit_default_none_does_not_emit_deprecation_warning(stub_foo_provider: ModelProvider) -> None: + """Pin the predicate tightening from PR #594 review: passing ``default=None`` + explicitly is semantically equivalent to omitting it (caller is opting *out* + of a registry-level default), so the deprecation must NOT fire. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + ModelProviderRegistry(providers=[stub_foo_provider], default=None) + + +def test_explicit_default_warning_attributes_to_user_frame( + stub_foo_provider: ModelProvider, stub_bar_provider: ModelProvider +) -> None: + """Regression for PR #594 review (andreatgretel): the ``default=`` deprecation + warning must attribute to the *user's* call site, not the pydantic-internal + or ``data_designer`` library frame that emits it. Library-attributed + ``DeprecationWarning`` entries are silenced under Python's default + ``ignore::DeprecationWarning`` filter, so attribution determines whether + the warning is actually visible. + + Construction goes through ``resolve_model_provider_registry`` so the walk + has to escape both pydantic (validator dispatch) and ``data_designer`` + (the helper that builds the registry) before landing on the test frame. + """ + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + resolve_model_provider_registry([stub_foo_provider, stub_bar_provider], default_provider_name="bar") + + matches = [w for w in caught if "ModelProviderRegistry.default is deprecated" in str(w.message)] + assert len(matches) == 1, [str(w.message) for w in caught] + assert matches[0].filename == __file__, ( + f"Warning attributed to {matches[0].filename!r} (line {matches[0].lineno}) " + f"instead of the test file. Library-attributed DeprecationWarnings are " + f"silenced under default filters." + ) + + +def test_resolve_single_provider_quiet_under_deprecation(stub_foo_provider: ModelProvider) -> None: + """Pin the q3 tweak: ``resolve_model_provider_registry`` skips ``default=`` + in the single-provider case so common construction paths stay quiet under + the #589 deprecation warning. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + registry = resolve_model_provider_registry([stub_foo_provider]) + + assert registry.get_default_provider_name() == "foo" + + +def test_resolve_multi_provider_emits_deprecation_warning( + stub_foo_provider: ModelProvider, stub_bar_provider: ModelProvider +) -> None: + """Multi-provider registries currently require ``default``, so + ``resolve_model_provider_registry`` keeps passing it. That construction + path is the deprecated one users should migrate off; the warning fires + accordingly. + """ + with pytest.warns(DeprecationWarning, match="ModelProviderRegistry.default is deprecated"): + resolve_model_provider_registry([stub_foo_provider, stub_bar_provider]) + + def test_mcp_provider_registry_empty() -> None: """Test MCPProviderRegistry can be created empty.""" registry = MCPProviderRegistry() diff --git a/packages/data-designer/src/data_designer/cli/controllers/provider_controller.py b/packages/data-designer/src/data_designer/cli/controllers/provider_controller.py index 165422436..c72087036 100644 --- a/packages/data-designer/src/data_designer/cli/controllers/provider_controller.py +++ b/packages/data-designer/src/data_designer/cli/controllers/provider_controller.py @@ -27,6 +27,7 @@ print_warning, select_with_arrows, ) +from data_designer.config.utils.warning_helpers import warn_at_caller if TYPE_CHECKING: from data_designer.engine.model_provider import ModelProvider @@ -288,6 +289,23 @@ def _handle_delete_all(self) -> None: def _handle_change_default(self) -> None: """Handle changing the default provider.""" + deprecation_msg = ( + "The 'Change default provider' workflow is deprecated and will be removed " + "in a future release. Specify provider= explicitly on each ModelConfig " + "instead of relying on a registry-level default. See issue #589." + ) + print_warning(deprecation_msg) + # ``print_warning`` always shows the user the message in the console, + # but ``warnings.warn`` is what's observable to programmatic callers + # (``pytest.warns``, ``filterwarnings("error", ...)``). With + # ``stacklevel=2`` attribution lands on the menu dispatcher in this + # same module — a ``data_designer.cli.*`` frame — and Python's default + # ``ignore::DeprecationWarning`` filter silences it. ``warn_at_caller`` + # walks past every ``data_designer.*`` frame so the warning attributes + # to the user's call site and stays visible. See PR #594 review. + warn_at_caller(deprecation_msg, DeprecationWarning) + console.print() + providers = self.service.list_all() current_default = self.service.get_default() diff --git a/packages/data-designer/src/data_designer/cli/repositories/provider_repository.py b/packages/data-designer/src/data_designer/cli/repositories/provider_repository.py index a2b692fb3..f815ef127 100644 --- a/packages/data-designer/src/data_designer/cli/repositories/provider_repository.py +++ b/packages/data-designer/src/data_designer/cli/repositories/provider_repository.py @@ -11,6 +11,7 @@ from data_designer.config.models import ModelProvider from data_designer.config.utils.constants import MODEL_PROVIDERS_FILE_NAME from data_designer.config.utils.io_helpers import load_config_file, save_config_file +from data_designer.config.utils.warning_helpers import warn_at_caller class ModelProviderRegistry(BaseModel): @@ -35,6 +36,27 @@ def load(self) -> ModelProviderRegistry | None: try: config_dict = load_config_file(self.config_file) + except Exception: + return None + + # Emit the deprecation warning *outside* the validation try/except below. + # ``DeprecationWarning`` is an ``Exception`` subclass, so under + # ``filterwarnings("error", DeprecationWarning)`` a warn raised inside + # the catch-all would be silently swallowed and ``load`` would drop the + # registry. ``warn_at_caller`` (rather than ``warnings.warn(stacklevel=2)``) + # so the warning attributes to the user's call site rather than a + # ``data_designer.cli.*`` frame; under default Python filters, + # library-attributed ``DeprecationWarning`` entries are silenced + # (``ignore::DeprecationWarning``). See PR #594 review. + if config_dict.get("default") is not None: + warn_at_caller( + f"The 'default:' key in {self.config_file} is deprecated and will " + "be removed in a future release. Remove it and specify provider= " + "explicitly on each ModelConfig instead. See issue #589.", + DeprecationWarning, + ) + + try: return ModelProviderRegistry.model_validate(config_dict) except Exception: return None diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index cd928f611..242131ba1 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -165,7 +165,21 @@ def __init__( self._model_providers = self._resolve_model_providers(model_providers) default_provider_name = None self._mcp_providers = mcp_providers or [] - self._model_provider_registry = resolve_model_provider_registry(self._model_providers, default_provider_name) + # When the YAML carries a default, ``get_default_provider_name`` already + # nudged the user with a ``DeprecationWarning``. Building the registry + # below would re-fire ``ModelProviderRegistry._warn_on_explicit_default`` + # for the same root cause, so suppress that second warning. See PR #594 + # review. + with warnings.catch_warnings(): + if default_provider_name is not None: + warnings.filterwarnings( + "ignore", + message="ModelProviderRegistry.default is deprecated", + category=DeprecationWarning, + ) + self._model_provider_registry = resolve_model_provider_registry( + self._model_providers, default_provider_name + ) self._seed_reader_registry = SeedReaderRegistry(readers=seed_readers or DEFAULT_SEED_READERS) @property diff --git a/packages/data-designer/tests/cli/controllers/test_provider_controller.py b/packages/data-designer/tests/cli/controllers/test_provider_controller.py index fc83ad45a..265f393a3 100644 --- a/packages/data-designer/tests/cli/controllers/test_provider_controller.py +++ b/packages/data-designer/tests/cli/controllers/test_provider_controller.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import warnings from pathlib import Path from unittest.mock import MagicMock, patch @@ -208,6 +209,38 @@ def test_run_changes_default_provider( assert controller_with_providers.service.get_default() == "test-provider-2" +@patch("data_designer.cli.controllers.provider_controller.select_with_arrows") +def test_handle_change_default_emits_deprecation_warning( + mock_select: MagicMock, + controller_with_providers: ProviderController, +) -> None: + """Regression for #589: entering the 'Change default provider' workflow + must emit a ``DeprecationWarning`` so users see the migration nudge before + setting another value that's also slated for removal. + + Also pins the attribution contract from PR #594 review: the warning must + land on the caller's frame (this test file), not on a + ``data_designer.cli.*`` library frame. Library attribution falls under + Python's default ``ignore::DeprecationWarning`` filter and would silently + suppress the user-facing nudge for any caller that isn't using + ``simplefilter("always")``. + """ + mock_select.side_effect = ["change_default", "test-provider-2"] + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + controller_with_providers.run() + + deprecations = [ + w + for w in caught + if issubclass(w.category, DeprecationWarning) + and "'Change default provider' workflow is deprecated" in str(w.message) + ] + assert len(deprecations) == 1 + assert deprecations[0].filename == __file__ + + @patch("data_designer.cli.controllers.provider_controller.confirm_action", return_value=False) @patch("data_designer.cli.controllers.provider_controller.select_with_arrows") def test_run_respects_delete_cancellation( diff --git a/packages/data-designer/tests/cli/repositories/test_provider_repository.py b/packages/data-designer/tests/cli/repositories/test_provider_repository.py index 0becfb54a..e62a59ccd 100644 --- a/packages/data-designer/tests/cli/repositories/test_provider_repository.py +++ b/packages/data-designer/tests/cli/repositories/test_provider_repository.py @@ -1,8 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import warnings from pathlib import Path +import pytest + from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository from data_designer.config.models import ModelProvider from data_designer.config.utils.constants import MODEL_PROVIDERS_FILE_NAME @@ -20,10 +23,14 @@ def test_load_does_not_exist(): def test_load_exists(tmp_path: Path, stub_model_providers: list[ModelProvider]): + # Roundtrip test for the load/save cycle. We deliberately leave ``default`` + # unset so this test does not exercise the deprecated YAML ``default:`` path + # — that path is covered by ``test_load_with_yaml_default_emits_deprecation_warning`` + # below. See issue #589. providers_file_path = tmp_path / MODEL_PROVIDERS_FILE_NAME save_config_file( providers_file_path, - ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name).model_dump(), + ModelProviderRegistry(providers=stub_model_providers).model_dump(exclude_none=True), ) repository = ProviderRepository(tmp_path) assert repository.load() is not None @@ -31,7 +38,76 @@ def test_load_exists(tmp_path: Path, stub_model_providers: list[ModelProvider]): def test_save(tmp_path: Path, stub_model_providers: list[ModelProvider]): + # As above, leave ``default`` unset so the roundtrip stays clear of the + # YAML-default deprecation. See issue #589. repository = ProviderRepository(tmp_path) - repository.save(ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name)) + repository.save(ModelProviderRegistry(providers=stub_model_providers)) assert repository.load() is not None assert repository.load().providers == stub_model_providers + + +def test_load_with_yaml_default_emits_deprecation_warning( + tmp_path: Path, stub_model_providers: list[ModelProvider] +) -> None: + """Regression for #589: when the on-disk providers YAML carries a non-None + ``default:`` key, ``ProviderRepository.load`` must emit a + ``DeprecationWarning`` so users see the migration nudge regardless of which + entry point reads the file. + """ + providers_file_path = tmp_path / MODEL_PROVIDERS_FILE_NAME + save_config_file( + providers_file_path, + ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name).model_dump(), + ) + repository = ProviderRepository(tmp_path) + + with pytest.warns(DeprecationWarning, match="'default:' key.*is deprecated"): + registry = repository.load() + assert registry is not None + assert registry.default == stub_model_providers[0].name + + +def test_load_with_yaml_default_attributes_warning_to_caller( + tmp_path: Path, stub_model_providers: list[ModelProvider] +) -> None: + """Regression for PR #594 review: the YAML-default ``DeprecationWarning`` + must attribute to the *caller's* frame (this test file), not to a + ``data_designer.cli.*`` library frame. Library-attributed + ``DeprecationWarning`` entries fall under Python's default + ``ignore::DeprecationWarning`` filter and are silenced, so attribution at + a library frame == invisible warning. ``warn_at_caller`` keeps this + visible; a regression to ``warnings.warn(stacklevel=2)`` would land on + ``provider_repository.py`` and silently break the user nudge. + """ + providers_file_path = tmp_path / MODEL_PROVIDERS_FILE_NAME + save_config_file( + providers_file_path, + ModelProviderRegistry(providers=stub_model_providers, default=stub_model_providers[0].name).model_dump(), + ) + repository = ProviderRepository(tmp_path) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + repository.load() + + deprecations = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecations) == 1 + assert deprecations[0].filename == __file__ + + +def test_load_without_yaml_default_does_not_warn(tmp_path: Path, stub_model_providers: list[ModelProvider]) -> None: + """Pin the post-deprecation happy path: a YAML without a ``default:`` key + must load cleanly with no ``DeprecationWarning``. + """ + providers_file_path = tmp_path / MODEL_PROVIDERS_FILE_NAME + save_config_file( + providers_file_path, + ModelProviderRegistry(providers=stub_model_providers).model_dump(exclude_none=True), + ) + repository = ProviderRepository(tmp_path) + + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + registry = repository.load() + assert registry is not None + assert registry.default is None diff --git a/packages/data-designer/tests/interface/test_data_designer.py b/packages/data-designer/tests/interface/test_data_designer.py index d382f0fae..c98987103 100644 --- a/packages/data-designer/tests/interface/test_data_designer.py +++ b/packages/data-designer/tests/interface/test_data_designer.py @@ -544,7 +544,13 @@ def test_init_user_supplied_providers_preserve_first_wins_over_yaml_default( ), ] - with patch.object(dd_mod, "get_default_provider_name", return_value="second-provider"): + # Multi-provider construction (user-supplied list of length > 1) still + # passes ``default=`` to ``ModelProviderRegistry`` — that's the deprecated + # path under #589 — so the registry-level deprecation fires here. + with ( + patch.object(dd_mod, "get_default_provider_name", return_value="second-provider"), + pytest.warns(DeprecationWarning, match="ModelProviderRegistry.default is deprecated"), + ): data_designer = DataDesigner( artifact_path=stub_artifact_path, model_providers=user_providers, @@ -593,6 +599,61 @@ def test_init_no_user_providers_uses_yaml_default( assert data_designer.model_provider_registry.get_default_provider_name() == "yaml-second" +def test_init_yaml_default_emits_single_deprecation_warning( + stub_artifact_path: Path, + stub_managed_assets_path: Path, +) -> None: + """Regression for PR #594 review: when ``DataDesigner()`` falls back to the + YAML's ``providers:`` and ``default:``, the user should see a single + ``DeprecationWarning`` (the YAML one) rather than a duplicate cascade where + ``ModelProviderRegistry._warn_on_explicit_default`` also fires for the same + root cause. See issue #589. + """ + yaml_providers = [ + ModelProvider( + name="yaml-first", + endpoint="https://yaml-first.example.com/v1", + api_key="yaml-first-key", + ), + ModelProvider( + name="yaml-second", + endpoint="https://yaml-second.example.com/v1", + api_key="yaml-second-key", + ), + ] + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", DeprecationWarning) + with ( + patch.object(dd_mod, "get_default_providers", return_value=yaml_providers), + patch.object(dd_mod, "get_default_provider_name") as mock_get_default, + ): + mock_get_default.side_effect = lambda: ( + warnings.warn( + "The 'default:' key in /fake/path is deprecated and will " + "be removed in a future release. Remove it and specify provider= " + "explicitly on each ModelConfig instead. See issue #589.", + DeprecationWarning, + stacklevel=2, + ) + or "yaml-second" + ) + DataDesigner( + artifact_path=stub_artifact_path, + secret_resolver=PlaintextResolver(), + managed_assets_path=stub_managed_assets_path, + ) + + deprecation_messages = [str(w.message) for w in caught if issubclass(w.category, DeprecationWarning)] + yaml_default_warnings = [m for m in deprecation_messages if "'default:' key" in m] + registry_default_warnings = [m for m in deprecation_messages if "ModelProviderRegistry.default is deprecated" in m] + assert len(yaml_default_warnings) == 1, deprecation_messages + assert registry_default_warnings == [], ( + "Registry-level deprecation should be suppressed in the YAML-fallback path " + "to avoid two warnings for the same root cause." + ) + + def test_run_config_setting_persists(stub_artifact_path, stub_model_providers): """Test that run config setting persists across multiple calls.""" data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers)