From 7187cb0b38c11653d5e4a65bbfaa236e404807c6 Mon Sep 17 00:00:00 2001 From: Sholto Armstrong Date: Thu, 28 May 2026 09:32:00 +0200 Subject: [PATCH 1/5] bump fix starlet version --- requirements/webapp.txt | 1 + spockflow/components/treelite/core.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/requirements/webapp.txt b/requirements/webapp.txt index babef30..1806984 100644 --- a/requirements/webapp.txt +++ b/requirements/webapp.txt @@ -1,3 +1,4 @@ fastapi>=0.104.1 +starlette>=1.1.0 uvicorn>=0.24.0.post1 gunicorn>=21.2.0 \ No newline at end of file diff --git a/spockflow/components/treelite/core.py b/spockflow/components/treelite/core.py index 8daa277..8b5b90e 100644 --- a/spockflow/components/treelite/core.py +++ b/spockflow/components/treelite/core.py @@ -36,13 +36,13 @@ class RangeTestNode(TestNode): thresholds: typing.List[float] = Field(min_length=1) @model_validator(mode="after") - def validate_thresholds(self) -> typing.Self: + def validate_thresholds(self) -> "typing.Self": if len(self.thresholds) < 1: raise ValueError(f"Range test nodes should have at least one threshold") return self @model_validator(mode="after") - def validate_children(self) -> typing.Self: + def validate_children(self) -> "typing.Self": # Ensure the number of children equals the number of thresholds + 1 # Since N thresholds create N+1 ranges # We allow more children here but ignore them if present @@ -133,7 +133,7 @@ class NumericalTestNode(TestNode): comparison_op: typing.Literal["<=", "<", "==", ">", ">="] @model_validator(mode="after") - def validate_two_children(self) -> typing.Self: + def validate_two_children(self) -> "typing.Self": if len(self.children) < 2: raise ValueError("Numerical nodes must have at least 2 child outputs") return self @@ -165,7 +165,7 @@ class CategoricalTestNode(TestNode): category_list_right_child: bool @model_validator(mode="after") - def validate_two_children(self) -> typing.Self: + def validate_two_children(self) -> "typing.Self": if len(self.children) < 2: raise ValueError("Categorical nodes must have at least 2 child outputs") return self From a2872a62e2365b1815ab9f478ac816594ca5ef1a Mon Sep 17 00:00:00 2001 From: Sholto Armstrong Date: Thu, 28 May 2026 09:40:48 +0200 Subject: [PATCH 2/5] formatting --- docs/concepts/decision_tables.ipynb | 23 +- docs/concepts/decision_trees.ipynb | 57 ++- docs/concepts/scorecard.ipynb | 24 +- docs/getting_started/quick_start.ipynb | 53 +- dsp-re-ui/example/tree/dtreevis.ipynb | 31 +- dsp-re-ui/example/tree/test.ipynb | 9 +- spockflow/_serializable.py | 1 - spockflow/components/calculate.py | 1 - spockflow/components/scorecard/__init__.py | 1 - spockflow/components/scorecard/v2/criteria.py | 1 - spockflow/components/tree/v1/compiled.py | 4 +- spockflow/inference/io/encoders.py | 1 - tests/test_pipelines/ptree02_basic/main.py | 1 - tree.ipynb | 475 +++++++++++------- 14 files changed, 422 insertions(+), 260 deletions(-) diff --git a/docs/concepts/decision_tables.ipynb b/docs/concepts/decision_tables.ipynb index eed270a..3a61652 100644 --- a/docs/concepts/decision_tables.ipynb +++ b/docs/concepts/decision_tables.ipynb @@ -55,14 +55,17 @@ } ], "source": [ - "example_dt\\\n", - " .add(dtable.DTMin, input_v1, [0,1,2,3,4,5,6,7,8,9,10])\\\n", - " .add(dtable.DTMax, input_v1, [1,2,3,4,5,6,7,8,9,10,11])\\\n", - " .add(dtable.DTMin, input_v2, [0,0,0,0,0,0,1,1,1,1,1])\\\n", - " .add(dtable.DTMax, input_v2, [1,1,1,1,1,1,2,2,2,2,2])\\\n", - " .set_default(pd.DataFrame({\"value\": [999], \"description\": [\"NA\"]}))\\\n", - " .output(\"value\", [1,2,0,None,-1,20,1,2,3,4,5])\\\n", - " .output(\"description\", [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\", \"h\", \"i\", \"j\", \"k\"])\n" + "example_dt.add(dtable.DTMin, input_v1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).add(\n", + " dtable.DTMax, input_v1, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]\n", + ").add(dtable.DTMin, input_v2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]).add(\n", + " dtable.DTMax, input_v2, [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]\n", + ").set_default(\n", + " pd.DataFrame({\"value\": [999], \"description\": [\"NA\"]})\n", + ").output(\n", + " \"value\", [1, 2, 0, None, -1, 20, 1, 2, 3, 4, 5]\n", + ").output(\n", + " \"description\", [\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\", \"h\", \"i\", \"j\", \"k\"]\n", + ")" ] }, { @@ -149,7 +152,7 @@ } ], "source": [ - "input_data = pd.DataFrame({input_v1: [5, 3, 8, 0, 10], input_v2: [0,0,0,0,4]})\n", + "input_data = pd.DataFrame({input_v1: [5, 3, 8, 0, 10], input_v2: [0, 0, 0, 0, 4]})\n", "result_df = example_dt.execute(inputs=input_data)\n", "result_df" ] @@ -257,7 +260,7 @@ " model_name=\"demo_spock_model\",\n", " model_version=\"1.0.0\",\n", " namespace=\"decision_table_config\",\n", - " config=example_dt.model_dump(mode='json')\n", + " config=example_dt.model_dump(mode=\"json\"),\n", ")\n", "\n", "# Load configuration\n", diff --git a/docs/concepts/decision_trees.ipynb b/docs/concepts/decision_trees.ipynb index 33ea6b9..c12be6e 100644 --- a/docs/concepts/decision_trees.ipynb +++ b/docs/concepts/decision_trees.ipynb @@ -43,10 +43,12 @@ "source": [ "from typing_extensions import TypedDict\n", "\n", + "\n", "class Reject(TypedDict):\n", " code: int\n", " description: str\n", "\n", + "\n", "RejectAction = Action[Reject]" ] }, @@ -145,6 +147,8 @@ "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n", "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n", " return (d > 5) & (e > 5) & (f > 5)\n", + "\n", + "\n", "tree.visualize(get_value_name=lambda x: x[\"description\"][0])" ] }, @@ -284,13 +288,21 @@ "def condition_a(a: pd.Series) -> pd.Series:\n", " return a > 5\n", "\n", - "@condition_a.condition(output=RejectAction(code=100, description=\"a and b are out of range\"))\n", + "\n", + "@condition_a.condition(\n", + " output=RejectAction(code=100, description=\"a and b are out of range\")\n", + ")\n", "def condition_b(b: pd.Series) -> pd.Series:\n", " return b > 5\n", "\n", - "@condition_a.condition(output=RejectAction(code=101, description=\"a and c are out of range\"))\n", + "\n", + "@condition_a.condition(\n", + " output=RejectAction(code=101, description=\"a and c are out of range\")\n", + ")\n", "def condition_c(c: pd.Series) -> pd.Series:\n", " return c > 5\n", + "\n", + "\n", "tree.visualize(get_value_name=lambda x: x[\"description\"][0])" ] }, @@ -597,17 +609,19 @@ } ], "source": [ - "test_data = pd.DataFrame({\n", - " \"a\": [5,6,7,8,1,2,3,4],\n", - " \"b\": [1,0,0,0,2,0,0,0],\n", - " \"c\": [0,10,0,0,0,10,0,0],\n", - " \"d\": [0,0,10,0,0,0,0,0],\n", - " \"e\": [0,0,10,0,0,0,0,0],\n", - " \"f\": [0,0,10,0,0,0,0,0],\n", - " # Below used later in the tutorial\n", - " \"nums\": [1,2,3,4,5,6,7,8],\n", - " \"input_condition\": [True, True, False, False, False, False, True, True],\n", - "})\n", + "test_data = pd.DataFrame(\n", + " {\n", + " \"a\": [5, 6, 7, 8, 1, 2, 3, 4],\n", + " \"b\": [1, 0, 0, 0, 2, 0, 0, 0],\n", + " \"c\": [0, 10, 0, 0, 0, 10, 0, 0],\n", + " \"d\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " \"e\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " \"f\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " # Below used later in the tutorial\n", + " \"nums\": [1, 2, 3, 4, 5, 6, 7, 8],\n", + " \"input_condition\": [True, True, False, False, False, False, True, True],\n", + " }\n", + ")\n", "test_data" ] }, @@ -874,7 +888,10 @@ } ], "source": [ - "tree.condition(condition=\"input_condition\", output=RejectAction(code=101, description=\"Input condition\"))\n", + "tree.condition(\n", + " condition=\"input_condition\",\n", + " output=RejectAction(code=101, description=\"Input condition\"),\n", + ")\n", "tree.visualize(get_value_name=lambda x: x[\"description\"][0])" ] }, @@ -1177,16 +1194,16 @@ ], "source": [ "tree = Tree()\n", - "tree.condition(output=Action(value=10), condition='A')\n", - "tree.condition(output=Action(value=20), condition='B')\n", + "tree.condition(output=Action(value=10), condition=\"A\")\n", + "tree.condition(output=Action(value=20), condition=\"B\")\n", "\n", "subtree = Tree()\n", - "subtree.condition(output=Action(value=100), condition='SubA')\n", - "subtree.condition(output=Action(value=200), condition='SubB')\n", + "subtree.condition(output=Action(value=100), condition=\"SubA\")\n", + "subtree.condition(output=Action(value=200), condition=\"SubB\")\n", "\n", "cond_subtree = Tree()\n", - "cond_subtree.condition(output=Action(value=1000), condition='SubD')\n", - "cond_subtree.condition(output=Action(value=2000), condition='SubE')\n", + "cond_subtree.condition(output=Action(value=1000), condition=\"SubD\")\n", + "cond_subtree.condition(output=Action(value=2000), condition=\"SubE\")\n", "\n", "tree.include_subtree(subtree)\n", "tree.include_subtree(cond_subtree, condition=\"SubC\")\n", diff --git a/docs/concepts/scorecard.ipynb b/docs/concepts/scorecard.ipynb index dbc7b02..0be9ac1 100644 --- a/docs/concepts/scorecard.ipynb +++ b/docs/concepts/scorecard.ipynb @@ -26,9 +26,9 @@ "var_2 = \"var_2\"\n", "\n", "sc = scorecard.ScoreCard(\n", - " bin_prefix='SCORE_BIN_',\n", - " score_prefix='SCORE_VALUE_',\n", - " description_prefix='SCORE_DESC_'\n", + " bin_prefix=\"SCORE_BIN_\",\n", + " score_prefix=\"SCORE_VALUE_\",\n", + " description_prefix=\"SCORE_DESC_\",\n", ")" ] }, @@ -109,8 +109,8 @@ "source": [ "sc.add_criteria(\n", " scorecard.ScoreCriteria(var_2, \"categorical\", default_behavior=\"regex\")\n", - " .add_discrete_score(['a', 'b', 'c'], 10, \"First pattern var_2\")\n", - " .add_discrete_score(['[b-z]'], 20, \"Second pattern var_2\")\n", + " .add_discrete_score([\"a\", \"b\", \"c\"], 10, \"First pattern var_2\")\n", + " .add_discrete_score([\"[b-z]\"], 20, \"Second pattern var_2\")\n", ")" ] }, @@ -241,10 +241,12 @@ } ], "source": [ - "test_data = pd.DataFrame({\n", - " \"var_1\": [ 0, 1, 2, None, 0, 1, 2, None],\n", - " \"var_2\": [ 'a', 'b', 'z', 'a', 'a', 'b', 'z', '9'],\n", - "})\n", + "test_data = pd.DataFrame(\n", + " {\n", + " \"var_1\": [0, 1, 2, None, 0, 1, 2, None],\n", + " \"var_2\": [\"a\", \"b\", \"z\", \"a\", \"a\", \"b\", \"z\", \"9\"],\n", + " }\n", + ")\n", "test_data" ] }, @@ -424,7 +426,7 @@ " model_name=\"demo_spock_model\",\n", " model_version=\"1.0.0\",\n", " namespace=\"scorecard_config\",\n", - " config=sc.model_dump(mode='json')\n", + " config=sc.model_dump(mode=\"json\"),\n", ")" ] }, @@ -458,7 +460,7 @@ } ], "source": [ - "config = conf_manager.get_config(\"demo_spock_model\", \"1.0.0\")['scorecard_config']\n", + "config = conf_manager.get_config(\"demo_spock_model\", \"1.0.0\")[\"scorecard_config\"]\n", "sc_loaded = scorecard.ScoreCard.from_config(\"\").load(config)\n", "\n", "# Retrieve view model and display widget\n", diff --git a/docs/getting_started/quick_start.ipynb b/docs/getting_started/quick_start.ipynb index 163ca98..10b99d2 100644 --- a/docs/getting_started/quick_start.ipynb +++ b/docs/getting_started/quick_start.ipynb @@ -47,10 +47,12 @@ "from typing_extensions import TypedDict\n", "import pandas as pd\n", "\n", + "\n", "class Reject(TypedDict):\n", " code: int\n", " description: str\n", "\n", + "\n", "RejectAction = Action[Reject]" ] }, @@ -106,11 +108,17 @@ "def condition_a(a: pd.Series) -> pd.Series:\n", " return a > 5\n", "\n", - "@condition_a.condition(output=RejectAction(code=100, description=\"a and b is out of range\"))\n", + "\n", + "@condition_a.condition(\n", + " output=RejectAction(code=100, description=\"a and b is out of range\")\n", + ")\n", "def condition_b(b: pd.Series) -> pd.Series:\n", " return b > 5\n", "\n", - "@condition_a.condition(output=RejectAction(code=101, description=\"a and c is out of range\"))\n", + "\n", + "@condition_a.condition(\n", + " output=RejectAction(code=101, description=\"a and c is out of range\")\n", + ")\n", "def condition_c(c: pd.Series) -> pd.Series:\n", " return c > 5" ] @@ -286,15 +294,17 @@ } ], "source": [ - "test_data = pd.DataFrame({\n", - " \"num\": [1,2,3,4,5,6,7,8], # Note this is for use later in the tutorial\n", - " \"a\": [5,6,7,8,1,2,3,4],\n", - " \"b\": [1,0,0,0,2,0,0,0],\n", - " \"c\": [0,10,0,0,0,10,0,0],\n", - " \"d\": [0,0,10,0,0,0,0,0],\n", - " \"e\": [0,0,10,0,0,0,0,0],\n", - " \"f\": [0,0,10,0,0,0,0,0],\n", - "})\n", + "test_data = pd.DataFrame(\n", + " {\n", + " \"num\": [1, 2, 3, 4, 5, 6, 7, 8], # Note this is for use later in the tutorial\n", + " \"a\": [5, 6, 7, 8, 1, 2, 3, 4],\n", + " \"b\": [1, 0, 0, 0, 2, 0, 0, 0],\n", + " \"c\": [0, 10, 0, 0, 0, 10, 0, 0],\n", + " \"d\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " \"e\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " \"f\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " }\n", + ")\n", "test_data" ] }, @@ -810,6 +820,7 @@ ], "source": [ "from spockflow.core import Driver\n", + "\n", "dr = Driver({}, demo_tree)\n", "df = dr.execute(inputs=test_data)\n", "df" @@ -916,6 +927,7 @@ ], "source": [ "from hamilton.driver import Driver as HamiltonDriver\n", + "\n", "dr_ham = HamiltonDriver({}, demo_tree)\n", "df = dr_ham.execute(inputs=test_data, final_vars=[\"tree\"])\n", "df" @@ -1171,7 +1183,11 @@ } ], "source": [ - "df = dr.execute(inputs=test_data, final_vars=[\"condition_b\",\"tree\"], overrides={\"b\": test_data[\"b\"]})\n", + "df = dr.execute(\n", + " inputs=test_data,\n", + " final_vars=[\"condition_b\", \"tree\"],\n", + " overrides={\"b\": test_data[\"b\"]},\n", + ")\n", "df" ] }, @@ -1422,7 +1438,11 @@ } ], "source": [ - "dr.visualize_execution(inputs=test_data, final_vars=[\"condition_b\",\"tree\"], overrides={\"b\": test_data[\"b\"]})" + "dr.visualize_execution(\n", + " inputs=test_data,\n", + " final_vars=[\"condition_b\", \"tree\"],\n", + " overrides={\"b\": test_data[\"b\"]},\n", + ")" ] }, { @@ -1538,6 +1558,7 @@ "import os\n", "import json\n", "import requests\n", + "\n", "os.environ[\"MODEL_PREFIX\"] = os.path.abspath(\".\")\n", "os.environ[\"MODEL_RELATIVE_PATH\"] = \"source_dir\"" ] @@ -1645,8 +1666,10 @@ } ], "source": [ - "resp = requests.post(\"http://localhost:8000/invocations\", json=test_data.to_dict(orient='records'))\n", - "pd.DataFrame(resp.json()['tree'])" + "resp = requests.post(\n", + " \"http://localhost:8000/invocations\", json=test_data.to_dict(orient=\"records\")\n", + ")\n", + "pd.DataFrame(resp.json()[\"tree\"])" ] }, { diff --git a/dsp-re-ui/example/tree/dtreevis.ipynb b/dsp-re-ui/example/tree/dtreevis.ipynb index 4d78b2f..e7719a6 100644 --- a/dsp-re-ui/example/tree/dtreevis.ipynb +++ b/dsp-re-ui/example/tree/dtreevis.ipynb @@ -16,8 +16,8 @@ "outputs": [], "source": [ "%config InlineBackend.figure_format = 'retina' # Make visualizations look good\n", - "#%config InlineBackend.figure_format = 'svg' \n", - "%matplotlib inline\n" + "# %config InlineBackend.figure_format = 'svg'\n", + "%matplotlib inline" ] }, { @@ -32,7 +32,7 @@ "import treelite.sklearn\n", "import numpy as np\n", "import pandas as pd\n", - "import dtreeviz\n" + "import dtreeviz" ] }, { @@ -55,11 +55,12 @@ "root_nodes = CompiledTreeliteTree._identify_independent_tree_roots(tree.nodes)\n", "trees = [\n", " CompiledTreeliteTree._build_treelite_tree(\n", - " root_nodes=[rn], \n", - " tree=tree, \n", + " root_nodes=[rn],\n", + " tree=tree,\n", " node_id_mapping=node_id_mapping,\n", - " output_encoding=OutputEncoding.ONE_HOT\n", - " ).commit() for rn in root_nodes\n", + " output_encoding=OutputEncoding.ONE_HOT,\n", + " ).commit()\n", + " for rn in root_nodes\n", "]" ] }, @@ -69,7 +70,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "sklearn_tree = treelite.sklearn.export_model(trees[0])" ] }, @@ -187,7 +187,7 @@ "with open(\"data.json\") as fp:\n", " input_data = pd.json_normalize(json.load(fp))\n", "input_data\n", - "input_data[\"cls\"] = [0,0,0,0,1,0,1,1]\n", + "input_data[\"cls\"] = [0, 0, 0, 0, 1, 0, 1, 1]\n", "input_data" ] }, @@ -197,7 +197,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "sklearn_tree.estimators_[0].classes_ = np.arange(2, dtype=np.int32)" ] }, @@ -207,10 +206,14 @@ "metadata": {}, "outputs": [], "source": [ - "viz_model = dtreeviz.model(sklearn_tree.estimators_[0],\n", - " X_train=input_data[tree.features], y_train=input_data[\"cls\"],\n", - " feature_names=tree.features,\n", - " target_name=\"cls\", class_names=[\"Positive\", \"Negative\"])" + "viz_model = dtreeviz.model(\n", + " sklearn_tree.estimators_[0],\n", + " X_train=input_data[tree.features],\n", + " y_train=input_data[\"cls\"],\n", + " feature_names=tree.features,\n", + " target_name=\"cls\",\n", + " class_names=[\"Positive\", \"Negative\"],\n", + ")" ] }, { diff --git a/dsp-re-ui/example/tree/test.ipynb b/dsp-re-ui/example/tree/test.ipynb index e82e7ae..db09025 100644 --- a/dsp-re-ui/example/tree/test.ipynb +++ b/dsp-re-ui/example/tree/test.ipynb @@ -37,7 +37,7 @@ "outputs": [], "source": [ "combined_res = model.execute(\n", - " inputs=input_data, \n", + " inputs=input_data,\n", ")\n", "combined_res" ] @@ -49,8 +49,7 @@ "outputs": [], "source": [ "res_everything = model.raw_execute(\n", - " inputs=input_data, \n", - " final_vars=list(model.graph.nodes.keys())\n", + " inputs=input_data, final_vars=list(model.graph.nodes.keys())\n", ")\n", "res_everything" ] @@ -63,8 +62,8 @@ "source": [ "resp = requests.post(\n", " \"http://localhost:8000/invocations\",\n", - " data = input_data.to_json(orient=\"records\").encode() ,\n", - " headers={\"accept\": \"*/*\", \"Content-Type\": \"application/json\"}\n", + " data=input_data.to_json(orient=\"records\").encode(),\n", + " headers={\"accept\": \"*/*\", \"Content-Type\": \"application/json\"},\n", ")\n", "resjson = resp.json()\n", "resjson" diff --git a/spockflow/_serializable.py b/spockflow/_serializable.py index 5438be3..4bcf14a 100644 --- a/spockflow/_serializable.py +++ b/spockflow/_serializable.py @@ -10,7 +10,6 @@ ) from pydantic.json_schema import JsonSchemaValue - values_schema = core_schema.dict_schema(keys_schema=core_schema.str_schema()) dataframe_json_schema = core_schema.model_fields_schema( diff --git a/spockflow/components/calculate.py b/spockflow/components/calculate.py index 111cc2e..14c5905 100644 --- a/spockflow/components/calculate.py +++ b/spockflow/components/calculate.py @@ -6,7 +6,6 @@ from spockflow.nodes import VariableNode, creates_node from hamilton import node - T = typing.TypeVar("T") diff --git a/spockflow/components/scorecard/__init__.py b/spockflow/components/scorecard/__init__.py index 63e3d57..072193e 100644 --- a/spockflow/components/scorecard/__init__.py +++ b/spockflow/components/scorecard/__init__.py @@ -9,7 +9,6 @@ from .v2.criteria_numerical import Bounds from .probability import log_odds_from_score, probability_of_default_from_log_odds - # This is for when there is more than one version # class ScoreCard(RootModel): # root: typing.Annotated[ diff --git a/spockflow/components/scorecard/v2/criteria.py b/spockflow/components/scorecard/v2/criteria.py index 6faacb2..c364cde 100644 --- a/spockflow/components/scorecard/v2/criteria.py +++ b/spockflow/components/scorecard/v2/criteria.py @@ -7,7 +7,6 @@ from .criteria_numerical import ScoreCriteriaNumerical from .criteria_categorical import ScoreCriteriaCategorical - ScoreCriteria = typing.Annotated[ typing.Union[ScoreCriteriaCategorical, ScoreCriteriaNumerical], Field(discriminator="type"), diff --git a/spockflow/components/tree/v1/compiled.py b/spockflow/components/tree/v1/compiled.py index 7bfb893..57b2f6b 100644 --- a/spockflow/components/tree/v1/compiled.py +++ b/spockflow/components/tree/v1/compiled.py @@ -65,7 +65,7 @@ def __init__(self, tree: Tree) -> None: ) self._has_priority = any(self._flattened_priority != 1) - (predefined_conditions, predefined_condition_names, execution_conditions) = ( + predefined_conditions, predefined_condition_names, execution_conditions = ( self._split_predefined( items=flattened_tree.conditions, predefined_types=(pd.Series, np.ndarray), @@ -82,7 +82,7 @@ def __init__(self, tree: Tree) -> None: self.all_condition_names = predefined_condition_names + execution_conditions self.execution_conditions = execution_conditions - (predefined_outputs, predefined_outputs_names, execution_outputs) = ( + predefined_outputs, predefined_outputs_names, execution_outputs = ( self._split_predefined( items=flattened_tree.outputs, predefined_types=pd.DataFrame, diff --git a/spockflow/inference/io/encoders.py b/spockflow/inference/io/encoders.py index f49f338..7b7fd75 100644 --- a/spockflow/inference/io/encoders.py +++ b/spockflow/inference/io/encoders.py @@ -7,7 +7,6 @@ from .responses import Response, CSVResponse, JSONResponse from . import content_types - TDefaultResult = typing.Union[ pd.Series, pd.DataFrame, np.ndarray, typing.Dict[str, typing.Any] ] diff --git a/tests/test_pipelines/ptree02_basic/main.py b/tests/test_pipelines/ptree02_basic/main.py index 7c020f0..7908cd4 100644 --- a/tests/test_pipelines/ptree02_basic/main.py +++ b/tests/test_pipelines/ptree02_basic/main.py @@ -2,7 +2,6 @@ from spockflow.components.tree import Tree from spockflow.components.common import Reject - tree = Tree() diff --git a/tree.ipynb b/tree.ipynb index 9010a83..d968251 100644 --- a/tree.ipynb +++ b/tree.ipynb @@ -23,14 +23,18 @@ "\n", "from typing_extensions import TypedDict\n", "\n", + "\n", "class Reject(TypedDict):\n", " code: int\n", " description: str\n", "\n", + "\n", "RejectAction = Action[Reject]\n", "\n", "\n", "tree = Tree()\n", + "\n", + "\n", "@tree.condition(output=pd.DataFrame([dict(code=102, description=\"My first condition\")]))\n", "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n", " return (d > 5) & (e > 5) & (f > 5)" @@ -50,6 +54,8 @@ "outputs": [], "source": [ "tree = Tree()\n", + "\n", + "\n", "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n", "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n", " return (d > 5) & (e > 5) & (f > 5)" @@ -248,38 +254,55 @@ } ], "source": [ - "example_dt = dtable.DecisionTable()\\\n", - " .add(dtable.DTMin, \"input_v1\", [0,1,2,3,4,5,6,7,8,9,10])\\\n", - " .add(dtable.DTMax, \"input_v1\", [1,2,3,4,5,6,7,8,9,10,11])\\\n", - " .add(dtable.DTMin, \"input_v2\", [0,0,0,0,0,0,1,1,1,1,1])\\\n", - " .add(dtable.DTMax, \"input_v2\", [1,1,1,1,1,1,2,2,2,2,2])\\\n", - " .set_default(pd.DataFrame({\"value\": [1]}))\\\n", - " .output(\"value\", [0,2,4,0,2,3,2,2,1,1])\n", + "example_dt = (\n", + " dtable.DecisionTable()\n", + " .add(dtable.DTMin, \"input_v1\", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n", + " .add(dtable.DTMax, \"input_v1\", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])\n", + " .add(dtable.DTMin, \"input_v2\", [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n", + " .add(dtable.DTMax, \"input_v2\", [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2])\n", + " .set_default(pd.DataFrame({\"value\": [1]}))\n", + " .output(\"value\", [0, 2, 4, 0, 2, 3, 2, 2, 1, 1])\n", + ")\n", "\n", "tree = Tree()\n", + "\n", + "\n", "@tree.condition(output=RejectAction(code=102, description=\"My first condition\"))\n", "def first_condition(d: pd.Series, e: pd.Series, f: pd.Series) -> pd.Series:\n", " return (d > 5) & (e > 5) & (f > 5)\n", "\n", - "tree.condition(condition=TableCondition(name=\"test\", table = example_dt), output=[\n", - " RejectAction(code=100, description=\"o1\"),\n", - " RejectAction(code=100, description=\"o2\"),\n", - " RejectAction(code=100, description=\"o3\"),\n", - " RejectAction(code=100, description=\"o4\"),\n", - " RejectAction(code=100, description=\"o5\"),\n", - "])\n", + "\n", + "tree.condition(\n", + " condition=TableCondition(name=\"test\", table=example_dt),\n", + " output=[\n", + " RejectAction(code=100, description=\"o1\"),\n", + " RejectAction(code=100, description=\"o2\"),\n", + " RejectAction(code=100, description=\"o3\"),\n", + " RejectAction(code=100, description=\"o4\"),\n", + " RejectAction(code=100, description=\"o5\"),\n", + " ],\n", + ")\n", + "\n", "\n", "@tree.condition()\n", "def condition_a(a: pd.Series) -> pd.Series:\n", " return a > 5\n", "\n", - "@condition_a.condition(output=RejectAction(code=100, description=\"a and b are out of range\"))\n", + "\n", + "@condition_a.condition(\n", + " output=RejectAction(code=100, description=\"a and b are out of range\")\n", + ")\n", "def condition_b(b: pd.Series) -> pd.Series:\n", " return b > 5\n", "\n", - "@condition_a.condition(output=RejectAction(code=101, description=\"a and c are out of range\"))\n", + "\n", + "@condition_a.condition(\n", + " output=RejectAction(code=101, description=\"a and c are out of range\")\n", + ")\n", "def condition_c(c: pd.Series) -> pd.Series:\n", " return c > 5\n", + "\n", + "\n", "tree.visualize(get_value_name=lambda x: x[\"description\"][0])" ] }, @@ -965,17 +988,19 @@ } ], "source": [ - "test_data = pd.DataFrame({\n", - " \"a\": [5,6,7,8,1,2,3,4],\n", - " \"b\": [1,0,0,0,2,0,0,0],\n", - " \"c\": [0,10,0,0,0,10,0,0],\n", - " \"d\": [0,0,10,0,0,0,0,0],\n", - " \"e\": [0,0,10,0,0,0,0,0],\n", - " \"f\": [0,0,10,0,0,0,0,0],\n", - " # Below used later in the tutorial\n", - " \"nums\": [1,2,3,4,5,6,7,8],\n", - " \"input_condition\": [True, True, False, False, False, False, True, True],\n", - "})\n", + "test_data = pd.DataFrame(\n", + " {\n", + " \"a\": [5, 6, 7, 8, 1, 2, 3, 4],\n", + " \"b\": [1, 0, 0, 0, 2, 0, 0, 0],\n", + " \"c\": [0, 10, 0, 0, 0, 10, 0, 0],\n", + " \"d\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " \"e\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " \"f\": [0, 0, 10, 0, 0, 0, 0, 0],\n", + " # Below used later in the tutorial\n", + " \"nums\": [1, 2, 3, 4, 5, 6, 7, 8],\n", + " \"input_condition\": [True, True, False, False, False, False, True, True],\n", + " }\n", + ")\n", "test_data" ] }, @@ -1699,7 +1724,8 @@ " value: DataFrame\n", " # value: typing.Union[typing.Callable[..., pd.DataFrame], DataFrame, str, \"Test\", None] = None\n", "\n", - "Test(value=pd.DataFrame({\"a\": [1]}))\n" + "\n", + "Test(value=pd.DataFrame({\"a\": [1]}))" ] }, { @@ -1754,20 +1780,20 @@ "def dump_dataframe_to_dict(instance: pd.DataFrame) -> Dict[str, Any]:\n", " \"\"\"Serialize a DataFrame to a dictionary.\"\"\"\n", " return {\n", - " \"type\": \"DataFrame\", \n", + " \"type\": \"DataFrame\",\n", " \"values\": instance.to_dict(orient=\"list\"),\n", " \"columns\": list(instance.columns),\n", - " \"index\": instance.index.tolist()\n", + " \"index\": instance.index.tolist(),\n", " }\n", "\n", "\n", "def dump_series_to_dict(instance: pd.Series) -> Dict[str, Any]:\n", " \"\"\"Serialize a Series to a dictionary.\"\"\"\n", " return {\n", - " \"type\": \"Series\", \n", + " \"type\": \"Series\",\n", " \"values\": instance.tolist(),\n", " \"name\": instance.name,\n", - " \"index\": instance.index.tolist()\n", + " \"index\": instance.index.tolist(),\n", " }\n", "\n", "\n", @@ -1782,22 +1808,24 @@ " # Ensure the input is a dictionary with the correct type\n", " if not isinstance(value, dict) or value.get(\"type\") != \"DataFrame\":\n", " raise ValueError(\"Invalid DataFrame representation\")\n", - " \n", + "\n", " # Reconstruct DataFrame with original columns and index\n", " df = pd.DataFrame(\n", - " value[\"values\"], \n", - " columns=value.get(\"columns\"),\n", - " index=value.get(\"index\")\n", + " value[\"values\"], columns=value.get(\"columns\"), index=value.get(\"index\")\n", " )\n", " return df\n", "\n", " return core_schema.json_or_python_schema(\n", - " json_schema=core_schema.no_info_plain_validator_function(validate_dataframe),\n", - " python_schema=core_schema.union_schema([\n", - " # Specifically check for DataFrame type\n", - " core_schema.is_instance_schema(pd.DataFrame),\n", - " core_schema.no_info_plain_validator_function(validate_dataframe)\n", - " ]),\n", + " json_schema=core_schema.no_info_plain_validator_function(\n", + " validate_dataframe\n", + " ),\n", + " python_schema=core_schema.union_schema(\n", + " [\n", + " # Specifically check for DataFrame type\n", + " core_schema.is_instance_schema(pd.DataFrame),\n", + " core_schema.no_info_plain_validator_function(validate_dataframe),\n", + " ]\n", + " ),\n", " serialization=core_schema.plain_serializer_function_ser_schema(\n", " dump_dataframe_to_dict\n", " ),\n", @@ -1814,9 +1842,9 @@ " \"type\": {\"const\": \"DataFrame\"},\n", " \"values\": {\"type\": \"object\"},\n", " \"columns\": {\"type\": \"array\"},\n", - " \"index\": {\"type\": \"array\"}\n", + " \"index\": {\"type\": \"array\"},\n", " },\n", - " \"required\": [\"type\", \"values\"]\n", + " \"required\": [\"type\", \"values\"],\n", " }\n", "\n", "\n", @@ -1831,21 +1859,21 @@ " # Ensure the input is a dictionary with the correct type\n", " if not isinstance(value, dict) or value.get(\"type\") != \"Series\":\n", " raise ValueError(\"Invalid Series representation\")\n", - " \n", + "\n", " # Reconstruct Series with original name and index\n", " return pd.Series(\n", - " value[\"values\"], \n", - " name=value.get(\"name\"),\n", - " index=value.get(\"index\")\n", + " value[\"values\"], name=value.get(\"name\"), index=value.get(\"index\")\n", " )\n", "\n", " return core_schema.json_or_python_schema(\n", " json_schema=core_schema.no_info_plain_validator_function(validate_series),\n", - " python_schema=core_schema.union_schema([\n", - " # Specifically check for Series type\n", - " core_schema.is_instance_schema(pd.Series),\n", - " core_schema.no_info_plain_validator_function(validate_series)\n", - " ]),\n", + " python_schema=core_schema.union_schema(\n", + " [\n", + " # Specifically check for Series type\n", + " core_schema.is_instance_schema(pd.Series),\n", + " core_schema.no_info_plain_validator_function(validate_series),\n", + " ]\n", + " ),\n", " serialization=core_schema.plain_serializer_function_ser_schema(\n", " dump_series_to_dict\n", " ),\n", @@ -1862,9 +1890,9 @@ " \"type\": {\"const\": \"Series\"},\n", " \"values\": {\"type\": \"array\"},\n", " \"name\": {\"type\": [\"string\", \"null\"]},\n", - " \"index\": {\"type\": \"array\"}\n", + " \"index\": {\"type\": \"array\"},\n", " },\n", - " \"required\": [\"type\", \"values\"]\n", + " \"required\": [\"type\", \"values\"],\n", " }\n", "\n", "\n", @@ -1879,13 +1907,9 @@ " metric: Series\n", "\n", "\n", - "\n", "# Create sample DataFrame and Series\n", - "df = pd.DataFrame({\n", - " 'A': [1, 2, 3],\n", - " 'B': [4, 5, 6]\n", - "})\n", - "series = pd.Series([10, 20, 30], name='values')\n", + "df = pd.DataFrame({\"A\": [1, 2, 3], \"B\": [4, 5, 6]})\n", + "series = pd.Series([10, 20, 30], name=\"values\")\n", "\n", "# Create a model instance\n", "model = MyModel(data=df, metric=series)\n", @@ -1907,8 +1931,7 @@ "try:\n", " # This should raise a validation error\n", " invalid_model = MyModel(\n", - " data={\"type\": \"Series\", \"values\": [1, 2, 3]}, # Wrong type\n", - " metric=series\n", + " data={\"type\": \"Series\", \"values\": [1, 2, 3]}, metric=series # Wrong type\n", " )\n", "except ValidationError as e:\n", " print(\"\\nValidation Error:\", e)" @@ -1943,12 +1966,13 @@ " return pd.DataFrame(value[\"values\"])\n", " return pd.Series(value[\"values\"], name=value[\"name\"])\n", "\n", + "\n", "core_schema.chain_schema(\n", - " [\n", - " core_schema.dict_schema(), # TODO make this more comprehensive\n", - " core_schema.no_info_plain_validator_function(validate_from_dict),\n", - " ]\n", - " )" + " [\n", + " core_schema.dict_schema(), # TODO make this more comprehensive\n", + " core_schema.no_info_plain_validator_function(validate_from_dict),\n", + " ]\n", + ")" ] }, { @@ -1960,6 +1984,7 @@ "from dataclasses import dataclass\n", "import typing\n", "\n", + "\n", "@dataclass\n", "class TestClass:\n", " class_value: int\n", @@ -1995,26 +2020,44 @@ "source": [ "values_schema = core_schema.dict_schema(keys_schema=core_schema.str_schema())\n", "\n", - "dataframe_json_schema = core_schema.model_fields_schema({\n", - " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", - " \"values\": core_schema.model_field(core_schema.union_schema([values_schema, core_schema.list_schema(values_schema)])),\n", - " \"dtypes\": core_schema.model_field(\n", - " core_schema.with_default_schema(\n", - " core_schema.union_schema([core_schema.none_schema(), core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema())]),\n", - " default=None,\n", - " )\n", - " )\n", - "})\n", - "series_json_schema = core_schema.model_fields_schema({\n", - " \"type\": core_schema.model_field(core_schema.literal_schema([\"Series\"])),\n", - " \"values\": core_schema.model_field(core_schema.list_schema()),\n", - " \"name\": core_schema.model_field(\n", - " core_schema.with_default_schema(\n", - " core_schema.union_schema([core_schema.none_schema(), core_schema.str_schema()]),\n", - " default=None,\n", - " )\n", - " )\n", - "})\n", + "dataframe_json_schema = core_schema.model_fields_schema(\n", + " {\n", + " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", + " \"values\": core_schema.model_field(\n", + " core_schema.union_schema(\n", + " [values_schema, core_schema.list_schema(values_schema)]\n", + " )\n", + " ),\n", + " \"dtypes\": core_schema.model_field(\n", + " core_schema.with_default_schema(\n", + " core_schema.union_schema(\n", + " [\n", + " core_schema.none_schema(),\n", + " core_schema.dict_schema(\n", + " keys_schema=core_schema.str_schema(),\n", + " values_schema=core_schema.str_schema(),\n", + " ),\n", + " ]\n", + " ),\n", + " default=None,\n", + " )\n", + " ),\n", + " }\n", + ")\n", + "series_json_schema = core_schema.model_fields_schema(\n", + " {\n", + " \"type\": core_schema.model_field(core_schema.literal_schema([\"Series\"])),\n", + " \"values\": core_schema.model_field(core_schema.list_schema()),\n", + " \"name\": core_schema.model_field(\n", + " core_schema.with_default_schema(\n", + " core_schema.union_schema(\n", + " [core_schema.none_schema(), core_schema.str_schema()]\n", + " ),\n", + " default=None,\n", + " )\n", + " ),\n", + " }\n", + ")\n", "\n", "\n", "def validate_to_dataframe_schema(value: dict) -> pd.DataFrame:\n", @@ -2033,6 +2076,7 @@ " value, *_ = value\n", " return pd.Series(value[\"values\"], name=value.get(\"name\"))\n", "\n", + "\n", "from_df_dict_schema = core_schema.chain_schema(\n", " [\n", " dataframe_json_schema,\n", @@ -2048,15 +2092,22 @@ " ]\n", ")\n", "\n", + "\n", "def dump_df_to_dict(instance: pd.DataFrame) -> dict:\n", " values = instance.to_dict(orient=\"records\")\n", " if len(values) == 1:\n", " values = values[0]\n", - " return {\"type\": \"DataFrame\", \"values\": values, \"dtypes\": {k: str(v) for k,v in instance.dtypes.items()}}\n", + " return {\n", + " \"type\": \"DataFrame\",\n", + " \"values\": values,\n", + " \"dtypes\": {k: str(v) for k, v in instance.dtypes.items()},\n", + " }\n", + "\n", "\n", "def dump_series_to_dict(instance: pd.Series) -> dict:\n", " return {\"type\": \"Series\", \"values\": instance.to_list(), \"name\": instance.name}\n", "\n", + "\n", "class _PandasDataFramePydanticAnnotation:\n", " @classmethod\n", " def __get_pydantic_core_schema__(\n", @@ -2083,8 +2134,9 @@ " def __get_pydantic_json_schema__(\n", " cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler\n", " ) -> JsonSchemaValue:\n", - " return handler(dataframe_json_schema) \n", - " \n", + " return handler(dataframe_json_schema)\n", + "\n", + "\n", "class _PandasSeriesPydanticAnnotation:\n", " @classmethod\n", " def __get_pydantic_core_schema__(\n", @@ -2111,7 +2163,7 @@ " def __get_pydantic_json_schema__(\n", " cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler\n", " ) -> JsonSchemaValue:\n", - " return handler(dataframe_json_schema) \n", + " return handler(dataframe_json_schema)\n", "\n", "\n", "DataFrame = Annotated[pd.DataFrame, _PandasDataFramePydanticAnnotation]\n", @@ -2119,9 +2171,11 @@ "\n", "\n", "import typing\n", + "\n", + "\n", "class MyModel(BaseModel):\n", " data: DataFrame\n", - " metric: typing.Union[Series,DataFrame]" + " metric: typing.Union[Series, DataFrame]" ] }, { @@ -2130,11 +2184,8 @@ "metadata": {}, "outputs": [], "source": [ - "df = pd.DataFrame({\n", - " 'A': [1, 2, 3],\n", - " 'B': [4, 5, 6]\n", - "})\n", - "series = pd.Series([10, 20, 30], name='values')\n", + "df = pd.DataFrame({\"A\": [1, 2, 3], \"B\": [4, 5, 6]})\n", + "series = pd.Series([10, 20, 30], name=\"values\")\n", "\n", "# Create a model instance\n", "model = MyModel(data=df, metric=series)" @@ -2170,7 +2221,6 @@ } ], "source": [ - "\n", "serialized = model.model_dump()\n", "print(\"Serialized:\", serialized)\n", "reconstructed_model = MyModel.model_validate(serialized)\n", @@ -2195,12 +2245,15 @@ } ], "source": [ - "{'data': {'type': 'DataFrame', 'values': [{'A': 1, 'B': 4}, {'A': 2, 'B': 5}, {'A': 3, 'B': 6}], 'dtypes': {'A': 'int64', 'B': 'int64'}}}\n", + "{\n", + " \"data\": {\n", + " \"type\": \"DataFrame\",\n", + " \"values\": [{\"A\": 1, \"B\": 4}, {\"A\": 2, \"B\": 5}, {\"A\": 3, \"B\": 6}],\n", + " \"dtypes\": {\"A\": \"int64\", \"B\": \"int64\"},\n", + " }\n", + "}\n", "\n", - "df = pd.DataFrame({\n", - " 'A': [1, 2, 3],\n", - " 'B': [4, 5, 6]\n", - "})\n", + "df = pd.DataFrame({\"A\": [1, 2, 3], \"B\": [4, 5, 6]})\n", "\n", "model = MyModel2(data=df)\n", "\n", @@ -2294,34 +2347,37 @@ ")\n", "from pydantic.json_schema import JsonSchemaValue\n", "\n", + "\n", "def validate_to_dataframe_schema(value: dict) -> pd.DataFrame:\n", " # Ensure the input is a dictionary with the correct type\n", " if not isinstance(value, dict) or value.get(\"type\") != \"DataFrame\":\n", " raise ValueError(\"Invalid DataFrame representation\")\n", - " \n", + "\n", " values = value[\"values\"]\n", " # Handle both single dict and list of dicts\n", " if isinstance(values, dict):\n", " values = [values]\n", - " \n", + "\n", " # Reconstruct DataFrame with specified dtypes\n", " dtypes = value.get(\"dtypes\", {})\n", " return pd.DataFrame(values, dtype=dtypes)\n", "\n", + "\n", "def dump_df_to_dict(instance: pd.DataFrame) -> dict:\n", " # Convert DataFrame to dictionary representation\n", " values = instance.to_dict(orient=\"list\")\n", - " \n", + "\n", " # Simplify values if there's only one column\n", " if len(values) == 1:\n", " values = values[list(values.keys())[0]]\n", - " \n", + "\n", " return {\n", - " \"type\": \"DataFrame\", \n", - " \"values\": values, \n", - " \"dtypes\": {k: str(v) for k, v in instance.dtypes.items()}\n", + " \"type\": \"DataFrame\",\n", + " \"values\": values,\n", + " \"dtypes\": {k: str(v) for k, v in instance.dtypes.items()},\n", " }\n", "\n", + "\n", "class _PandasDataFramePydanticAnnotation:\n", " @classmethod\n", " def __get_pydantic_core_schema__(\n", @@ -2330,38 +2386,45 @@ " _handler: GetCoreSchemaHandler,\n", " ) -> core_schema.CoreSchema:\n", " # Create schemas for different parts of the DataFrame representation\n", - " values_schema = core_schema.dict_schema(\n", - " keys_schema=core_schema.str_schema()\n", + " values_schema = core_schema.dict_schema(keys_schema=core_schema.str_schema())\n", + "\n", + " dataframe_json_schema = core_schema.model_fields_schema(\n", + " {\n", + " \"type\": core_schema.literal_schema([\"DataFrame\"]),\n", + " \"values\": core_schema.union_schema(\n", + " [values_schema, core_schema.list_schema(values_schema)]\n", + " ),\n", + " \"dtypes\": core_schema.union_schema(\n", + " [\n", + " core_schema.none_schema(),\n", + " core_schema.dict_schema(\n", + " keys_schema=core_schema.str_schema(),\n", + " values_schema=core_schema.str_schema(),\n", + " ),\n", + " ]\n", + " ),\n", + " }\n", " )\n", - " \n", - " dataframe_json_schema = core_schema.model_fields_schema({\n", - " \"type\": core_schema.literal_schema([\"DataFrame\"]),\n", - " \"values\": core_schema.union_schema([\n", - " values_schema, \n", - " core_schema.list_schema(values_schema)\n", - " ]),\n", - " \"dtypes\": core_schema.union_schema([\n", - " core_schema.none_schema(), \n", - " core_schema.dict_schema(\n", - " keys_schema=core_schema.str_schema(), \n", - " values_schema=core_schema.str_schema()\n", - " )\n", - " ])\n", - " })\n", - " \n", + "\n", " # Create a schema that validates the DataFrame representation\n", - " from_df_dict_schema = core_schema.chain_schema([\n", - " core_schema.dict_schema(),\n", - " core_schema.no_info_plain_validator_function(validate_to_dataframe_schema),\n", - " ])\n", + " from_df_dict_schema = core_schema.chain_schema(\n", + " [\n", + " core_schema.dict_schema(),\n", + " core_schema.no_info_plain_validator_function(\n", + " validate_to_dataframe_schema\n", + " ),\n", + " ]\n", + " )\n", "\n", " return core_schema.json_or_python_schema(\n", " json_schema=from_df_dict_schema,\n", - " python_schema=core_schema.union_schema([\n", - " # Check if it's a DataFrame instance first\n", - " core_schema.is_instance_schema(pd.DataFrame),\n", - " from_df_dict_schema,\n", - " ]),\n", + " python_schema=core_schema.union_schema(\n", + " [\n", + " # Check if it's a DataFrame instance first\n", + " core_schema.is_instance_schema(pd.DataFrame),\n", + " from_df_dict_schema,\n", + " ]\n", + " ),\n", " serialization=core_schema.plain_serializer_function_ser_schema(\n", " dump_df_to_dict\n", " ),\n", @@ -2374,40 +2437,36 @@ " # Use the predefined dataframe_json_schema\n", " values_schema = {\n", " \"type\": \"object\",\n", - " \"additionalProperties\": {\"type\": [\"number\", \"string\", \"boolean\", \"null\"]}\n", + " \"additionalProperties\": {\"type\": [\"number\", \"string\", \"boolean\", \"null\"]},\n", " }\n", - " \n", + "\n", " return {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"type\": {\"const\": \"DataFrame\"},\n", " \"values\": {\n", - " \"oneOf\": [\n", - " values_schema,\n", - " {\"type\": \"array\", \"items\": values_schema}\n", - " ]\n", + " \"oneOf\": [values_schema, {\"type\": \"array\", \"items\": values_schema}]\n", " },\n", " \"dtypes\": {\n", " \"type\": [\"object\", \"null\"],\n", - " \"additionalProperties\": {\"type\": \"string\"}\n", - " }\n", + " \"additionalProperties\": {\"type\": \"string\"},\n", + " },\n", " },\n", - " \"required\": [\"type\", \"values\"]\n", + " \"required\": [\"type\", \"values\"],\n", " }\n", "\n", + "\n", "# Type annotation for easy use\n", "DataFrame = Annotated[pd.DataFrame, _PandasDataFramePydanticAnnotation]\n", "\n", + "\n", "# Example usage\n", "class MyModel(BaseModel):\n", " data: DataFrame\n", "\n", "\n", "# Create sample DataFrame\n", - "df = pd.DataFrame({\n", - " 'A': [1, 2, 3],\n", - " 'B': [4, 5, 6]\n", - "})\n", + "df = pd.DataFrame({\"A\": [1, 2, 3], \"B\": [4, 5, 6]})\n", "\n", "# Create a model instance\n", "model = MyModel(data=df)\n", @@ -2421,8 +2480,7 @@ "\n", "# Verify reconstruction\n", "print(\"\\nReconstructed DataFrame:\")\n", - "print(reconstructed_model.data)\n", - "\n" + "print(reconstructed_model.data)" ] }, { @@ -2446,10 +2504,10 @@ "from pydantic_core import SchemaValidator, core_schema\n", "\n", "wrapper_schema = core_schema.model_fields_schema(\n", - " {'a': core_schema.model_field(core_schema.str_schema())}\n", + " {\"a\": core_schema.model_field(core_schema.str_schema())}\n", ")\n", "v = SchemaValidator(dataframe_json_schema)\n", - "print(v.validate_python({'a': 'hello'}))" + "print(v.validate_python({\"a\": \"hello\"}))" ] }, { @@ -2472,14 +2530,30 @@ "source": [ "from pydantic_core import SchemaValidator, core_schema\n", "\n", - "dataframe_json_schema = core_schema.model_fields_schema({\n", - " # \"value\": core_schema.model_field(core_schema.dict_schema(keys_schema=core_schema.str_schema()))\n", - " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", - " \"values\": core_schema.model_field(core_schema.union_schema([values_schema, core_schema.list_schema(values_schema)])),\n", - " \"dtypes\": core_schema.model_field(core_schema.union_schema([core_schema.none_schema(), core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema())]))\n", - "})\n", + "dataframe_json_schema = core_schema.model_fields_schema(\n", + " {\n", + " # \"value\": core_schema.model_field(core_schema.dict_schema(keys_schema=core_schema.str_schema()))\n", + " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", + " \"values\": core_schema.model_field(\n", + " core_schema.union_schema(\n", + " [values_schema, core_schema.list_schema(values_schema)]\n", + " )\n", + " ),\n", + " \"dtypes\": core_schema.model_field(\n", + " core_schema.union_schema(\n", + " [\n", + " core_schema.none_schema(),\n", + " core_schema.dict_schema(\n", + " keys_schema=core_schema.str_schema(),\n", + " values_schema=core_schema.str_schema(),\n", + " ),\n", + " ]\n", + " )\n", + " ),\n", + " }\n", + ")\n", "v = SchemaValidator(dataframe_json_schema)\n", - "print(v.validate_python({'a': 'hello'}))" + "print(v.validate_python({\"a\": \"hello\"}))" ] }, { @@ -2526,11 +2600,28 @@ ], "source": [ "from pydantic_core import SchemaValidator, core_schema\n", - "dataframe_json_schema = core_schema.model_fields_schema({\n", - " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", - " \"values\": core_schema.model_field(core_schema.union_schema([values_schema, core_schema.list_schema(values_schema)])),\n", - " \"dtypes\": core_schema.model_field(core_schema.union_schema([core_schema.none_schema(), core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema())]))\n", - "})\n", + "\n", + "dataframe_json_schema = core_schema.model_fields_schema(\n", + " {\n", + " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", + " \"values\": core_schema.model_field(\n", + " core_schema.union_schema(\n", + " [values_schema, core_schema.list_schema(values_schema)]\n", + " )\n", + " ),\n", + " \"dtypes\": core_schema.model_field(\n", + " core_schema.union_schema(\n", + " [\n", + " core_schema.none_schema(),\n", + " core_schema.dict_schema(\n", + " keys_schema=core_schema.str_schema(),\n", + " values_schema=core_schema.str_schema(),\n", + " ),\n", + " ]\n", + " )\n", + " ),\n", + " }\n", + ")\n", "\n", "\n", "def validate_to_dataframe_schema(value: dict) -> pd.DataFrame:\n", @@ -2555,7 +2646,15 @@ "\n", "\n", "v = SchemaValidator(from_df_dict_schema)\n", - "print(v.validate_python({'type': 'DataFrame', 'values': {'A': 1, 'B': 2}, 'dtypes': {'A': 'int64', 'B': 'int64'}}))" + "print(\n", + " v.validate_python(\n", + " {\n", + " \"type\": \"DataFrame\",\n", + " \"values\": {\"A\": 1, \"B\": 2},\n", + " \"dtypes\": {\"A\": \"int64\", \"B\": \"int64\"},\n", + " }\n", + " )\n", + ")" ] }, { @@ -2572,16 +2671,38 @@ } ], "source": [ - "\n", - "\n", - "dataframe_json_schema = core_schema.model_fields_schema({\n", - " # \"value\": core_schema.model_field(core_schema.dict_schema(keys_schema=core_schema.str_schema()))\n", - " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", - " \"values\": core_schema.model_field(core_schema.union_schema([values_schema, core_schema.list_schema(values_schema)])),\n", - " \"dtypes\": core_schema.model_field(core_schema.union_schema([core_schema.none_schema(), core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema())]))\n", - "})\n", + "dataframe_json_schema = core_schema.model_fields_schema(\n", + " {\n", + " # \"value\": core_schema.model_field(core_schema.dict_schema(keys_schema=core_schema.str_schema()))\n", + " \"type\": core_schema.model_field(core_schema.literal_schema([\"DataFrame\"])),\n", + " \"values\": core_schema.model_field(\n", + " core_schema.union_schema(\n", + " [values_schema, core_schema.list_schema(values_schema)]\n", + " )\n", + " ),\n", + " \"dtypes\": core_schema.model_field(\n", + " core_schema.union_schema(\n", + " [\n", + " core_schema.none_schema(),\n", + " core_schema.dict_schema(\n", + " keys_schema=core_schema.str_schema(),\n", + " values_schema=core_schema.str_schema(),\n", + " ),\n", + " ]\n", + " )\n", + " ),\n", + " }\n", + ")\n", "v = SchemaValidator(dataframe_json_schema)\n", - "print(v.validate_python({'type': 'DataFrame', 'values': {'A': 1, 'B': 2}, 'dtypes': {'A': 'int64', 'B': 'int64'}}))" + "print(\n", + " v.validate_python(\n", + " {\n", + " \"type\": \"DataFrame\",\n", + " \"values\": {\"A\": 1, \"B\": 2},\n", + " \"dtypes\": {\"A\": \"int64\", \"B\": \"int64\"},\n", + " }\n", + " )\n", + ")" ] }, { From 969e4f0ae31c1a04c20d3409cd9c4d7ffe135ba4 Mon Sep 17 00:00:00 2001 From: Sholto Armstrong Date: Thu, 28 May 2026 09:45:35 +0200 Subject: [PATCH 3/5] fix test failing due to newer pandas --- tests/unit/test_treelite.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_treelite.py b/tests/unit/test_treelite.py index e6a27bf..65da081 100644 --- a/tests/unit/test_treelite.py +++ b/tests/unit/test_treelite.py @@ -65,7 +65,9 @@ def test_parse_and_compile(basic_treelite_config: str): assert compiled_tree.output_priority_mapping.shape == (3,) assert compiled_tree.leaf_output_mapping.shape == (3,) - assert (compiled_tree.output_dataframe_mapping.dtypes == ["O", "O", "O"]).all() + assert all( + str(dt) in ("object", "string") for dt in compiled_tree.output_dataframe_mapping.dtypes + ) leaf_priorities = [ compiled_tree.output_priority_mapping[compiled_tree.node_id_mapping["1"]], compiled_tree.output_priority_mapping[ From 9d2e192d71b7f87aff149608c61e765163b920db Mon Sep 17 00:00:00 2001 From: Sholto Armstrong Date: Thu, 28 May 2026 09:55:21 +0200 Subject: [PATCH 4/5] formatting --- tests/unit/test_treelite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_treelite.py b/tests/unit/test_treelite.py index 65da081..0e5c737 100644 --- a/tests/unit/test_treelite.py +++ b/tests/unit/test_treelite.py @@ -66,7 +66,8 @@ def test_parse_and_compile(basic_treelite_config: str): assert compiled_tree.output_priority_mapping.shape == (3,) assert compiled_tree.leaf_output_mapping.shape == (3,) assert all( - str(dt) in ("object", "string") for dt in compiled_tree.output_dataframe_mapping.dtypes + str(dt) in ("object", "string") + for dt in compiled_tree.output_dataframe_mapping.dtypes ) leaf_priorities = [ compiled_tree.output_priority_mapping[compiled_tree.node_id_mapping["1"]], From b15fddec5f9fe279cebd4bcd45a9614b5de23051 Mon Sep 17 00:00:00 2001 From: Sholto Armstrong Date: Thu, 28 May 2026 09:59:54 +0200 Subject: [PATCH 5/5] the test is not needed --- tests/unit/test_treelite.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit/test_treelite.py b/tests/unit/test_treelite.py index 0e5c737..6a9625b 100644 --- a/tests/unit/test_treelite.py +++ b/tests/unit/test_treelite.py @@ -65,10 +65,6 @@ def test_parse_and_compile(basic_treelite_config: str): assert compiled_tree.output_priority_mapping.shape == (3,) assert compiled_tree.leaf_output_mapping.shape == (3,) - assert all( - str(dt) in ("object", "string") - for dt in compiled_tree.output_dataframe_mapping.dtypes - ) leaf_priorities = [ compiled_tree.output_priority_mapping[compiled_tree.node_id_mapping["1"]], compiled_tree.output_priority_mapping[