diff --git a/generate_cpp_code.py b/generate_cpp_code.py index 6ef3e68..44a9b9e 100644 --- a/generate_cpp_code.py +++ b/generate_cpp_code.py @@ -1,6 +1,7 @@ import sys import argparse import re +import numpy as np def get_single_booster_cpp_code(booster_tree, branch_id, class_index, indentation_level=0): level = booster_tree[branch_id].split() @@ -18,8 +19,9 @@ def get_single_booster_cpp_code(booster_tree, branch_id, class_index, indentatio # Get feature index and limit value feature_index = re.search('f(\d+)', level[0]).group(1) - comparison = re.search('[^0-9a-zA-Z:[]+[0-9]*[0-9.]*', level[0]).group(0) - + comparison = re.search('[^0-9a-zA-Z:[]+-?[\d.]+(?:e[-, +]?\d+)', level[0]).group(0) + comparison = comparison[0] + np.format_float_positional(float(comparison[1:])) + booster_code += "{0}if (sample[{1}] {2}) {{\n".format(" " * indentation_level, feature_index, comparison) booster_code += get_single_booster_cpp_code(booster_tree, yes_branch_id, class_index, indentation_level + 1) booster_code += "{0}}} else {{\n".format(" " * indentation_level) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b26e81f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +numpy +