33from .common import Classifier , Regressor
44
55
6+ SPLIT_KIND_THRESHOLD = "threshold"
7+ SPLIT_KIND_CATEGORY_SET = "category_set"
8+ UNSEEN_CATEGORY_VALUE = - 1.0
9+
10+
611def get_terminal_node_lt (node , data ):
712 # XGBoost
813 current = node
914 while not current .is_leaf :
1015 if current .is_missing (data ):
1116 current = current .left_child if current .missing_goes_left else current .right_child
17+ elif current .split_kind == SPLIT_KIND_CATEGORY_SET :
18+ # XGBoost routes matching categories to the right child.
19+ # Unseen encoded categories are treated as missing values and follow missing_goes_left when available.
20+ if data [current .feature_idx ] == UNSEEN_CATEGORY_VALUE and current .missing_goes_left is not None :
21+ current = current .left_child if current .missing_goes_left else current .right_child
22+ elif current .has_category (data ):
23+ current = current .right_child
24+ else :
25+ current = current .left_child
1226 elif data [current .feature_idx ] < current .threshold :
1327 current = current .left_child
1428 else :
@@ -22,6 +36,14 @@ def get_terminal_node_lte(node, data):
2236 while not current .is_leaf :
2337 if current .is_missing (data ):
2438 current = current .left_child if current .missing_goes_left else current .right_child
39+ elif current .split_kind == SPLIT_KIND_CATEGORY_SET :
40+ # LightGBM routes matching categories to the left child and treats unseen encoded categories as missing.
41+ if data [current .feature_idx ] == UNSEEN_CATEGORY_VALUE and current .missing_goes_left is not None :
42+ current = current .left_child if current .missing_goes_left else current .right_child
43+ elif current .has_category (data ):
44+ current = current .left_child
45+ else :
46+ current = current .right_child
2547 elif data [current .feature_idx ] <= current .threshold :
2648 current = current .left_child
2749 else :
@@ -35,7 +57,9 @@ class Node:
3557 label: the 'value' of the node
3658 """
3759
38- def __init__ (self , feature_idx = None , threshold = np .nan , left_child = None , right_child = None , label = None , is_leaf = None , missing_goes_left = None , missing_value = np .nan ):
60+ def __init__ (self , feature_idx = None , threshold = np .nan , left_child = None , right_child = None , label = None ,
61+ is_leaf = None , missing_goes_left = None , missing_value = np .nan , split_kind = SPLIT_KIND_THRESHOLD ,
62+ category_set = None ):
3963 self .label = label
4064 self .feature_idx = feature_idx
4165 self .threshold = threshold
@@ -44,13 +68,18 @@ def __init__(self, feature_idx=None, threshold=np.nan, left_child=None, right_ch
4468 self .is_leaf = is_leaf
4569 self .missing_goes_left = missing_goes_left
4670 self .missing_value = missing_value
71+ self .split_kind = split_kind
72+ self .category_set = None if category_set is None else frozenset (float (v ) for v in category_set )
4773
4874 def is_missing (self , data ):
4975 if np .isnan (self .missing_value ):
5076 return np .isnan (data [self .feature_idx ])
5177 else :
5278 return data [self .feature_idx ] == self .missing_value
5379
80+ def has_category (self , data ):
81+ return float (data [self .feature_idx ]) in self .category_set
82+
5483class DecisionTreeModel (Classifier , Regressor ):
5584 """
5685 This class handle the cases where the tree model was trained by XGBoost, LightGBM or scikit-learn.
@@ -77,15 +106,15 @@ class DecisionTreeModel(Classifier, Regressor):
77106 def __init__ (self , model_parameters ):
78107 self .init_tree (model_parameters )
79108 if self .variant == "XGBOOST" :
80- self .feature_converter = np .float32
109+ self .feature_converter = lambda x : np .asarray ( x , dtype = np . float32 )
81110 self .get_terminal_node = get_terminal_node_lt
82111 self .label_dtype = np .float32
83112 elif self .variant == "LIGHTGBM" :
84- self .feature_converter = np .float64
113+ self .feature_converter = lambda x : np .asarray ( x , dtype = np . float64 )
85114 self .get_terminal_node = get_terminal_node_lte
86115 self .label_dtype = np .float64
87116 elif self .variant == "SKLEARN" :
88- self .feature_converter = lambda x : np .float64 ( np .float32 ( x ) )
117+ self .feature_converter = lambda x : np .asarray ( x , dtype = np .float32 ). astype ( np . float64 )
89118 self .get_terminal_node = get_terminal_node_lte
90119 self .label_dtype = np .float64
91120
@@ -117,10 +146,29 @@ def init_tree(self, model_parameters):
117146 list_missing_goes_left = [None ] * len (model_parameters ["node_id" ])
118147 else :
119148 list_missing_goes_left = [v == "l" for v in missing ]
149+ split_kinds = model_parameters .get ("split_kind" )
150+ if split_kinds is None or len (split_kinds ) == 0 :
151+ split_kinds = [SPLIT_KIND_THRESHOLD ] * len (model_parameters ["node_id" ])
152+ category_sets = model_parameters .get ("category_set" )
153+ if category_sets is None or len (category_sets ) == 0 :
154+ category_sets = [None ] * len (model_parameters ["node_id" ])
120155 nodes_with_children = {
121- node_id : Node (feature_idx = feature , threshold = convert_threshold (threshold ), is_leaf = False , missing_goes_left = missing_goes_left , missing_value = missing_value )
122- for node_id , feature , threshold , missing_goes_left in zip (
123- model_parameters ["node_id" ], model_parameters ["feature" ], model_parameters ["threshold" ], list_missing_goes_left )
156+ node_id : Node (
157+ feature_idx = feature ,
158+ threshold = convert_threshold (threshold ) if split_kind == SPLIT_KIND_THRESHOLD else threshold ,
159+ is_leaf = False ,
160+ missing_goes_left = missing_goes_left ,
161+ missing_value = missing_value ,
162+ split_kind = split_kind ,
163+ category_set = category_set
164+ )
165+ for node_id , feature , threshold , missing_goes_left , split_kind , category_set in zip (
166+ model_parameters ["node_id" ],
167+ model_parameters ["feature" ],
168+ model_parameters ["threshold" ],
169+ list_missing_goes_left ,
170+ split_kinds ,
171+ category_sets )
124172 }
125173
126174 # Connect the nodes to their children
@@ -129,8 +177,12 @@ def init_tree(self, model_parameters):
129177 node .right_child = nodes_with_children .get (node_id * 2 + 2 , leaves .get (node_id * 2 + 2 ))
130178
131179 # Validation
132- if node .left_child is None or node .right_child is None or np .isnan (node .threshold ):
133- raise ValueError ("Tree node is not valid" )
180+ if node .left_child is None or node .right_child is None :
181+ raise ValueError ("Tree split node is missing a child" )
182+ if node .split_kind == SPLIT_KIND_THRESHOLD and np .isnan (node .threshold ):
183+ raise ValueError ("Threshold split node is missing a threshold" )
184+ if node .split_kind == SPLIT_KIND_CATEGORY_SET and node .category_set is None :
185+ raise ValueError ("Category-set split node is missing a category_set" )
134186 if (node .left_child .is_leaf and node .left_child .label is None ) or (
135187 node .right_child .is_leaf and node .right_child .label is None ):
136188 raise ValueError ("Leaf node does not have a label" )
0 commit comments