Add CLI binaries for statistical models and serialization for CITrees#45
Add CLI binaries for statistical models and serialization for CITrees#45
Conversation
Add command-line utilities for stochastic forests, conditional random forests, and conditional inference trees (train + predict for each). All binaries use ArgumentHandler for CLI args, support CSV/TSV input with auto-detection, and follow existing repository conventions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add write_to/read_from serialization to ConditionalInferenceTrees - Create ygor_stochastic_forest_train and ygor_stochastic_forest_predict - Create ygor_conditional_forest_train and ygor_conditional_forest_predict - Create ygor_ci_tree_train and ygor_ci_tree_predict - Update binaries/CMakeLists.txt to build and install all new binaries Co-authored-by: hdclark <934858+hdclark@users.noreply.github.com> Agent-Logs-Url: https://github.com/hdclark/Ygor/sessions/a5269fa4-9646-45da-9023-4ba7802d958c
…ader flag Co-authored-by: hdclark <934858+hdclark@users.noreply.github.com> Agent-Logs-Url: https://github.com/hdclark/Ygor/sessions/a5269fa4-9646-45da-9023-4ba7802d958c
There was a problem hiding this comment.
Pull request overview
This PR exposes Ygor’s statistical tree/forest models to end-users via six new command-line binaries and adds text-stream serialization support to Stats::ConditionalInferenceTrees, aligning it with the existing serialization patterns used by the forest models.
Changes:
- Added
write_to()/read_from()(and recursive node serialization helpers) forConditionalInferenceTrees. - Introduced six new
ygor_*CLI binaries to train/predict with stochastic forests, conditional random forests, and CI trees from CSV/TSV input. - Updated
binaries/CMakeLists.txtto build, link, and install the new executables.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| src/YgorStatsCITrees.h | Declares CI tree serialization APIs and helpers. |
| src/YgorStatsCITrees.cc | Implements CI tree serialization/deserialization logic. |
| binaries/Ygor_Stochastic_Forest_Train.cc | New CLI trainer for stochastic forests (CSV/TSV ingestion + model save). |
| binaries/Ygor_Stochastic_Forest_Predict.cc | New CLI predictor for stochastic forests (model load + per-row predictions). |
| binaries/Ygor_Conditional_Forest_Train.cc | New CLI trainer for conditional random forests (hyperparameters + model save). |
| binaries/Ygor_Conditional_Forest_Predict.cc | New CLI predictor for conditional random forests. |
| binaries/Ygor_CI_Tree_Train.cc | New CLI trainer for conditional inference trees (hyperparameters + model save). |
| binaries/Ygor_CI_Tree_Predict.cc | New CLI predictor for conditional inference trees. |
| binaries/CMakeLists.txt | Adds build/link/install rules for the six new CLI targets. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| }else if(node_type == "I"){ | ||
| node->is_leaf = false; | ||
| is >> node->split_feature; | ||
| std::string thresh_str; | ||
| is >> thresh_str; | ||
| if(is.fail()) return nullptr; | ||
| node->split_threshold = static_cast<T>(std::stold(thresh_str)); | ||
| node->left = read_tree_node(is); | ||
| node->right = read_tree_node(is); | ||
| if(!node->left || !node->right) return nullptr; |
There was a problem hiding this comment.
read_tree_node() accepts split_feature from the stream without validating it against n_features_trained. A malformed/hostile model file could deserialize a negative or out-of-range feature index, which can later trigger out-of-bounds access in predict_tree() (via x.read_coeff(0, node->split_feature)). Add a bounds check during deserialization (after n_features_trained is read) and fail the load if the index is invalid.
|
|
||
| is >> label >> this->random_seed; | ||
| if(is.fail() || label != "random_seed") return false; | ||
|
|
There was a problem hiding this comment.
read_from() assigns serialized hyperparameters directly to members without re-validating the invariants enforced by the constructor (e.g., max_depth > 0, min_samples_split >= 2, alpha in (0,1), n_permutations > 0). This can leave the object in an invalid state after loading an incorrect/corrupted model file. Consider validating these ranges immediately after parsing and returning false if any are invalid.
| // Validate hyperparameter invariants after deserialization. | |
| if(this->max_depth <= 0){ | |
| return false; | |
| } | |
| if(this->min_samples_split < 2){ | |
| return false; | |
| } | |
| if(this->alpha <= static_cast<T>(0) || static_cast<T>(1) <= this->alpha){ | |
| return false; | |
| } | |
| if(this->n_permutations <= 0){ | |
| return false; | |
| } |
| std::stringstream ss(line); | ||
| std::string token; | ||
| while(std::getline(ss, token, delimiter)){ | ||
| vals.push_back(std::stod(token)); | ||
| } |
There was a problem hiding this comment.
CSV/TSV numeric parsing uses std::stod(token) without handling std::invalid_argument/std::out_of_range. If the input contains a non-numeric token (including empty fields), this will throw and terminate the program without a clear, row/column-specific error. Wrap conversions in a try/catch and rethrow a runtime_error that includes row/column context (and consider trimming whitespace).
| std::stringstream ss(line); | ||
| std::string token; | ||
| while(std::getline(ss, token, delimiter)){ | ||
| vals.push_back(std::stod(token)); | ||
| } |
There was a problem hiding this comment.
CSV/TSV numeric parsing uses std::stod(token) without handling std::invalid_argument/std::out_of_range. If the input contains a non-numeric token (including empty fields), this will throw and terminate the program without a clear, row/column-specific error. Wrap conversions in a try/catch and rethrow a runtime_error that includes row/column context (and consider trimming whitespace).
| std::stringstream ss(line); | ||
| std::string token; | ||
| while(std::getline(ss, token, delimiter)){ | ||
| vals.push_back(std::stod(token)); | ||
| } |
There was a problem hiding this comment.
CSV/TSV numeric parsing uses std::stod(token) without handling std::invalid_argument/std::out_of_range. If the input contains a non-numeric token (including empty fields), this will throw and terminate the program without a clear, row/column-specific error. Wrap conversions in a try/catch and rethrow a runtime_error that includes row/column context (and consider trimming whitespace).
| bool Stats::ConditionalInferenceTrees<T>::write_to(std::ostream &os) const { | ||
| const auto original_precision = os.precision(); | ||
| os.precision( std::numeric_limits<T>::max_digits10 ); | ||
|
|
There was a problem hiding this comment.
write_to()/read_from() add a new persistence format for ConditionalInferenceTrees, but there are currently no unit tests exercising round-trip correctness or invalid-input rejection (unlike StochasticForests and ConditionalRandomForests, which have serialization tests under tests2/). Adding tests in tests2/YgorStatsCITrees.cc for round-tripping predictions and for rejecting malformed streams would help prevent format regressions.
| min_samples_split = std::stoll(optarg); | ||
| })); | ||
| arger.push_back(std::make_tuple(2, 'f', "max-features", true, "<int>", | ||
| "Maximum features per split; -1 for all (default: -1).", |
There was a problem hiding this comment.
The --max-features help text says "-1 for all", but Stats::StochasticForests treats max_features <= 0 as sqrt(n_features) (see library docs/implementation). This makes the CLI misleading and will produce different behavior than users expect. Update the flag semantics and/or help text (e.g., document <=0 as sqrt, and provide an explicit way to request "all features").
| "Maximum features per split; -1 for all (default: -1).", | |
| "Maximum features per split; <=0 uses sqrt(number of features) (default: -1).", |
| std::stringstream ss(line); | ||
| std::string token; | ||
| while(std::getline(ss, token, delimiter)){ | ||
| vals.push_back(std::stod(token)); | ||
| } |
There was a problem hiding this comment.
CSV/TSV numeric parsing uses std::stod(token) without handling std::invalid_argument/std::out_of_range. If the input contains a non-numeric token (including empty fields), this will throw and terminate the program without a clear, row/column-specific error. Wrap conversions in a try/catch and rethrow a runtime_error that includes row/column context (and consider trimming whitespace).
| std::stringstream ss(line); | ||
| std::string token; | ||
| while(std::getline(ss, token, delimiter)){ | ||
| vals.push_back(std::stod(token)); | ||
| } |
There was a problem hiding this comment.
CSV/TSV numeric parsing uses std::stod(token) without handling std::invalid_argument/std::out_of_range. If the input contains a non-numeric token (including empty fields), this will throw and terminate the program without a clear, row/column-specific error. Wrap conversions in a try/catch and rethrow a runtime_error that includes row/column context (and consider trimming whitespace).
| std::stringstream ss(line); | ||
| std::string token; | ||
| while(std::getline(ss, token, delimiter)){ | ||
| vals.push_back(std::stod(token)); | ||
| } |
There was a problem hiding this comment.
CSV/TSV numeric parsing uses std::stod(token) without handling std::invalid_argument/std::out_of_range. If the input contains a non-numeric token (including empty fields), this will throw and terminate the program without a clear, row/column-specific error. Wrap conversions in a try/catch and rethrow a runtime_error that includes row/column context (and consider trimming whitespace).
Expose stochastic forests, conditional forests, and conditional inference trees to end-users via command-line binaries, without requiring C++ development.
New binaries
Six binaries, all prefixed
ygor_and installed toCMAKE_INSTALL_BINDIR:ygor_stochastic_forest_trainygor_stochastic_forest_predictygor_conditional_forest_trainygor_conditional_forest_predictygor_ci_tree_trainygor_ci_tree_predictAll binaries use
ArgumentHandlerfor CLI parsing, accept CSV/TSV with auto-detection (tab presence), and support--headerto skip a header row. All model hyperparameters are exposed as CLI flags.Training binaries: last column = response, preceding columns = features. Displays feature importances post-fit when applicable. Prediction binaries: all columns = features, outputs one prediction per row to stdout.
CITrees serialization
ConditionalInferenceTreeslackedwrite_to/read_from—added following the identical pattern used byStochasticForestsandConditionalRandomForests(versioned header, parameter block, recursive tree node serialization withL/Inode types, RAII precision guard).Build
Updated
binaries/CMakeLists.txtwith all six targets linked againstygor,m,Threads::Threadsand added to the install target.💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.