# Copyright 2023 AntGroup CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict, List
from collections import OrderedDict
from pyhocon import ConfigTree
import numpy as np
from .tree_node import GradientCausalTreeNode
from .dataset import Dataset
from .histogram import Histogram
from .information import CausalDataInfo
from .bin import BinMapper
from .splitting_losses import *
from .cppnode import create_didnode_from_dict, predict
from .losses import Loss
from .utils import *
[docs]def _filter(leaves, leaves_range):
inner_leaves_range = []
inner_leaves = []
splitting = []
for i, leaf in enumerate(leaves):
if leaf.is_leaf is False:
inner_leaves_range.append(leaves_range[i])
inner_leaves.append(leaf)
splitting.append(True)
else:
splitting.append(False)
if len(inner_leaves_range) > 0:
inner_leaves_range = np.stack(inner_leaves_range, axis=0)
else:
return [], None, splitting
return inner_leaves, inner_leaves_range, splitting
[docs]class GradientDebiasedCausalTree:
"""
GradientDebiasedCausalTree is a class that represents a gradient-based debiased causal tree model.
Arguments:
conf (ConfigTree): The configuration tree.
bin_mapper (BinMapper): The BinMapper instance.
kwargs: Additional keyword arguments.
"""
[docs] def __init__(self, conf: ConfigTree = None, bin_mapper: BinMapper = None, **kwargs):
self.conf = conf
self.info = CausalDataInfo(conf)
self.verbose = kwargs.get('verbose', False)
self.feature_used = []
self.feature_used_map = {} # key: sub-feature index, value: original feature index
self.bin_mapper = bin_mapper
conf_tree = conf.get('tree', conf)
self.op_loss = Loss.new_instance(conf_tree)
self.w_monotonic = 0
self.did = False
self.nthreads = kwargs.get('nthreads', conf.get('nthreads', 32))
[docs] def fit(self, gradients, cgradients, data: Dataset, eta=None):
"""
Fit the GradientDebiasedCausalTree model.
Arguments:
gradients: The gradients.
cgradients: The counterfactal gradients.
data (Dataset): The training dataset.
eta: The eta values.
Returns:
None
"""
hist, idx_map = self.preprocess(gradients, cgradients, data, eta, 1, self.info.feature_ratio)
root = GradientCausalTreeNode(self.conf, leaf_id=0, level_id=0)
self.root = root
leaves = [root]
leaves_range = np.array([[0, self.inst_used]], np.int32)
# calculate loss
leaf_id = root.leaf_id
gsum = hist.bin_grad_hist[leaf_id, 0].sum(0)
hsum = hist.bin_hess_hist[leaf_id, 0].sum(0)
root.theta = root.estimate(gsum, hsum, lambd=self.info.lambd)
if eta is not None:
self.root.eta = np.zeros([self.info.n_treatment, 1])
for i in range(self.info.max_depth):
if self.verbose:
TRACE(f'{"--"*10} the {i}-th iterations {"--"*10}')
split_conds = self.split(leaves, hist)
leaves, leaves_range = self.updater(
split_conds, gradients, cgradients, data, hist, idx_map, leaves, leaves_range, eta
)
if len(leaves) == 0 and self.verbose:
TRACE(f'{i}-th level early stop!')
break
for leaf in leaves:
leaf.is_leaf = True
self.postprocess()
[docs] def updater(
self,
split_conds: Dict,
gradients,
cgradients,
tr_data,
hist: Histogram,
idx_map,
leaves: List[GradientCausalTreeNode],
leaves_range,
eta=None,
):
"""
Update the tree nodes.
Arguments:
split_conds (Dict): The split conditions.
gradients: The gradients.
cgradients: The cgradients.
tr_data (Dataset): The training dataset.
hist (Histogram): The histogram.
idx_map: The index map.
leaves (List[GradientCausalTreeNode]): The list of tree nodes.
leaves_range: The range of each leaf.
eta: The eta values.
Returns:
Tuple[List[GradientCausalTreeNode], ndarray]: The updated tree nodes and updated leaf ranges.
"""
leaves, leaves_range, is_splitting = _filter(leaves, leaves_range)
n_leaf = len(split_conds)
if len(leaves) == 0 or len(split_conds) == 0:
return leaves, leaves_range
x_binned = to_row_major(tr_data.bin_features[self.feature_used], np.int32)
treatment = to_row_major(tr_data.treatment, np.int32)
sorted_split = OrderedDict(sorted(split_conds.items()))
split_info = np.asfarray([[info['feature'], info['threshold']] for _, info in sorted_split.items()]).astype(
np.int32
)
out = np.zeros([n_leaf * 2, 2], np.int32)
update_x_map(x_binned, idx_map, split_info, leaves_range, out)
leaves_range_new = out
hist.update_hists(
{
'bin_grad_hist': gradients[0],
'bin_hess_hist': gradients[1],
'bin_cgrad_hist': cgradients[0],
'bin_chess_hist': cgradients[1],
},
idx_map,
leaves_range_new,
treatment,
x_binned,
is_gradient=True,
is_splitting=is_splitting,
threads=self.nthreads,
)
# create new node
leaves_new = []
for i, leaf in enumerate(leaves):
ltheta, rtheta = split_conds[leaf.level_id]['theta']
l_eta, r_eta = split_conds[leaf.level_id]['eta']
leaf._children = [
GradientCausalTreeNode(
self.conf, leaf_id=leaf.leaf_id * 2 + 1, level_id=i * 2, theta=ltheta, eta=l_eta
),
GradientCausalTreeNode(
self.conf, leaf_id=leaf.leaf_id * 2 + 2, level_id=i * 2 + 1, theta=rtheta, eta=r_eta
),
]
leaves_new.extend(leaf._children)
fid, bin_id = split_info[i]
leaf.split_feature = self.feature_used_map[fid]
leaf.split_thresh = bin_id
leaf.split_rawthresh = self.bin_mapper.inverse_transform(bin_id, self.feature_used_map[fid])
return leaves_new, leaves_range_new
[docs] def split(self, leaves: List[GradientCausalTreeNode], hist: Histogram):
"""
Split the tree nodes.
Arguments:
leaves (List[GradientCausalTreeNode]): The list of tree nodes.
hist (Histogram): The histogram.
Returns:
Dict: The split conditions.
"""
return self._split_cpp(leaves, hist)
def _split_cpp(self, leaves: List[GradientCausalTreeNode], hist: Histogram):
"""
Split the tree nodes using C++ implementation.
Arguments:
leaves (List[GradientCausalTreeNode]): The list of tree nodes.
hist (Histogram): The histogram.
Returns:
Dict: The split conditions.
"""
info = self.info
n_leaves, m, n_bins, n_w, n_y = hist.bin_grad_hist.shape
t0, T, n_w = info.treat_dt, info.n_period, info.n_treatment
lambd = info.lambd
coef = info.coef
min_num = self.info.min_point_num_node
configs = {leaf.level_id: {fid: [0, n_bins] for fid in range(m)} for leaf in leaves}
parameters = f"""{{
"tree": {{
"lambd": {lambd},
"coeff": {coef},
"min_point_num_node": {min_num},
"min_var_rate": {0.1},
"monotonic_constraints": {self.w_monotonic}
}},
"threads": {self.nthreads},
"dataset": {{
"treat_dt": {t0}
}}
}}"""
res = gbct_splitting_losses(
configs,
hist.bin_grad_hist,
hist.bin_hess_hist,
hist.bin_cgrad_hist,
hist.bin_chess_hist,
hist.bin_counts,
parameters,
)
split_conds = {}
for leaf in leaves:
level_id = leaf.level_id
if level_id not in res:
leaf.is_leaf = True
continue
# opt_feature, opt_bin_idx, opt_loss
if bool(np.isinf(res[level_id][2])) is False:
split_conds[level_id] = {
'feature': res[level_id][0],
'threshold': res[level_id][1],
'loss': res[level_id][2],
'theta': (res[level_id][3][0], res[level_id][3][1]),
'eta': (None, None),
}
else:
leaf.is_leaf = True
return split_conds
[docs] def preprocess(self, gradients, cgradients, tr_data: Dataset, eta=None, subsample=1, subfeature=1):
"""
Preprocess the data before fitting the tree.
Arguments:
gradients: The gradients.
cgradients: The cgradients.
tr_data (Dataset): The training dataset.
eta: The eta values.
subsample: The subsample ratio.
subfeature: The subfeature ratio.
Returns:
Tuple[Histogram, ndarray]: The histogram and index.
"""
n, m = tr_data.features.shape
index = np.random.permutation(n).astype(np.int32)
n_used, m_used = np.math.ceil(n * subsample), np.math.ceil(m * subfeature)
features = self.info.feature_columns
# subsampling
if m_used < m:
tmp_feat = np.random.choice(self.info.feature_columns, m_used, replace=False)
features = [f for f in self.info.feature_columns if (f in tmp_feat)]
else:
features = self.info.feature_columns
hist = Histogram(self.conf)
if tr_data.bin_features is None:
self.bin_mapper.fit_dataset(tr_data)
else:
hist.columns = list(tr_data.feature_columns)
x_binned = to_row_major(tr_data.bin_features[features], np.int32)
self.feature_used = features
self.inst_used = n_used
orig_features = list(tr_data.feature_columns)
self.feature_used_map = {i: orig_features.index(f) for i, f in enumerate(features)}
# calculate histogram for outcome
w = to_row_major(tr_data.treatment, np.int32)
hist.update_hists(
{
'bin_grad_hist': gradients[0],
'bin_hess_hist': gradients[1],
'bin_cgrad_hist': cgradients[0],
'bin_chess_hist': cgradients[1],
},
index,
np.array([[0, n_used]], np.int32),
w,
x_binned,
True,
[],
self.nthreads,
)
return hist, index
[docs] def export(self):
"""
Export the tree model.
Returns:
Tuple[List[DidNode], List[Dict]]: The exported C++ nodes and Python nodes.
"""
nodes, queue = [], [self.root]
while len(queue) > 0:
nodes.append(queue.pop(0))
for child in nodes[-1].children:
queue.append(child)
# encode for each node
slim_cppnodes, slim_nodes = [], []
t_0 = self.info.treat_dt
for child in nodes:
bias, effect, debiased_effect = [], [], []
for _w in range(1, self.info.n_treatment):
if t_0 > 0:
bias.append(np.mean(child.theta[_w, :t_0] - child.theta[0, :t_0]))
else:
bias.append(0)
effect.append(child.theta[_w] - child.theta[0])
if isinstance(self, GradientDebiasedCausalTree):
debiased_effect.append(effect[-1][t_0:] - bias[-1])
else:
debiased_effect.append(effect[t_0:] + child.eta[_w])
info = {
'leaf_id': child.leaf_id,
'level_id': child.level_id,
'outcomes': child.theta,
'predict': child.theta,
'bias': np.array(bias),
'eta': child.eta,
'effect': np.array(effect),
'debiased_effect': np.array(debiased_effect),
'is_leaf': child.is_leaf,
'children': [-1, -1],
'split_feature': -1,
'split_thresh': -1,
}
if child.is_leaf is False:
info['children'] = [nodes.index(child.children[0]), nodes.index(child.children[1])]
info['split_feature'] = child.split_feature
info['split_thresh'] = child.split_rawthresh
slim_cppnodes.append(create_didnode_from_dict(info))
info['split_feature'] = self.info.feature_columns[info['split_feature']]
slim_nodes.append(info)
return slim_cppnodes, slim_nodes
[docs] def postprocess(self):
"""
Perform post-processing steps after fitting the tree.
Returns:
None
"""
nodes, queue = [], [self.root]
while len(queue) > 0:
nodes.append(queue.pop(0))
for child in nodes[-1].children:
queue.append(child)
self.nodes = nodes
def _predict(self, nodes, x, key='effect', out=None):
"""
Internal method to predict using the tree nodes.
Arguments:
nodes: The tree nodes.
x: The input data.
key: The prediction key.
out: The output array.
Returns:
ndarray: The predicted values.
Raises:
NotImplementedError: If the prediction key is not implemented.
"""
assert isinstance(nodes, list) and len(nodes) > 0, f'nodes must be list and at least one element!'
if key == 'effect':
shape = (x.shape[0],) + nodes[0].outcomes.shape
elif key == 'leaf_id':
shape = (x.shape[0], 1, 1)
elif key == 'outcomes':
shape = (x.shape[0],) + nodes[0].outcomes.shape
elif key == 'eta':
shape = (x.shape[0], 2, 1)
else:
raise NotImplementedError
if x.flags.c_contiguous is False:
x = to_row_major(x)
if out is None:
out = np.zeros(shape, np.float64)
predict(nodes, x, out, key)
return out
[docs] def predict(self, x, w=None, key='effect', out=None):
"""
Predict the treatment effect or other values.
Arguments:
x: The input data.
w: The treatment weights.
key: The prediction key.
out: The output array.
Returns:
ndarray: The predicted values.
"""
if key == 'cf_outcomes':
outcome = self.predict(x, key='outcomes')
cur_pred, cur_cpred = out
indexbyarray(outcome, w, cur_pred, cur_cpred)
return cur_pred, cur_cpred, None
return self._predict(self.export()[0], x, key, out)
[docs] def gradients(self, target, prediction, **kwargs):
"""
Compute the gradients of the loss function.
Arguments:
target: The target values.
prediction: The predicted values.
kwargs: Additional keyword arguments.
Returns:
ndarray: The gradients.
"""
return self.op_loss.gradients(target, prediction, **kwargs)
[docs] def loss(self, grad, hess, y_hat=None, **kwargs):
"""
Compute the loss function.
Arguments:
grad: The gradients.
hess: The hessians.
y_hat: The predicted values.
kwargs: Additional keyword arguments.
Returns:
ndarray: The loss values.
"""
lambd = kwargs.pop('lambd', 0)
return self.root.fast_loss(y_hat=y_hat, grad=grad, hess=hess, lambd=lambd, **kwargs)
[docs]class GradientDiDCausalTree(GradientDebiasedCausalTree):
"""
GradientDiDCausalTree is a class that represents a gradient-based debiased causal tree model with difference in
differences. It inherits from the GradientDebiasedCausalTree class.
Arguments:
conf (ConfigTree): The configuration tree.
bin_mapper (BinMapper): The BinMapper instance.
kwargs: Additional keyword arguments.
"""
[docs] def preprocess(self, gradients, cgradients, tr_data: Dataset, eta=None, subsample=1, subfeature=1):
"""
Preprocesses the data for the GradientDiDCausalTree model.
Arguments:
gradients: The gradients.
cgradients: The counterfactual gradients.
tr_data (Dataset): The training dataset.
eta: The parallel interval between the treated and control group.
subsample (float): The subsampling ratio for instances (default: 1).
subfeature (float): The subsampling ratio for features (default: 1).
Returns:
hist (Histogram): The constructed histogram.
index (ndarray): The permutation index.
"""
n, m = tr_data.features.shape
index = np.random.permutation(n).astype(np.int32)
n_used, m_used = np.math.ceil(n * subsample), np.math.ceil(m * subfeature)
features = self.info.feature_columns
# subsampling
if m_used < m:
tmp_feat = np.random.choice(self.info.feature_columns, m_used, replace=False)
features = [f for f in self.info.feature_columns if (f in tmp_feat)]
else:
features = self.info.feature_columns
if self.verbose:
INFO(f'=' * 50)
INFO(f'{"=" * 5} Used instance & features: {n_used}, {m_used}, {"=" * 5}')
INFO(f'=' * 50)
hist = Histogram(self.conf)
if tr_data.bin_features is None:
self.bin_mapper.fit_dataset(tr_data)
else:
hist.columns = list(tr_data.feature_columns)
x_binned = to_row_major(tr_data.bin_features[features], np.int32)
self.feature_used = features
self.inst_used = n_used
orig_features = list(tr_data.feature_columns)
self.feature_used_map = {i: orig_features.index(f) for i, f in enumerate(features)}
# calculate histogram for outcome
w = to_row_major(tr_data.treatment, np.int32)
hist.update_hists(
{
'bin_grad_hist': gradients[0],
'bin_hess_hist': gradients[1],
'bin_cgrad_hist': cgradients[0],
'bin_chess_hist': cgradients[1],
'bin_eta_hist': eta,
},
index,
np.array([[0, n_used]], np.int32),
w,
x_binned,
True,
[],
self.nthreads,
)
return hist, index
[docs] def updater(
self,
split_conds: Dict,
gradients,
cgradients,
tr_data,
hist: Histogram,
idx_map,
leaves: List[GradientCausalTreeNode],
leaves_range,
eta,
):
"""
Update the GradientCausalTree by performing splitting and updating histograms.
Arguments:
split_conds (Dict): The split conditions.
gradients: The gradients.
cgradients: The counterfactual gradients.
tr_data: The training dataset.
hist (Histogram): The histogram object.
idx_map: The index mapping.
leaves (List[GradientCausalTreeNode]): The list of leaves.
leaves_range: The range of leaves.
eta: The parallel interval between the treated and control group.
Returns:
leaves_new (List[GradientCausalTreeNode]): The new leaves.
leaves_range_new: The new range of leaves.
"""
leaves, leaves_range, is_splitting = _filter(leaves, leaves_range)
n_leaf = len(split_conds)
if len(leaves) == 0 or len(split_conds) == 0:
return leaves, leaves_range
x_binned = to_row_major(tr_data.bin_features[self.feature_used], np.int32)
treatment = to_row_major(tr_data.treatment, np.int32)
sorted_split = OrderedDict(sorted(split_conds.items()))
split_info = np.asarray([[info['feature'], info['threshold']] for _, info in sorted_split.items()]).astype(
np.int32
)
out = np.zeros([n_leaf * 2, 2], np.int32)
update_x_map(x_binned, idx_map, split_info, leaves_range, out)
leaves_range_new = out
hist.update_hists(
{
'bin_grad_hist': gradients[0],
'bin_hess_hist': gradients[1],
'bin_cgrad_hist': cgradients[0],
'bin_chess_hist': cgradients[1],
'bin_eta_hist': eta,
},
idx_map,
leaves_range_new,
treatment,
x_binned,
is_gradient=True,
is_splitting=is_splitting,
threads=self.nthreads,
)
# create new node
leaves_new = []
for i, leaf in enumerate(leaves):
ltheta, rtheta = split_conds[leaf.level_id]['theta']
l_eta, r_eta = split_conds[leaf.level_id]['eta']
leaf._children = [
GradientCausalTreeNode(
self.conf, leaf_id=leaf.leaf_id * 2 + 1, level_id=i * 2, theta=ltheta, eta=l_eta
),
GradientCausalTreeNode(
self.conf, leaf_id=leaf.leaf_id * 2 + 2, level_id=i * 2 + 1, theta=rtheta, eta=r_eta
),
]
leaves_new.extend(leaf._children)
fid, bin_id = split_info[i]
leaf.split_feature = self.feature_used_map[fid]
leaf.split_thresh = bin_id
leaf.split_rawthresh = self.bin_mapper.inverse_transform(bin_id, self.feature_used_map[fid])
return leaves_new, leaves_range_new
def _split_cpp(self, leaves: List[GradientCausalTreeNode], hist: Histogram):
"""
Split the leaf nodes at the current level using C++ implementation.
Arguments:
leaves: The list of leaf nodes.
hist: The histogram object.
Returns:
split_conds: The split conditions.
"""
# Step 1: Collect all the split points that need to calculate the loss
# Step 2: Perform the split
info = self.info
n_leaves, m, n_bins, n_w, n_y = hist.bin_grad_hist.shape
t0, T, n_w = info.treat_dt, info.n_period, info.n_treatment
lambd = info.lambd
coef = info.coef
min_num = self.info.min_point_num_node
configs = {leaf.level_id: {fid: [0, n_bins] for fid in range(m)} for leaf in leaves}
parameters = f"""{{
"tree": {{
"lambd": {lambd},
"coeff": {coef},
"min_point_num_node": {min_num},
"min_var_rate": {0.1},
"monotonic_constraints": {self.w_monotonic},
"parallel_l2": {self.info.parallel_l2}
}},
"threads": {self.nthreads},
"dataset": {{
"treat_dt": {t0}
}}
}}"""
res = didtree_splitting_losses(
configs,
hist.bin_grad_hist,
hist.bin_hess_hist,
hist.bin_cgrad_hist,
hist.bin_chess_hist,
hist.bin_eta_hist,
hist.bin_counts,
parameters,
)
split_conds = {}
for leaf in leaves:
level_id = leaf.level_id
if level_id not in res:
leaf.is_leaf = True
continue
# opt_feature, opt_bin_idx, opt_loss
if bool(np.isinf(res[level_id][2])) is False:
etas = res[level_id][4]
split_conds[level_id] = {
'feature': res[level_id][0],
'threshold': res[level_id][1],
'loss': res[level_id][2],
'theta': (res[level_id][3][0], res[level_id][3][1]),
'eta': (np.asarray([etas[0], -etas[0]]), np.asarray([etas[1], -etas[1]])),
}
else:
leaf.is_leaf = True
return split_conds
[docs] def split(self, leaves: List[GradientCausalTreeNode], hist: Histogram):
"""
Split the leaf nodes.
Arguments:
leaves: The list of leaf nodes.
hist: The histogram object.
Returns:
split_conds: The split conditions.
"""
return self._split_cpp(leaves, hist)
[docs] def predict(self, x, w=None, key='effect', out=None):
"""
Predict the treatment effect or other outcomes for given data.
Arguments:
x: The input data.
w: The treatment assignments (default: None).
key: The key specifying the prediction type (default: 'effect').
out: The output array (default: None).
Returns:
pred: The predicted values.
cpred: The counterfactual predicted values.
eta: The optimal parallel interval between the treated and control group.
"""
if key == 'cf_outcomes':
pred, cpred, _ = super().predict(x, w, key, out)
eta = self.predict(x, key='eta')
temp = np.zeros([cpred.shape[0], 1], dtype=cpred.dtype)
indexbyarray(eta, w, temp)
cpred -= temp
return pred, cpred, eta
return self._predict(self.export()[0], x, key, out)
def _predict(self, nodes, x, key='effect', out=None):
"""
Predicting for the given data using the exported nodes.
Arguments:
nodes: The exported nodes.
x: The input data.
key: The key specifying the prediction type (default: 'effect').
out: The output array (default: None).
Returns:
The predictions.
"""
assert isinstance(nodes, list) and len(nodes) > 0, f'nodes must be list and at least one element!'
if key == 'eta':
shape = (x.shape[0], self.info.n_treatment)
x = to_row_major(x)
if out is None:
out = np.zeros(shape, np.float64)
predict(nodes, x, out, key)
return out
return super()._predict(nodes, x, key, out)
[docs] def export(self):
"""
Export the GradientDiDCausalTree.
Returns:
slim_cppnodes: The exported nodes in C++ object.
slim_nodes: The exported nodes in python object.
"""
nodes, queue = [], [self.root]
while len(queue) > 0:
nodes.append(queue.pop(0))
for child in nodes[-1].children:
queue.append(child)
# encode for each node
slim_cppnodes, slim_nodes = [], []
for child in nodes:
effect, debiased_effect = [], []
for _w in range(1, self.info.n_treatment):
effect.append(child.theta[_w] - child.theta[0])
debiased_effect.append(effect[-1] + child.eta[_w])
info = {
'leaf_id': child.leaf_id,
'level_id': child.level_id,
'outcomes': child.theta,
'predict': child.theta,
'eta': child.eta,
'effect': np.array(effect),
'debiased_effect': np.array(debiased_effect),
'is_leaf': child.is_leaf,
'children': [-1, -1],
'split_feature': -1,
'split_thresh': -1,
}
if child.is_leaf is False:
info['children'] = [nodes.index(child.children[0]), nodes.index(child.children[1])]
info['split_feature'] = child.split_feature
info['split_thresh'] = child.split_rawthresh
slim_cppnodes.append(create_didnode_from_dict(info))
info['split_feature'] = self.info.feature_columns[info['split_feature']]
slim_nodes.append(info)
return slim_cppnodes, slim_nodes