diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt index 19c8483..00a45ac 100644 --- a/binaries/CMakeLists.txt +++ b/binaries/CMakeLists.txt @@ -48,13 +48,90 @@ target_link_libraries(parse_TAR_files Threads::Threads ) +add_executable(ygor_stochastic_forest_train + Ygor_Stochastic_Forest_Train.cc +) +target_include_directories(ygor_stochastic_forest_train + SYSTEM PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) +target_link_libraries(ygor_stochastic_forest_train + ygor + m + Threads::Threads +) + +add_executable(ygor_stochastic_forest_predict + Ygor_Stochastic_Forest_Predict.cc +) +target_include_directories(ygor_stochastic_forest_predict + SYSTEM PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) +target_link_libraries(ygor_stochastic_forest_predict + ygor + m + Threads::Threads +) + +add_executable(ygor_conditional_forest_train + Ygor_Conditional_Forest_Train.cc +) +target_include_directories(ygor_conditional_forest_train + SYSTEM PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) +target_link_libraries(ygor_conditional_forest_train + ygor + m + Threads::Threads +) + +add_executable(ygor_conditional_forest_predict + Ygor_Conditional_Forest_Predict.cc +) +target_include_directories(ygor_conditional_forest_predict + SYSTEM PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) +target_link_libraries(ygor_conditional_forest_predict + ygor + m + Threads::Threads +) + +add_executable(ygor_ci_tree_train + Ygor_CI_Tree_Train.cc +) +target_include_directories(ygor_ci_tree_train + SYSTEM PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) +target_link_libraries(ygor_ci_tree_train + ygor + m + Threads::Threads +) + +add_executable(ygor_ci_tree_predict + Ygor_CI_Tree_Predict.cc +) +target_include_directories(ygor_ci_tree_predict + SYSTEM PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) +target_link_libraries(ygor_ci_tree_predict + ygor + m + Threads::Threads +) + install(TARGETS fits_replace_nans twot_pvalue regex_tester parse_TAR_files + ygor_stochastic_forest_train + ygor_stochastic_forest_predict + ygor_conditional_forest_train + ygor_conditional_forest_predict + ygor_ci_tree_train + ygor_ci_tree_predict ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ) - diff --git a/binaries/Ygor_CI_Tree_Predict.cc b/binaries/Ygor_CI_Tree_Predict.cc new file mode 100644 index 0000000..ae33756 --- /dev/null +++ b/binaries/Ygor_CI_Tree_Predict.cc @@ -0,0 +1,87 @@ +//Ygor_CI_Tree_Predict.cc -- A command-line utility to predict using a trained conditional inference tree model. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "YgorArguments.h" +#include "YgorMath.h" +#include "YgorMathIOCSV.h" +#include "YgorStatsCITrees.h" + +int main(int argc, char **argv){ + + std::string model_file; + std::string input_file; + bool has_header = false; + + ArgumentHandler arger; + arger.description = "Predict using a trained conditional inference tree model."; + + arger.push_back(std::make_tuple(1, 'm', "model", true, "", + "Trained model file to load.", + [&](const std::string &optarg) -> void { + model_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'i', "input", true, "", + "Input CSV/TSV file with feature values.", + [&](const std::string &optarg) -> void { + input_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'H', "header", false, "", + "Indicate that the first row is a header (will be skipped).", + [&](const std::string &) -> void { + has_header = true; + })); + + arger.Launch(argc, argv); + + if(model_file.empty()){ + throw std::runtime_error("A model file must be specified via -m or --model."); + } + if(input_file.empty()){ + throw std::runtime_error("An input file must be specified via -i or --input."); + } + + // Load the model. + Stats::ConditionalInferenceTrees model; + { + std::ifstream fm(model_file); + if(!fm.good()){ + throw std::runtime_error("Unable to open model file '" + model_file + "'."); + } + if(!model.read_from(fm)){ + throw std::runtime_error("Failed to read model from '" + model_file + "'."); + } + } + + // Read the input file. + std::ifstream fi(input_file); + if(!fi.good()){ + throw std::runtime_error("Unable to open input file '" + input_file + "'."); + } + + auto csv_result = ReadNumArrayFromCSV(fi, has_header); + const auto &all_data = csv_result.data; + const int64_t n_rows = all_data.num_rows(); + const int64_t n_cols = all_data.num_cols(); + + // Reuse a single 1 x n_cols buffer for each prediction to avoid + // repeatedly allocating and copying subarrays. + num_array x(1, n_cols); + for(int64_t r = 0; r < n_rows; ++r){ + for(int64_t c = 0; c < n_cols; ++c){ + x(0, c) = all_data(r, c); + } + double prediction = model.predict(x); + std::cout << prediction << std::endl; + } + + return 0; +} diff --git a/binaries/Ygor_CI_Tree_Train.cc b/binaries/Ygor_CI_Tree_Train.cc new file mode 100644 index 0000000..9c734b8 --- /dev/null +++ b/binaries/Ygor_CI_Tree_Train.cc @@ -0,0 +1,119 @@ +//Ygor_CI_Tree_Train.cc -- A command-line utility to train a conditional inference tree model. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "YgorArguments.h" +#include "YgorMath.h" +#include "YgorMathIOCSV.h" +#include "YgorStatsCITrees.h" + +int main(int argc, char **argv){ + + std::string input_file; + std::string output_file; + bool has_header = false; + int64_t max_depth = 10; + int64_t min_samples_split = 2; + double alpha = 0.05; + int64_t n_permutations = 1000; + uint64_t random_seed = 42; + + ArgumentHandler arger; + arger.description = "Train a conditional inference tree model from tabular data (CSV/TSV)."; + + arger.push_back(std::make_tuple(1, 'i', "input", true, "", + "Input CSV/TSV file. Last column is the response variable.", + [&](const std::string &optarg) -> void { + input_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'o', "output", true, "", + "Output file for the trained model.", + [&](const std::string &optarg) -> void { + output_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'H', "header", false, "", + "Indicate that the first row is a header (will be skipped).", + [&](const std::string &) -> void { + has_header = true; + })); + arger.push_back(std::make_tuple(2, 'd', "max-depth", true, "", + "Maximum tree depth (default: 10).", + [&](const std::string &optarg) -> void { + max_depth = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 's', "min-samples-split", true, "", + "Minimum samples to split a node (default: 2).", + [&](const std::string &optarg) -> void { + min_samples_split = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'a', "alpha", true, "", + "Significance level for conditional inference tests (default: 0.05).", + [&](const std::string &optarg) -> void { + alpha = std::stod(optarg); + })); + arger.push_back(std::make_tuple(2, 'p', "n-permutations", true, "", + "Number of permutations for hypothesis tests (default: 1000).", + [&](const std::string &optarg) -> void { + n_permutations = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'r', "random-seed", true, "", + "Random seed (default: 42).", + [&](const std::string &optarg) -> void { + random_seed = std::stoull(optarg); + })); + + arger.Launch(argc, argv); + + if(input_file.empty()){ + throw std::runtime_error("An input file must be specified via -i or --input."); + } + if(output_file.empty()){ + throw std::runtime_error("An output file must be specified via -o or --output."); + } + + // Read the input file. + std::ifstream fi(input_file); + if(!fi.good()){ + throw std::runtime_error("Unable to open input file '" + input_file + "'."); + } + + auto csv_result = ReadNumArrayFromCSV(fi, has_header); + const auto &all_data = csv_result.data; + const int64_t n_rows = all_data.num_rows(); + const int64_t n_cols = all_data.num_cols(); + if(n_cols < 2){ + throw std::runtime_error("Input must have at least two columns (features + response)."); + } + const int64_t n_features = n_cols - 1; + + num_array X = all_data.subarray(0, n_rows, 0, n_features); + num_array y = all_data.subarray(0, n_rows, n_features, n_cols); + + std::cout << "Training conditional inference tree with " << n_rows << " samples and " + << n_features << " features." << std::endl; + + // Train the model. + Stats::ConditionalInferenceTrees model(max_depth, min_samples_split, alpha, + n_permutations, random_seed); + model.fit(X, y); + + // Save the model. + std::ofstream fo(output_file); + if(!fo.good()){ + throw std::runtime_error("Unable to open output file '" + output_file + "'."); + } + if(!model.write_to(fo)){ + throw std::runtime_error("Failed to write model to output file."); + } + std::cout << "Model saved to '" << output_file << "'." << std::endl; + + return 0; +} diff --git a/binaries/Ygor_Conditional_Forest_Predict.cc b/binaries/Ygor_Conditional_Forest_Predict.cc new file mode 100644 index 0000000..1abb520 --- /dev/null +++ b/binaries/Ygor_Conditional_Forest_Predict.cc @@ -0,0 +1,86 @@ +//Ygor_Conditional_Forest_Predict.cc -- A command-line utility to predict using a trained conditional random forest model. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "YgorArguments.h" +#include "YgorMath.h" +#include "YgorMathIOCSV.h" +#include "YgorStatsConditionalForests.h" + +int main(int argc, char **argv){ + + std::string model_file; + std::string input_file; + bool has_header = false; + + ArgumentHandler arger; + arger.description = "Predict using a trained conditional random forest model."; + + arger.push_back(std::make_tuple(1, 'm', "model", true, "", + "Trained model file to load.", + [&](const std::string &optarg) -> void { + model_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'i', "input", true, "", + "Input CSV/TSV file with feature values.", + [&](const std::string &optarg) -> void { + input_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'H', "header", false, "", + "Indicate that the first row is a header (will be skipped).", + [&](const std::string &) -> void { + has_header = true; + })); + + arger.Launch(argc, argv); + + if(model_file.empty()){ + throw std::runtime_error("A model file must be specified via -m or --model."); + } + if(input_file.empty()){ + throw std::runtime_error("An input file must be specified via -i or --input."); + } + + // Load the model. + Stats::ConditionalRandomForests model; + { + std::ifstream fm(model_file); + if(!fm.good()){ + throw std::runtime_error("Unable to open model file '" + model_file + "'."); + } + if(!model.read_from(fm)){ + throw std::runtime_error("Failed to read model from '" + model_file + "'."); + } + } + + // Read the input file. + std::ifstream fi(input_file); + if(!fi.good()){ + throw std::runtime_error("Unable to open input file '" + input_file + "'."); + } + + auto csv_result = ReadNumArrayFromCSV(fi, has_header); + const auto &all_data = csv_result.data; + const int64_t n_rows = all_data.num_rows(); + const int64_t n_cols = all_data.num_cols(); + + // Reuse a single 1xN buffer to avoid per-row subarray allocation and copy. + num_array x(1, n_cols); + for(int64_t r = 0; r < n_rows; ++r){ + for(int64_t c = 0; c < n_cols; ++c){ + x(0, c) = all_data(r, c); + } + double prediction = model.predict(x); + std::cout << prediction << std::endl; + } + + return 0; +} diff --git a/binaries/Ygor_Conditional_Forest_Train.cc b/binaries/Ygor_Conditional_Forest_Train.cc new file mode 100644 index 0000000..5e4787c --- /dev/null +++ b/binaries/Ygor_Conditional_Forest_Train.cc @@ -0,0 +1,174 @@ +//Ygor_Conditional_Forest_Train.cc -- A command-line utility to train a conditional random forest model. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "YgorArguments.h" +#include "YgorMath.h" +#include "YgorMathIOCSV.h" +#include "YgorStatsConditionalForests.h" + +int main(int argc, char **argv){ + + std::string input_file; + std::string output_file; + bool has_header = false; + int64_t n_trees = 100; + int64_t max_depth = 10; + int64_t min_samples_split = 2; + double alpha = 0.05; + int64_t n_permutations = 1000; + int64_t max_features = -1; + double subsample_fraction = 0.632; + double correlation_threshold = 0.20; + uint64_t random_seed = 42; + std::string importance_str = "none"; + + ArgumentHandler arger; + arger.description = "Train a conditional random forest model from tabular data (CSV/TSV)."; + + arger.push_back(std::make_tuple(1, 'i', "input", true, "", + "Input CSV/TSV file. Last column is the response variable.", + [&](const std::string &optarg) -> void { + input_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'o', "output", true, "", + "Output file for the trained model.", + [&](const std::string &optarg) -> void { + output_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'H', "header", false, "", + "Indicate that the first row is a header (will be skipped).", + [&](const std::string &) -> void { + has_header = true; + })); + arger.push_back(std::make_tuple(2, 't', "n-trees", true, "", + "Number of trees (default: 100).", + [&](const std::string &optarg) -> void { + n_trees = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'd', "max-depth", true, "", + "Maximum tree depth (default: 10).", + [&](const std::string &optarg) -> void { + max_depth = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 's', "min-samples-split", true, "", + "Minimum samples to split a node (default: 2).", + [&](const std::string &optarg) -> void { + min_samples_split = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'a', "alpha", true, "", + "Significance level for conditional inference tests (default: 0.05).", + [&](const std::string &optarg) -> void { + alpha = std::stod(optarg); + })); + arger.push_back(std::make_tuple(2, 'p', "n-permutations", true, "", + "Number of permutations for hypothesis tests (default: 1000).", + [&](const std::string &optarg) -> void { + n_permutations = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'f', "max-features", true, "", + "Maximum features per split; -1 for all (default: -1).", + [&](const std::string &optarg) -> void { + max_features = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'S', "subsample-fraction", true, "", + "Fraction of samples to draw per tree (default: 0.632).", + [&](const std::string &optarg) -> void { + subsample_fraction = std::stod(optarg); + })); + arger.push_back(std::make_tuple(2, 'c', "correlation-threshold", true, "", + "Correlation threshold for conditional importance (default: 0.20).", + [&](const std::string &optarg) -> void { + correlation_threshold = std::stod(optarg); + })); + arger.push_back(std::make_tuple(2, 'r', "random-seed", true, "", + "Random seed (default: 42).", + [&](const std::string &optarg) -> void { + random_seed = std::stoull(optarg); + })); + arger.push_back(std::make_tuple(2, 'I', "importance", true, "", + "Feature importance method: none, permutation, or conditional (default: none).", + [&](const std::string &optarg) -> void { + importance_str = optarg; + })); + + arger.Launch(argc, argv); + + if(input_file.empty()){ + throw std::runtime_error("An input file must be specified via -i or --input."); + } + if(output_file.empty()){ + throw std::runtime_error("An output file must be specified via -o or --output."); + } + + // Parse the importance method. + Stats::ConditionalImportanceMethod importance_method = Stats::ConditionalImportanceMethod::none; + if(importance_str == "none"){ + importance_method = Stats::ConditionalImportanceMethod::none; + }else if(importance_str == "permutation"){ + importance_method = Stats::ConditionalImportanceMethod::permutation; + }else if(importance_str == "conditional"){ + importance_method = Stats::ConditionalImportanceMethod::conditional; + }else{ + throw std::runtime_error("Unknown importance method '" + importance_str + "'. Use none, permutation, or conditional."); + } + + // Read the input file. + std::ifstream fi(input_file); + if(!fi.good()){ + throw std::runtime_error("Unable to open input file '" + input_file + "'."); + } + + auto csv_result = ReadNumArrayFromCSV(fi, has_header); + const auto &all_data = csv_result.data; + const int64_t n_rows = all_data.num_rows(); + const int64_t n_cols = all_data.num_cols(); + if(n_cols < 2){ + throw std::runtime_error("Input must have at least two columns (features + response)."); + } + const int64_t n_features = n_cols - 1; + + num_array X = all_data.subarray(0, n_rows, 0, n_features); + num_array y = all_data.subarray(0, n_rows, n_features, n_cols); + + std::cout << "Training conditional random forest with " << n_rows << " samples and " + << n_features << " features." << std::endl; + + // Train the model. + Stats::ConditionalRandomForests model(n_trees, max_depth, min_samples_split, + alpha, n_permutations, max_features, + subsample_fraction, correlation_threshold, + random_seed); + model.set_importance_method(importance_method); + model.fit(X, y); + + // Compute and display feature importances. + if(importance_method != Stats::ConditionalImportanceMethod::none){ + model.compute_importance(X, y); + std::vector importances = model.get_feature_importances(); + std::cout << "Feature importances:" << std::endl; + for(int64_t c = 0; c < static_cast(importances.size()); ++c){ + std::cout << " feature " << c << ": " << importances[c] << std::endl; + } + } + + // Save the model. + std::ofstream fo(output_file); + if(!fo.good()){ + throw std::runtime_error("Unable to open output file '" + output_file + "'."); + } + if(!model.write_to(fo)){ + throw std::runtime_error("Failed to write model to output file."); + } + std::cout << "Model saved to '" << output_file << "'." << std::endl; + + return 0; +} diff --git a/binaries/Ygor_Stochastic_Forest_Predict.cc b/binaries/Ygor_Stochastic_Forest_Predict.cc new file mode 100644 index 0000000..e14b2d2 --- /dev/null +++ b/binaries/Ygor_Stochastic_Forest_Predict.cc @@ -0,0 +1,86 @@ +//Ygor_Stochastic_Forest_Predict.cc -- A command-line utility to predict using a trained stochastic forest model. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "YgorArguments.h" +#include "YgorMath.h" +#include "YgorMathIOCSV.h" +#include "YgorStatsStochasticForests.h" + +int main(int argc, char **argv){ + + std::string model_file; + std::string input_file; + bool has_header = false; + + ArgumentHandler arger; + arger.description = "Predict using a trained stochastic forest model."; + + arger.push_back(std::make_tuple(1, 'm', "model", true, "", + "Trained model file to load.", + [&](const std::string &optarg) -> void { + model_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'i', "input", true, "", + "Input CSV/TSV file with feature values.", + [&](const std::string &optarg) -> void { + input_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'H', "header", false, "", + "Indicate that the first row is a header (will be skipped).", + [&](const std::string &) -> void { + has_header = true; + })); + + arger.Launch(argc, argv); + + if(model_file.empty()){ + throw std::runtime_error("A model file must be specified via -m or --model."); + } + if(input_file.empty()){ + throw std::runtime_error("An input file must be specified via -i or --input."); + } + + // Load the model. + Stats::StochasticForests model; + { + std::ifstream fm(model_file); + if(!fm.good()){ + throw std::runtime_error("Unable to open model file '" + model_file + "'."); + } + if(!model.read_from(fm)){ + throw std::runtime_error("Failed to read model from '" + model_file + "'."); + } + } + + // Read the input file. + std::ifstream fi(input_file); + if(!fi.good()){ + throw std::runtime_error("Unable to open input file '" + input_file + "'."); + } + + auto csv_result = ReadNumArrayFromCSV(fi, has_header); + const auto &all_data = csv_result.data; + const int64_t n_rows = all_data.num_rows(); + const int64_t n_cols = all_data.num_cols(); + + // Reuse a single 1xN buffer for each prediction to avoid per-row allocations. + num_array x(1, n_cols); + for(int64_t r = 0; r < n_rows; ++r){ + for(int64_t c = 0; c < n_cols; ++c){ + x(0, c) = all_data(r, c); + } + double prediction = model.predict(x); + std::cout << prediction << std::endl; + } + + return 0; +} diff --git a/binaries/Ygor_Stochastic_Forest_Train.cc b/binaries/Ygor_Stochastic_Forest_Train.cc new file mode 100644 index 0000000..eaab75f --- /dev/null +++ b/binaries/Ygor_Stochastic_Forest_Train.cc @@ -0,0 +1,149 @@ +//Ygor_Stochastic_Forest_Train.cc -- A command-line utility to train a stochastic forest model. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "YgorArguments.h" +#include "YgorMath.h" +#include "YgorMathIOCSV.h" +#include "YgorStatsStochasticForests.h" + +int main(int argc, char **argv){ + + std::string input_file; + std::string output_file; + bool has_header = false; + int64_t n_trees = 100; + int64_t max_depth = 10; + int64_t min_samples_split = 2; + int64_t max_features = -1; + uint64_t random_seed = 42; + std::string importance_str = "none"; + + ArgumentHandler arger; + arger.description = "Train a stochastic forest model from tabular data (CSV/TSV)."; + + arger.push_back(std::make_tuple(1, 'i', "input", true, "", + "Input CSV/TSV file. Last column is the response variable.", + [&](const std::string &optarg) -> void { + input_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'o', "output", true, "", + "Output file for the trained model.", + [&](const std::string &optarg) -> void { + output_file = optarg; + })); + arger.push_back(std::make_tuple(1, 'H', "header", false, "", + "Indicate that the first row is a header (will be skipped).", + [&](const std::string &) -> void { + has_header = true; + })); + arger.push_back(std::make_tuple(2, 't', "n-trees", true, "", + "Number of trees (default: 100).", + [&](const std::string &optarg) -> void { + n_trees = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'd', "max-depth", true, "", + "Maximum tree depth (default: 10).", + [&](const std::string &optarg) -> void { + max_depth = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 's', "min-samples-split", true, "", + "Minimum samples to split a node (default: 2).", + [&](const std::string &optarg) -> void { + min_samples_split = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'f', "max-features", true, "", + "Maximum features per split; <=0 uses sqrt(number of features) (default: -1).", + [&](const std::string &optarg) -> void { + max_features = std::stoll(optarg); + })); + arger.push_back(std::make_tuple(2, 'r', "random-seed", true, "", + "Random seed (default: 42).", + [&](const std::string &optarg) -> void { + random_seed = std::stoull(optarg); + })); + arger.push_back(std::make_tuple(2, 'I', "importance", true, "", + "Feature importance method: none, gini, or permutation (default: none).", + [&](const std::string &optarg) -> void { + importance_str = optarg; + })); + + arger.Launch(argc, argv); + + if(input_file.empty()){ + throw std::runtime_error("An input file must be specified via -i or --input."); + } + if(output_file.empty()){ + throw std::runtime_error("An output file must be specified via -o or --output."); + } + + // Parse the importance method. + Stats::ImportanceMethod importance_method = Stats::ImportanceMethod::none; + if(importance_str == "none"){ + importance_method = Stats::ImportanceMethod::none; + }else if(importance_str == "gini"){ + importance_method = Stats::ImportanceMethod::gini; + }else if(importance_str == "permutation"){ + importance_method = Stats::ImportanceMethod::permutation; + }else{ + throw std::runtime_error("Unknown importance method '" + importance_str + "'. Use none, gini, or permutation."); + } + + // Read the input file. + std::ifstream fi(input_file); + if(!fi.good()){ + throw std::runtime_error("Unable to open input file '" + input_file + "'."); + } + + auto csv_result = ReadNumArrayFromCSV(fi, has_header); + const auto &all_data = csv_result.data; + const int64_t n_rows = all_data.num_rows(); + const int64_t n_cols = all_data.num_cols(); + if(n_cols < 2){ + throw std::runtime_error("Input must have at least two columns (features + response)."); + } + const int64_t n_features = n_cols - 1; + + num_array X = all_data.subarray(0, n_rows, 0, n_features); + num_array y = all_data.subarray(0, n_rows, n_features, n_cols); + + std::cout << "Training stochastic forest with " << n_rows << " samples and " + << n_features << " features." << std::endl; + + // Train the model. + Stats::StochasticForests model(n_trees, max_depth, min_samples_split, max_features, random_seed); + model.set_importance_method(importance_method); + model.fit(X, y); + + // Compute and display feature importances. + if(importance_method == Stats::ImportanceMethod::permutation){ + model.compute_permutation_importance(X, y); + } + if(importance_method != Stats::ImportanceMethod::none){ + std::vector importances = model.get_feature_importances(); + std::cout << "Feature importances:" << std::endl; + for(int64_t c = 0; c < static_cast(importances.size()); ++c){ + std::cout << " feature " << c << ": " << importances[c] << std::endl; + } + } + + // Save the model. + std::ofstream fo(output_file); + if(!fo.good()){ + throw std::runtime_error("Unable to open output file '" + output_file + "'."); + } + if(!model.write_to(fo)){ + throw std::runtime_error("Failed to write model to output file."); + } + std::cout << "Model saved to '" << output_file << "'." << std::endl; + + return 0; +} diff --git a/src/YgorMath.cc b/src/YgorMath.cc index 1de87bc..12fb526 100644 --- a/src/YgorMath.cc +++ b/src/YgorMath.cc @@ -9214,6 +9214,28 @@ num_array::read_from(std::istream &is){ template bool num_array::read_from(std::istream &); #endif +template +num_array +num_array::subarray(int64_t row_begin, int64_t row_end, int64_t col_begin, int64_t col_end) const { + if( row_begin < 0 || row_end > this->rows || row_begin >= row_end + || col_begin < 0 || col_end > this->cols || col_begin >= col_end ){ + throw std::invalid_argument("Subarray indices are out of range or produce an empty sub-array."); + } + const int64_t sub_rows = row_end - row_begin; + const int64_t sub_cols = col_end - col_begin; + num_array out(sub_rows, sub_cols); + for(int64_t r = 0; r < sub_rows; ++r){ + for(int64_t c = 0; c < sub_cols; ++c){ + out.coeff(r, c) = this->read_coeff(r + row_begin, c + col_begin); + } + } + return out; +} +#ifndef YGORMATH_DISABLE_ALL_SPECIALIZATIONS + template num_array num_array::subarray(int64_t, int64_t, int64_t, int64_t) const; + template num_array num_array::subarray(int64_t, int64_t, int64_t, int64_t) const; +#endif + //--------------------------------------------------------------------------------------------------------------------- //-------------------------- affine_transform: a class that holds an affine transformation ---------------------------- //--------------------------------------------------------------------------------------------------------------------- diff --git a/src/YgorMath.h b/src/YgorMath.h index 2d3a029..d63d078 100644 --- a/src/YgorMath.h +++ b/src/YgorMath.h @@ -780,6 +780,20 @@ class num_array { // Serialize and deserialize to a human- and machine-readable format. bool write_to( std::ostream &os ) const; bool read_from( std::istream &is ); + + // Extract a sub-array (block) from the matrix as a new copy. + // + // Parameters: + // row_begin: Starting row (inclusive). + // row_end: Ending row (exclusive). + // col_begin: Starting column (inclusive). + // col_end: Ending column (exclusive). + // + // Returns a new num_array containing the specified sub-region. + // + // Throws: + // std::invalid_argument if indices are out of range or produce an empty sub-array. + num_array subarray(int64_t row_begin, int64_t row_end, int64_t col_begin, int64_t col_end) const; }; diff --git a/src/YgorMathIOCSV.cc b/src/YgorMathIOCSV.cc new file mode 100644 index 0000000..fcf62a6 --- /dev/null +++ b/src/YgorMathIOCSV.cc @@ -0,0 +1,150 @@ +//YgorMathIOCSV.cc - Routines for reading tabular CSV/TSV data into num_array. +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "YgorDefinitions.h" +#include "YgorMath.h" +#include "YgorMathIOCSV.h" + + +template +csv_load_result +ReadNumArrayFromCSV(std::istream &is, + bool has_header, + csv_non_numeric_callback_t non_numeric_cb){ + + csv_load_result result; + std::map &string_to_int = result.string_to_int; + std::map>> &int_to_locations = result.int_to_locations; + int64_t next_mapped_int = 0; + + // Default callback for non-numeric tokens. + const auto default_cb = [&](const std::string &token, int64_t row, int64_t col) -> double { + // Case-insensitive comparison for NaN. + std::string lower; + lower.reserve(token.size()); + for(const auto &ch : token){ + lower.push_back(static_cast(std::tolower(static_cast(ch)))); + } + + if(lower.empty() || lower == "nan"){ + return static_cast(std::numeric_limits::quiet_NaN()); + } + if(lower == "inf" || lower == "+inf"){ + return static_cast(std::numeric_limits::infinity()); + } + if(lower == "-inf"){ + return static_cast(-std::numeric_limits::infinity()); + } + + // Map distinct strings to distinct integers (case sensitive on the original token). + auto it = string_to_int.find(token); + if(it == string_to_int.end()){ + string_to_int[token] = next_mapped_int; + it = string_to_int.find(token); + ++next_mapped_int; + } + int_to_locations[it->second].emplace_back(row, col); + return static_cast(it->second); + }; + + const auto &cb = non_numeric_cb ? non_numeric_cb : default_cb; + + // Helper to trim leading/trailing whitespace from a token. + const auto trim = [](const std::string &s) -> std::string { + const auto begin = s.find_first_not_of(" \t\r\n"); + if(begin == std::string::npos) return ""; + const auto end = s.find_last_not_of(" \t\r\n"); + return s.substr(begin, end - begin + 1); + }; + + std::vector> rows; + std::string line; + bool first_data_line = true; + char delimiter = ','; + + while(std::getline(is, line)){ + if(line.empty()) continue; + if(first_data_line){ + // Auto-detect delimiter from first non-empty line. + if(line.find('\t') != std::string::npos){ + delimiter = '\t'; + } + if(has_header){ + first_data_line = false; + continue; + } + first_data_line = false; + } + + const int64_t row_idx = static_cast(rows.size()); + std::vector vals; + std::stringstream ss(line); + std::string token; + int64_t col_idx = 0; + while(std::getline(ss, token, delimiter)){ + const std::string trimmed = trim(token); + // First try to parse as a number. + bool parsed = false; + if(!trimmed.empty()){ + try{ + size_t pos = 0; + const double val = std::stod(trimmed, &pos); + if(pos == trimmed.size()){ + vals.push_back(static_cast(val)); + parsed = true; + } + }catch(const std::invalid_argument &){ + }catch(const std::out_of_range &){ + } + } + if(!parsed){ + // Use the callback for non-numeric tokens. + const double mapped_val = cb(trimmed, row_idx, col_idx); + vals.push_back(static_cast(mapped_val)); + } + ++col_idx; + } + if(!vals.empty()){ + rows.push_back(vals); + } + } + + if(rows.empty()){ + throw std::runtime_error("No data rows found in CSV/TSV input."); + } + + const int64_t n_rows = static_cast(rows.size()); + const int64_t n_cols = static_cast(rows.front().size()); + + for(int64_t r = 0; r < n_rows; ++r){ + if(static_cast(rows[r].size()) != n_cols){ + throw std::runtime_error("Row " + std::to_string(r) + " has " + + std::to_string(rows[r].size()) + " columns, expected " + std::to_string(n_cols) + "."); + } + } + + result.data = num_array(n_rows, n_cols, static_cast(0)); + for(int64_t r = 0; r < n_rows; ++r){ + for(int64_t c = 0; c < n_cols; ++c){ + result.data.coeff(r, c) = rows[r][c]; + } + } + + return result; +} +#ifndef YGORMATHIOCSV_DISABLE_ALL_SPECIALIZATIONS + template csv_load_result ReadNumArrayFromCSV(std::istream &, bool, csv_non_numeric_callback_t); + template csv_load_result ReadNumArrayFromCSV(std::istream &, bool, csv_non_numeric_callback_t); +#endif diff --git a/src/YgorMathIOCSV.h b/src/YgorMathIOCSV.h new file mode 100644 index 0000000..f999719 --- /dev/null +++ b/src/YgorMathIOCSV.h @@ -0,0 +1,83 @@ +//YgorMathIOCSV.h - Written by hal clark in 2026. +// +// Routines for reading tabular CSV/TSV data into num_array. +// + +#pragma once + +#ifndef YGOR_MATH_IO_CSV_HDR_GRD_H +#define YGOR_MATH_IO_CSV_HDR_GRD_H + +#include +#include +#include +#include +#include +#include +#include + +#include "YgorDefinitions.h" +#include "YgorMath.h" + + +// Result of loading a CSV/TSV file into a num_array. +// +// In addition to the numeric matrix, this structure provides metadata about +// non-numeric values that were mapped to integers by the default callback. +// +template +struct csv_load_result { + num_array data; + + // Mapping from each distinct non-numeric string to its assigned integer value. + // Only populated by the default non-numeric callback. + std::map string_to_int; + + // Reverse mapping: for each assigned integer, a list of (row, col) locations + // in the matrix where that mapped value appears. + std::map>> int_to_locations; +}; + + +// Callback type for converting a non-numeric token to a numeric value. +// +// Parameters: +// token: the raw (whitespace-trimmed) cell text that could not be parsed as a number. +// row: the 0-based row index of the cell. +// col: the 0-based column index of the cell. +// +// Returns: the numeric value to store in the matrix. +// +using csv_non_numeric_callback_t = std::function; + + +// Read a CSV or TSV stream into a num_array. +// +// The delimiter is auto-detected: if the first non-empty line (header or data) contains +// a tab character the delimiter is tab; otherwise comma. +// +// Parameters: +// is: Input stream to read from. +// has_header: If true, the first non-empty line is treated as a header and skipped. +// non_numeric_cb: Optional callback for non-numeric tokens. If not provided, the +// default callback maps: +// - "nan" / "NaN" / "NAN" (case insensitive) and empty cells to quiet_NaN. +// - "inf" / "+inf" / "-inf" (case insensitive) to C++ positive/negative infinity. +// - Any other non-numeric token (including punctuation/quoted strings, case +// sensitive) to a distinct integer, with the mapping stored in the returned +// csv_load_result. +// Returns: +// A csv_load_result containing the loaded num_array and any string-to-int mappings. +// +// Throws: +// std::runtime_error if the stream contains no data rows, or if rows have inconsistent +// column counts. +// +template +csv_load_result +ReadNumArrayFromCSV(std::istream &is, + bool has_header = false, + csv_non_numeric_callback_t non_numeric_cb = {}); + + +#endif // YGOR_MATH_IO_CSV_HDR_GRD_H diff --git a/src/YgorStatsCITrees.cc b/src/YgorStatsCITrees.cc index c13f7a0..18f7e95 100644 --- a/src/YgorStatsCITrees.cc +++ b/src/YgorStatsCITrees.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -565,3 +566,177 @@ int64_t Stats::ConditionalInferenceTrees::get_n_permutations() const { template int64_t Stats::ConditionalInferenceTrees::get_n_permutations() const; template int64_t Stats::ConditionalInferenceTrees::get_n_permutations() const; #endif + + +template +bool Stats::ConditionalInferenceTrees::write_tree_node(std::ostream &os, const TreeNode *node) const { + if(node == nullptr){ + return false; + } + if(node->is_leaf){ + os << "L " << node->value << "\n"; + }else{ + os << "I " << node->split_feature << " " << node->split_threshold << "\n"; + if(!write_tree_node(os, node->left.get())) return false; + if(!write_tree_node(os, node->right.get())) return false; + } + return (!os.fail()); +} +#ifndef YGOR_STATS_CI_TREES_DISABLE_ALL_SPECIALIZATIONS + template bool Stats::ConditionalInferenceTrees::write_tree_node(std::ostream &, const TreeNode *) const; + template bool Stats::ConditionalInferenceTrees::write_tree_node(std::ostream &, const TreeNode *) const; +#endif + + +template +std::unique_ptr::TreeNode> +Stats::ConditionalInferenceTrees::read_tree_node(std::istream &is) { + std::string node_type; + is >> node_type; + if(is.fail()) return nullptr; + + auto node = std::make_unique(); + try{ + if(node_type == "L"){ + node->is_leaf = true; + std::string val_str; + is >> val_str; + if(is.fail()) return nullptr; + node->value = static_cast(std::stold(val_str)); + }else if(node_type == "I"){ + node->is_leaf = false; + is >> node->split_feature; + if(is.fail()) return nullptr; + // Validate split_feature against n_features_trained to prevent out-of-bounds + // access during prediction. n_features_trained is set before tree deserialization. + if(node->split_feature < 0 || node->split_feature >= this->n_features_trained){ + return nullptr; + } + std::string thresh_str; + is >> thresh_str; + if(is.fail()) return nullptr; + node->split_threshold = static_cast(std::stold(thresh_str)); + node->left = read_tree_node(is); + node->right = read_tree_node(is); + if(!node->left || !node->right) return nullptr; + }else{ + return nullptr; + } + }catch(const std::invalid_argument &){ + return nullptr; + }catch(const std::out_of_range &){ + return nullptr; + } + return node; +} +#ifndef YGOR_STATS_CI_TREES_DISABLE_ALL_SPECIALIZATIONS + template std::unique_ptr::TreeNode> + Stats::ConditionalInferenceTrees::read_tree_node(std::istream &); + template std::unique_ptr::TreeNode> + Stats::ConditionalInferenceTrees::read_tree_node(std::istream &); +#endif + + +template +bool Stats::ConditionalInferenceTrees::write_to(std::ostream &os) const { + const auto original_precision = os.precision(); + os.precision( std::numeric_limits::max_digits10 ); + + // RAII guard to restore stream precision on all exit paths. + struct precision_guard { + std::ostream &s; + std::streamsize p; + ~precision_guard(){ s.precision(p); } + } guard{os, original_precision}; + + os << "ConditionalInferenceTrees_v1" << "\n"; + os << "max_depth " << this->max_depth << "\n"; + os << "min_samples_split " << this->min_samples_split << "\n"; + os << "alpha " << this->alpha << "\n"; + os << "n_permutations " << this->n_permutations << "\n"; + os << "n_features_trained " << this->n_features_trained << "\n"; + os << "random_seed " << this->random_seed << "\n"; + + // Write tree. + os << "begin_tree\n"; + if(!write_tree_node(os, this->root.get())) return false; + os << "end_tree\n"; + + os.flush(); + return (!os.fail()); +} +#ifndef YGOR_STATS_CI_TREES_DISABLE_ALL_SPECIALIZATIONS + template bool Stats::ConditionalInferenceTrees::write_to(std::ostream &) const; + template bool Stats::ConditionalInferenceTrees::write_to(std::ostream &) const; +#endif + + +template +bool Stats::ConditionalInferenceTrees::read_from(std::istream &is) { + try{ + std::string label; + + // Read and validate header. + is >> label; + if(is.fail() || label != "ConditionalInferenceTrees_v1") return false; + + // Read parameters. + is >> label >> this->max_depth; + if(is.fail() || label != "max_depth") return false; + + is >> label >> this->min_samples_split; + if(is.fail() || label != "min_samples_split") return false; + + { + std::string val_str; + is >> label >> val_str; + if(is.fail() || label != "alpha") return false; + this->alpha = static_cast(std::stold(val_str)); + } + + is >> label >> this->n_permutations; + if(is.fail() || label != "n_permutations") return false; + + is >> label >> this->n_features_trained; + if(is.fail() || label != "n_features_trained") return false; + + is >> label >> this->random_seed; + if(is.fail() || label != "random_seed") return false; + + // Validate hyperparameter invariants after deserialization. + if(this->max_depth <= 0){ + return false; + } + if(this->min_samples_split < 2){ + return false; + } + if(this->alpha <= static_cast(0) || static_cast(1) <= this->alpha){ + return false; + } + if(this->n_permutations <= 0){ + return false; + } + if(this->n_features_trained <= 0){ + return false; + } + // Read tree. + is >> label; + if(is.fail() || label != "begin_tree") return false; + + this->root = read_tree_node(is); + if(!this->root) return false; + + is >> label; + if(is.fail() || label != "end_tree") return false; + + return (!is.fail()); + }catch(const std::invalid_argument &){ + return false; + }catch(const std::out_of_range &){ + return false; + } +} +#ifndef YGOR_STATS_CI_TREES_DISABLE_ALL_SPECIALIZATIONS + template bool Stats::ConditionalInferenceTrees::read_from(std::istream &); + template bool Stats::ConditionalInferenceTrees::read_from(std::istream &); +#endif diff --git a/src/YgorStatsCITrees.h b/src/YgorStatsCITrees.h index 9dc3269..3c3d273 100644 --- a/src/YgorStatsCITrees.h +++ b/src/YgorStatsCITrees.h @@ -13,6 +13,7 @@ #define YGOR_STATS_CI_TREES_HDR_GRD_H #include +#include #include #include #include @@ -109,6 +110,10 @@ class ConditionalInferenceTrees { // Predict using the tree from a given node. T predict_tree(const TreeNode *node, const num_array &x) const; + // Serialization helpers. + bool write_tree_node(std::ostream &os, const TreeNode *node) const; + std::unique_ptr read_tree_node(std::istream &is); + public: // Constructor. // @@ -160,6 +165,31 @@ class ConditionalInferenceTrees { // Get the number of permutations. int64_t get_n_permutations() const; + + // Write the model to a text stream. + // + // Serializes all data members, parameters, and tree structure to a human-readable + // text format. The model can be restored exactly using read_from() without any loss + // in function or accuracy. Floating point values are written with maximum precision. + // + // Parameters: + // os: Output stream to write to. + // + // Returns: + // true on success, false if the stream enters a fail state. + bool write_to(std::ostream &os) const; + + // Read a model from a text stream. + // + // Restores a model previously written by write_to(). All parameters and tree + // structure are restored exactly. + // + // Parameters: + // is: Input stream to read from. + // + // Returns: + // true on success, false if the stream format is invalid or enters a fail state. + bool read_from(std::istream &is); }; } //namespace Stats. diff --git a/tests2/YgorMathIOCSV.cc b/tests2/YgorMathIOCSV.cc new file mode 100644 index 0000000..eb3d9bb --- /dev/null +++ b/tests2/YgorMathIOCSV.cc @@ -0,0 +1,182 @@ + +#include +#include +#include +#include + +#include +#include + +#include "doctest/doctest.h" + + +TEST_CASE( "ReadNumArrayFromCSV basic CSV" ){ + std::stringstream ss("1.0,2.0,3.0\n4.0,5.0,6.0\n"); + + auto result = ReadNumArrayFromCSV(ss); + REQUIRE( result.data.num_rows() == 2 ); + REQUIRE( result.data.num_cols() == 3 ); + REQUIRE( result.data.read_coeff(0, 0) == doctest::Approx(1.0) ); + REQUIRE( result.data.read_coeff(0, 2) == doctest::Approx(3.0) ); + REQUIRE( result.data.read_coeff(1, 0) == doctest::Approx(4.0) ); + REQUIRE( result.data.read_coeff(1, 2) == doctest::Approx(6.0) ); +} + + +TEST_CASE( "ReadNumArrayFromCSV TSV auto-detection" ){ + std::stringstream ss("1.0\t2.0\t3.0\n4.0\t5.0\t6.0\n"); + + auto result = ReadNumArrayFromCSV(ss); + REQUIRE( result.data.num_rows() == 2 ); + REQUIRE( result.data.num_cols() == 3 ); + REQUIRE( result.data.read_coeff(0, 0) == doctest::Approx(1.0) ); + REQUIRE( result.data.read_coeff(1, 2) == doctest::Approx(6.0) ); +} + + +TEST_CASE( "ReadNumArrayFromCSV with header" ){ + std::stringstream ss("x,y,z\n1.0,2.0,3.0\n4.0,5.0,6.0\n"); + + auto result = ReadNumArrayFromCSV(ss, true); + REQUIRE( result.data.num_rows() == 2 ); + REQUIRE( result.data.num_cols() == 3 ); + REQUIRE( result.data.read_coeff(0, 0) == doctest::Approx(1.0) ); +} + + +TEST_CASE( "ReadNumArrayFromCSV NaN handling" ){ + std::stringstream ss("1.0,nan,NaN\n,NAN,2.0\n"); + + auto result = ReadNumArrayFromCSV(ss); + REQUIRE( result.data.num_rows() == 2 ); + REQUIRE( result.data.num_cols() == 3 ); + REQUIRE( result.data.read_coeff(0, 0) == doctest::Approx(1.0) ); + REQUIRE( std::isnan(result.data.read_coeff(0, 1)) ); + REQUIRE( std::isnan(result.data.read_coeff(0, 2)) ); + REQUIRE( std::isnan(result.data.read_coeff(1, 0)) ); // empty cell -> NaN + REQUIRE( std::isnan(result.data.read_coeff(1, 1)) ); + REQUIRE( result.data.read_coeff(1, 2) == doctest::Approx(2.0) ); +} + + +TEST_CASE( "ReadNumArrayFromCSV infinity handling" ){ + std::stringstream ss("inf,-inf,+inf\n"); + + auto result = ReadNumArrayFromCSV(ss); + REQUIRE( result.data.num_rows() == 1 ); + REQUIRE( result.data.num_cols() == 3 ); + REQUIRE( std::isinf(result.data.read_coeff(0, 0)) ); + REQUIRE( result.data.read_coeff(0, 0) > 0.0 ); + REQUIRE( std::isinf(result.data.read_coeff(0, 1)) ); + REQUIRE( result.data.read_coeff(0, 1) < 0.0 ); + REQUIRE( std::isinf(result.data.read_coeff(0, 2)) ); + REQUIRE( result.data.read_coeff(0, 2) > 0.0 ); +} + + +TEST_CASE( "ReadNumArrayFromCSV non-numeric string mapping" ){ + std::stringstream ss("apple,1.0\nbanana,2.0\napple,3.0\n"); + + auto result = ReadNumArrayFromCSV(ss); + REQUIRE( result.data.num_rows() == 3 ); + REQUIRE( result.data.num_cols() == 2 ); + + // "apple" and "banana" should be mapped to distinct integers. + REQUIRE( result.string_to_int.size() == 2 ); + REQUIRE( result.string_to_int.count("apple") == 1 ); + REQUIRE( result.string_to_int.count("banana") == 1 ); + REQUIRE( result.string_to_int.at("apple") != result.string_to_int.at("banana") ); + + // Both "apple" cells should have the same value. + REQUIRE( result.data.read_coeff(0, 0) == result.data.read_coeff(2, 0) ); + + // The numeric columns should still be parsed correctly. + REQUIRE( result.data.read_coeff(0, 1) == doctest::Approx(1.0) ); + REQUIRE( result.data.read_coeff(1, 1) == doctest::Approx(2.0) ); + REQUIRE( result.data.read_coeff(2, 1) == doctest::Approx(3.0) ); + + // Locations should be tracked. + const int64_t apple_int = result.string_to_int.at("apple"); + REQUIRE( result.int_to_locations.count(apple_int) == 1 ); + REQUIRE( result.int_to_locations.at(apple_int).size() == 2 ); +} + + +TEST_CASE( "ReadNumArrayFromCSV custom callback" ){ + std::stringstream ss("hello,1.0\nworld,2.0\n"); + + auto cb = [](const std::string &token, int64_t row, int64_t col) -> double { + return -999.0; // Map all non-numeric values to -999. + }; + + auto result = ReadNumArrayFromCSV(ss, false, cb); + REQUIRE( result.data.num_rows() == 2 ); + REQUIRE( result.data.read_coeff(0, 0) == doctest::Approx(-999.0) ); + REQUIRE( result.data.read_coeff(1, 0) == doctest::Approx(-999.0) ); + REQUIRE( result.data.read_coeff(0, 1) == doctest::Approx(1.0) ); +} + + +TEST_CASE( "ReadNumArrayFromCSV empty input" ){ + std::stringstream ss(""); + REQUIRE_THROWS( ReadNumArrayFromCSV(ss) ); +} + + +TEST_CASE( "ReadNumArrayFromCSV inconsistent columns" ){ + std::stringstream ss("1.0,2.0\n1.0,2.0,3.0\n"); + REQUIRE_THROWS( ReadNumArrayFromCSV(ss) ); +} + + +TEST_CASE( "ReadNumArrayFromCSV whitespace trimming" ){ + std::stringstream ss(" 1.0 , 2.0 , 3.0 \n"); + + auto result = ReadNumArrayFromCSV(ss); + REQUIRE( result.data.num_rows() == 1 ); + REQUIRE( result.data.num_cols() == 3 ); + REQUIRE( result.data.read_coeff(0, 0) == doctest::Approx(1.0) ); + REQUIRE( result.data.read_coeff(0, 1) == doctest::Approx(2.0) ); + REQUIRE( result.data.read_coeff(0, 2) == doctest::Approx(3.0) ); +} + + +TEST_CASE( "num_array subarray basic" ){ + num_array m(3, 4, 0.0); + for(int64_t r = 0; r < 3; ++r){ + for(int64_t c = 0; c < 4; ++c){ + m.coeff(r, c) = static_cast(r * 10 + c); + } + } + + SUBCASE("full subarray equals original"){ + auto sub = m.subarray(0, 3, 0, 4); + REQUIRE( sub.num_rows() == 3 ); + REQUIRE( sub.num_cols() == 4 ); + REQUIRE( sub == m ); + } + + SUBCASE("single row slice"){ + auto row = m.subarray(1, 2, 0, 4); + REQUIRE( row.num_rows() == 1 ); + REQUIRE( row.num_cols() == 4 ); + REQUIRE( row.read_coeff(0, 0) == doctest::Approx(10.0) ); + REQUIRE( row.read_coeff(0, 3) == doctest::Approx(13.0) ); + } + + SUBCASE("column range slice"){ + auto cols = m.subarray(0, 3, 1, 3); + REQUIRE( cols.num_rows() == 3 ); + REQUIRE( cols.num_cols() == 2 ); + REQUIRE( cols.read_coeff(0, 0) == doctest::Approx(1.0) ); + REQUIRE( cols.read_coeff(2, 1) == doctest::Approx(22.0) ); + } + + SUBCASE("invalid ranges are rejected"){ + REQUIRE_THROWS( m.subarray(-1, 3, 0, 4) ); + REQUIRE_THROWS( m.subarray(0, 4, 0, 4) ); + REQUIRE_THROWS( m.subarray(0, 3, 0, 5) ); + REQUIRE_THROWS( m.subarray(2, 1, 0, 4) ); + REQUIRE_THROWS( m.subarray(0, 3, 3, 2) ); + } +} diff --git a/tests2/YgorStatsCITrees.cc b/tests2/YgorStatsCITrees.cc index f92e664..db7ab3a 100644 --- a/tests2/YgorStatsCITrees.cc +++ b/tests2/YgorStatsCITrees.cc @@ -1,6 +1,8 @@ #include #include +#include +#include #include #include @@ -736,3 +738,184 @@ TEST_CASE( "ConditionalInferenceTrees handles two identical samples" ){ const double pred = ct.predict(x_test); REQUIRE( std::abs(pred - 5.0) < 0.01 ); } + + +TEST_CASE( "ConditionalInferenceTrees write_to and read_from roundtrip" ){ + // Train a model, write it, read it back, and verify predictions match. + Stats::ConditionalInferenceTrees ct(5, 2, 0.10, 100, 42); + + const int64_t n_samples = 30; + const int64_t n_features = 3; + num_array X(n_samples, n_features); + num_array y(n_samples, 1); + + for(int64_t i = 0; i < n_samples; ++i){ + X.coeff(i, 0) = static_cast(i) * 0.1; + X.coeff(i, 1) = static_cast(i) * 0.2; + X.coeff(i, 2) = static_cast(i) * 0.15; + y.coeff(i, 0) = X.read_coeff(i, 0) + 2.0 * X.read_coeff(i, 1) + 3.0 * X.read_coeff(i, 2); + } + + ct.fit(X, y); + + SUBCASE("roundtrip preserves predictions exactly"){ + // Write model. + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + + // Read model. + Stats::ConditionalInferenceTrees ct_loaded; + REQUIRE( ct_loaded.read_from(ss) ); + + // Verify predictions are identical. + for(int64_t i = 0; i < n_samples; ++i){ + num_array x_test(1, n_features); + x_test.coeff(0, 0) = X.read_coeff(i, 0); + x_test.coeff(0, 1) = X.read_coeff(i, 1); + x_test.coeff(0, 2) = X.read_coeff(i, 2); + + const double pred_original = ct.predict(x_test); + const double pred_loaded = ct_loaded.predict(x_test); + REQUIRE( pred_original == pred_loaded ); + } + } + + SUBCASE("roundtrip preserves predictions on unseen data"){ + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + + Stats::ConditionalInferenceTrees ct_loaded; + REQUIRE( ct_loaded.read_from(ss) ); + + // Test on unseen data. + for(int i = 0; i < 10; ++i){ + num_array x_test(1, n_features); + x_test.coeff(0, 0) = static_cast(i) * 0.05 + 0.025; + x_test.coeff(0, 1) = static_cast(i) * 0.1 + 0.05; + x_test.coeff(0, 2) = static_cast(i) * 0.075 + 0.0375; + + const double pred_original = ct.predict(x_test); + const double pred_loaded = ct_loaded.predict(x_test); + REQUIRE( pred_original == pred_loaded ); + } + } + + SUBCASE("written format is text-based"){ + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + const std::string content = ss.str(); + + // Verify it starts with the expected header. + REQUIRE( content.substr(0, 28) == "ConditionalInferenceTrees_v1" ); + + // Verify it contains expected labels. + REQUIRE( content.find("max_depth 5") != std::string::npos ); + REQUIRE( content.find("min_samples_split 2") != std::string::npos ); + REQUIRE( content.find("begin_tree") != std::string::npos ); + REQUIRE( content.find("end_tree") != std::string::npos ); + } + + SUBCASE("read_from rejects invalid input"){ + std::stringstream bad_ss("not a valid model"); + Stats::ConditionalInferenceTrees ct_bad; + REQUIRE( !ct_bad.read_from(bad_ss) ); + } + + SUBCASE("read_from rejects invalid hyperparameters"){ + // Corrupt max_depth to 0. + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + std::string content = ss.str(); + auto pos = content.find("max_depth 5"); + REQUIRE( pos != std::string::npos ); + content.replace(pos, 11, "max_depth 0"); + + std::stringstream bad_ss(content); + Stats::ConditionalInferenceTrees ct_bad; + REQUIRE( !ct_bad.read_from(bad_ss) ); + } + + SUBCASE("read_from rejects invalid alpha"){ + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + std::string content = ss.str(); + // Replace "alpha 0.1" with "alpha 0" (invalid: alpha must be in (0,1)). + auto pos = content.find("alpha "); + REQUIRE( pos != std::string::npos ); + auto end = content.find('\n', pos); + content.replace(pos, end - pos, "alpha 0"); + + std::stringstream bad_ss(content); + Stats::ConditionalInferenceTrees ct_bad; + REQUIRE( !ct_bad.read_from(bad_ss) ); + } + + SUBCASE("read_from rejects out-of-range split_feature"){ + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + std::string content = ss.str(); + // Replace a valid internal node split_feature with an out-of-range value. + // The format for internal nodes is "I ". + auto pos = content.find("\nI "); + REQUIRE( pos != std::string::npos ); + // Find the feature index after "I ". + auto feat_start = pos + 3; + auto feat_end = content.find(' ', feat_start); + REQUIRE( feat_end != std::string::npos ); + // Replace the feature index with 999 (out of range for 3 features). + content.replace(feat_start, feat_end - feat_start, "999"); + + std::stringstream bad_ss(content); + Stats::ConditionalInferenceTrees ct_bad; + REQUIRE( !ct_bad.read_from(bad_ss) ); + } + + SUBCASE("read_from rejects negative split_feature"){ + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + std::string content = ss.str(); + auto pos = content.find("\nI "); + REQUIRE( pos != std::string::npos ); + auto feat_start = pos + 3; + auto feat_end = content.find(' ', feat_start); + REQUIRE( feat_end != std::string::npos ); + content.replace(feat_start, feat_end - feat_start, "-1"); + + std::stringstream bad_ss(content); + Stats::ConditionalInferenceTrees ct_bad; + REQUIRE( !ct_bad.read_from(bad_ss) ); + } +} + + +TEST_CASE( "ConditionalInferenceTrees write_to and read_from with float" ){ + Stats::ConditionalInferenceTrees ct(4, 2, 0.10f, 100, 99); + + const int64_t n_samples = 15; + num_array X(n_samples, 2); + num_array y(n_samples, 1); + + for(int64_t i = 0; i < n_samples; ++i){ + X.coeff(i, 0) = static_cast(i) * 0.5f; + X.coeff(i, 1) = static_cast(i) * 0.3f; + y.coeff(i, 0) = X.read_coeff(i, 0) + X.read_coeff(i, 1); + } + + ct.fit(X, y); + + std::stringstream ss; + REQUIRE( ct.write_to(ss) ); + + Stats::ConditionalInferenceTrees ct_loaded; + REQUIRE( ct_loaded.read_from(ss) ); + + for(int64_t i = 0; i < n_samples; ++i){ + num_array x_test(1, 2); + x_test.coeff(0, 0) = X.read_coeff(i, 0); + x_test.coeff(0, 1) = X.read_coeff(i, 1); + + const float pred_original = ct.predict(x_test); + const float pred_loaded = ct_loaded.predict(x_test); + REQUIRE( pred_original == pred_loaded ); + } +} diff --git a/tests2/compile_and_run.sh b/tests2/compile_and_run.sh index c85e684..030d93c 100755 --- a/tests2/compile_and_run.sh +++ b/tests2/compile_and_run.sh @@ -45,6 +45,7 @@ g++ \ YgorMathIOPLY.cc \ YgorMathIOSTL.cc \ YgorMathIOXYZ.cc \ + YgorMathIOCSV.cc \ YgorMeshesConvexHull.cc \ YgorMeshesHoles.cc \ YgorMeshesOrient.cc \