Skip to content

Commit dbeb66d

Browse files
Deeven-Seruanubhav756
authored andcommitted
fix(toolbox-core): honor explicit defaults and clean lint issues
1 parent ebc18a5 commit dbeb66d

5 files changed

Lines changed: 42 additions & 3 deletions

File tree

packages/toolbox-core/src/toolbox_core/protocol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ class ParameterSchema(BaseModel):
8383
additionalProperties: Optional[Union[bool, AdditionalPropertiesSchema]] = None
8484
default: Optional[Any] = None
8585

86+
@property
87+
def has_default(self) -> bool:
88+
"""Returns True if `default` was explicitly provided in schema input."""
89+
return "default" in self.model_fields_set
90+
8691
def __get_type(self) -> Type:
8792
base_type: Type
8893
if self.type == "array":
@@ -105,7 +110,7 @@ def __get_type(self) -> Type:
105110

106111
def to_param(self) -> Parameter:
107112
default_value = Parameter.empty
108-
if self.default is not None:
113+
if self.has_default:
109114
default_value = self.default
110115
elif not self.required:
111116
default_value = None

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import copy
16-
import itertools
1716
from collections import OrderedDict
1817
from inspect import Parameter, Signature
1918
from types import MappingProxyType

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def params_to_pydantic_model(
115115
# '...' (Ellipsis) signifies a required field in Pydantic.
116116
# If a default value is provided in the schema, it should be used.
117117
default_value = ... if field.required else None
118-
if field.default is not None:
118+
if field.has_default:
119119
default_value = field.default
120120

121121
field_definitions[field.name] = cast(

packages/toolbox-core/tests/test_protocol.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def test_parameter_schema_map_unsupported_value_type_error():
290290
with pytest.raises(ValueError, match=expected_error_msg):
291291
schema._ParameterSchema__get_type()
292292

293+
293294
def test_parameter_schema_with_default():
294295
"""Tests ParameterSchema with a default value provided."""
295296
schema = ParameterSchema(
@@ -326,3 +327,17 @@ def test_parameter_schema_required_with_default():
326327

327328
param = schema.to_param()
328329
assert param.default == 3
330+
331+
332+
def test_parameter_schema_required_with_explicit_none_default():
333+
"""Tests explicit default=None is treated as a provided default."""
334+
schema = ParameterSchema(
335+
name="opt_in",
336+
type="boolean",
337+
description="Optional flag",
338+
required=True,
339+
default=None,
340+
)
341+
342+
param = schema.to_param()
343+
assert param.default is None

packages/toolbox-core/tests/test_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def create_param_mock(name: str, description: str, annotation: Type) -> Mock:
3535
param_mock.name = name
3636
param_mock.description = description
3737
param_mock.required = True
38+
param_mock.default = None
39+
param_mock.has_default = False
3840

3941
mock_param_info = Mock()
4042
mock_param_info.annotation = annotation
@@ -422,6 +424,24 @@ def test_params_to_pydantic_model_with_params():
422424
Model(name="Bob", age="thirty", is_active=True)
423425

424426

427+
def test_params_to_pydantic_model_uses_explicit_default_none():
428+
"""Test that explicit default=None is honored for required schema fields."""
429+
tool_name = "MyToolWithExplicitNoneDefault"
430+
params = [
431+
ParameterSchema(
432+
name="message",
433+
type="string",
434+
description="Message value",
435+
required=True,
436+
default=None,
437+
)
438+
]
439+
Model = params_to_pydantic_model(tool_name, params)
440+
441+
assert "message" in Model.model_fields
442+
assert Model.model_fields["message"].default is None
443+
444+
425445
@pytest.mark.asyncio
426446
async def test_resolve_value_plain_value():
427447
"""Test resolving a plain, non-callable value."""

0 commit comments

Comments
 (0)