Source code for omlt.linear_tree.lt_definition

import numpy as np
import lineartree


[docs]class LinearTreeDefinition: """ Class to represent a linear tree model trained in the linear-tree package Attributes: __model (linear-tree model) : Linear Tree Model trained in linear-tree __splits (dict) : Dict containing split node information __leaves (dict) : Dict containing leaf node information __thresholds (dict) : Dict containing splitting threshold information __scaling_object (scaling object) : Scaling object to ensure scaled data match units of broader optimization problem __scaled_input_bounds (dict): Dict containing scaled input bounds __unscaled_input_bounds (dict): Dict containing unscaled input bounds References: * linear-tree : https://github.com/cerlymarco/linear-tree """ def __init__( self, lt_regressor, scaling_object=None, scaled_input_bounds=None, unscaled_input_bounds=None, ): """Create a LinearTreeDefinition object and define attributes based on the trained linear model decision tree. Arguments: lt_regressor -- A LinearTreeRegressor model that is trained by the linear-tree package Keyword Arguments: scaling_object -- A scaling object to specify the scaling parameters for the linear model tree inputs and outputs. If None, then no scaling is performed. (default: {None}) scaled_input_bounds -- A dict that contains the bounds on the scaled variables (the direct inputs to the tree). If None, then the user must specify the bounds via the input_bounds argument. (default: {None}) unscaled_input_bounds -- A dict that contains the bounds on the variables (the direct inputs to the tree). If None, then the user must specify the scaled bounds via the scaled_input_bounds argument. (default: {None}) Raises: Exception: Input bounds required. If unscaled_input_bounds and scaled_input_bounds is None, raise Exception. """ self.__model = lt_regressor self.__scaling_object = scaling_object # Process input bounds to insure scaled input bounds exist for formulations if scaled_input_bounds is None: if unscaled_input_bounds is not None and scaling_object is not None: lbs = scaling_object.get_scaled_input_expressions( {k: t[0] for k, t in unscaled_input_bounds.items()} ) ubs = scaling_object.get_scaled_input_expressions( {k: t[1] for k, t in unscaled_input_bounds.items()} ) scaled_input_bounds = { k: (lbs[k], ubs[k]) for k in unscaled_input_bounds.keys() } # If unscaled input bounds provided and no scaler provided, scaled # input bounds = unscaled input bounds elif unscaled_input_bounds is not None and scaling_object is None: scaled_input_bounds = unscaled_input_bounds elif unscaled_input_bounds is None: raise ValueError( "Input Bounds needed to represent linear trees as MIPs" ) self.__unscaled_input_bounds = unscaled_input_bounds self.__scaled_input_bounds = scaled_input_bounds self.__splits, self.__leaves, self.__thresholds = _parse_tree_data( lt_regressor, scaled_input_bounds ) self.__n_inputs = _find_n_inputs(self.__leaves) self.__n_outputs = 1 @property def scaling_object(self): """Returns scaling object""" return self.__scaling_object @property def scaled_input_bounds(self): """Returns dict containing scaled input bounds""" return self.__scaled_input_bounds @property def splits(self): """Returns dict containing split information""" return self.__splits @property def leaves(self): """Returns dict containing leaf information""" return self.__leaves @property def thresholds(self): """Returns dict containing threshold information""" return self.__thresholds @property def n_inputs(self): """Returns number of inputs to the linear tree""" return self.__n_inputs @property def n_outputs(self): """Returns number of outputs to the linear tree""" return self.__n_outputs
def _find_all_children_splits(split, splits_dict): """ This helper function finds all multigeneration children splits for an argument split. Arguments: split --The split for which you are trying to find children splits splits_dict -- A dictionary of all the splits in the tree Returns: A list containing the Node IDs of all children splits """ all_splits = [] # Check if the immediate left child of the argument split is also a split. # If so append to the list then use recursion to generate the remainder left_child = splits_dict[split]["children"][0] if left_child in splits_dict: all_splits.append(left_child) all_splits.extend(_find_all_children_splits(left_child, splits_dict)) # Same as above but with right child right_child = splits_dict[split]["children"][1] if right_child in splits_dict: all_splits.append(right_child) all_splits.extend(_find_all_children_splits(right_child, splits_dict)) return all_splits def _find_all_children_leaves(split, splits_dict, leaves_dict): """ This helper function finds all multigeneration children leaves for an argument split. Arguments: split -- The split for which you are trying to find children leaves splits_dict -- A dictionary of all the split info in the tree leaves_dict -- A dictionary of all the leaf info in the tree Returns: A list containing all the Node IDs of all children leaves """ all_leaves = [] # Find all the splits that are children of the relevant split all_splits = _find_all_children_splits(split, splits_dict) # Ensure the current split is included if split not in all_splits: all_splits.append(split) # For each leaf, check if the parents appear in the list of children # splits (all_splits). If so, it must be a leaf of the argument split for leaf in leaves_dict: if leaves_dict[leaf]["parent"] in all_splits: all_leaves.append(leaf) return all_leaves def _find_n_inputs(leaves): """ Finds the number of inputs using the length of the slope vector in the first leaf Arguments: leaves -- Dictionary of leaf information Returns: Number of inputs """ tree_indices = np.array(list(leaves.keys())) leaf_indices = np.array(list(leaves[tree_indices[0]].keys())) tree_one = tree_indices[0] leaf_one = leaf_indices[0] n_inputs = len(np.arange(0, len(leaves[tree_one][leaf_one]["slope"]))) return n_inputs def _reassign_none_bounds(leaves, input_bounds): """ This helper function reassigns bounds that are None to the bounds input by the user Arguments: leaves -- The dictionary of leaf information. Attribute of the LinearTreeDefinition object input_bounds -- The nested dictionary Returns: The modified leaves dict without any bounds that are listed as None """ leaf_indices = np.array(list(leaves.keys())) leaf_one = leaf_indices[0] features = np.arange(0, len(leaves[leaf_one]["slope"])) for leaf in leaf_indices: for feat in features: if leaves[leaf]["bounds"][feat][0] is None: leaves[leaf]["bounds"][feat][0] = input_bounds[feat][0] if leaves[leaf]["bounds"][feat][1] is None: leaves[leaf]["bounds"][feat][1] = input_bounds[feat][1] return leaves def _parse_tree_data(model, input_bounds): """ This function creates the data structures with the information required for creation of the variables, sets, and constraints in the pyomo reformulation of the linear model decision trees. Note that these data structures are attributes of the LinearTreeDefinition Class. Arguments: model -- Trained linear-tree model or dic containing linear-tree model summary (e.g. dict = model.summary()) Returns: leaves - Dict containing the following information for each leaf: 1) 'slope' - The slope of the fitted line at that leaf 2) 'intercept' - The intercept of the line at that lead 3) 'parent' - The parent split or node of that leaf splits - Dict containing the following information for each split: 1) 'children' - The child nodes of that split 2) 'col' - The variable(feature) to split on (beginning at 0) 3) 'left_leaves' - All the leaves to the left of that split 4) 'right_leaves' - All the leaves to the right of that split 5) 'parent' - The parent node of the split. Node zero has no parent 6) 'th' - The threshold of the split 7) 'y_index' - Indices corresponding to Mistry et. al. y binary variable vars_dict - Dict of tree inputs and their respective thresholds Raises: Exception: If input dict is not equal to model.summary() Exception: If input model is not a dict or linear-tree instance """ # Create the initial leaves and splits dictionaries depending on the # instance of the model (can be either a LinearTreeRegressor or dict). # Include checks to ensure that the input dict is the model summary which # is obtained by calling the summary() method contained within the # linear-tree package (e.g. dict = model.summary()) if isinstance(model, lineartree.lineartree.LinearTreeRegressor) is True: leaves = model.summary(only_leaves=True) splits = model.summary() elif isinstance(model, dict) is True: splits = model leaves = {} num_splits_in_model = 0 count = 0 # Checks to ensure that the input nested dictionary contains the # correct information for entry in model: if "children" not in model[entry].keys(): leaves[entry] = model[entry] else: left_child = model[entry]["children"][0] right_child = model[entry]["children"][1] num_splits_in_model += 1 if left_child not in model.keys() or right_child not in model.keys(): count += 1 if count > 0 or num_splits_in_model == 0: raise ValueError( "Input dict must be the summary of the linear-tree model" + " e.g. dict = model.summary()" ) else: raise TypeError("Model entry must be dict or linear-tree instance") # This loop adds keys for the slopes and intercept and removes the leaf # keys in the splits dictionary for leaf in leaves: del splits[leaf] leaves[leaf]["slope"] = list(leaves[leaf]["models"].coef_) leaves[leaf]["intercept"] = leaves[leaf]["models"].intercept_ # This loop creates an parent node id entry for each node in the tree for split in splits: left_child = splits[split]["children"][0] right_child = splits[split]["children"][1] if left_child in splits: splits[left_child]["parent"] = split else: leaves[left_child]["parent"] = split if right_child in splits: splits[right_child]["parent"] = split else: leaves[right_child]["parent"] = split # This loop creates an entry for the all the leaves to the left and right # of a split for split in splits: left_child = splits[split]["children"][0] right_child = splits[split]["children"][1] if left_child in splits: splits[split]["left_leaves"] = _find_all_children_leaves( left_child, splits, leaves ) else: splits[split]["left_leaves"] = [left_child] if right_child in splits: splits[split]["right_leaves"] = _find_all_children_leaves( right_child, splits, leaves ) else: splits[split]["right_leaves"] = [right_child] # For each variable that appears in the tree, go through all the splits # and record its splitting threshold splitting_thresholds = {} for split in splits: var = splits[split]["col"] splitting_thresholds[var] = {} for split in splits: var = splits[split]["col"] splitting_thresholds[var][split] = splits[split]["th"] # Make sure every nested dictionary in the splitting_thresholds dictionary # is sorted by value for var in splitting_thresholds: splitting_thresholds[var] = dict( sorted(splitting_thresholds[var].items(), key=lambda x: x[1]) ) # NOTE: Can eliminate if not implementing the Mistry et. al. formulations # Record the ordered indices of the binary variable y. The first index # is the splitting variable. The second index is its location in the # ordered dictionary of thresholds for that variable. for split in splits: var = splits[split]["col"] splits[split]["y_index"] = [] splits[split]["y_index"].append(var) splits[split]["y_index"].append(list(splitting_thresholds[var]).index(split)) # For each leaf, create an empty dictionary that will store the lower # and upper bounds of each feature. for leaf in leaves: leaves[leaf]["bounds"] = {} leaf_ids = np.array(list(leaves.keys())) features = np.arange(0, len(leaves[leaf_ids[0]]["slope"])) # For each feature in each leaf, initialize lower and upper bounds to None for feat in features: for leaf in leaves: leaves[leaf]["bounds"][feat] = [None, None] # Finally, go through each split and assign it's threshold value as the # upper bound to all the leaves descending to the left of the split and # as the lower bound to all the leaves descending to the right. for split in splits: var = splits[split]["col"] for leaf in splits[split]["left_leaves"]: leaves[leaf]["bounds"][var][1] = splits[split]["th"] for leaf in splits[split]["right_leaves"]: leaves[leaf]["bounds"][var][0] = splits[split]["th"] leaves = _reassign_none_bounds(leaves, input_bounds) # We use the same formulations developed for gradient boosted linear trees # so we nest the leaves, splits, and thresholds attributes in a "one-tree" # tree. leaves = {0: leaves} splits = {0: splits} splitting_thresholds = {0: splitting_thresholds} return splits, leaves, splitting_thresholds