Source code for openasce.inference.tree.histogram

#    Copyright 2023 AntGroup CO., Ltd.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import numpy as np
from pyhocon import ConfigTree

from .utils import update_histogram, update_histograms
from .dataset import Dataset
from .information import CausalDataInfo


[docs]class Histogram(object):
[docs] def __init__(self, conf: ConfigTree): hist_conf = conf.get('histogram', conf) self.conf = conf self.info = CausalDataInfo(conf) self.tr_dts = [] self.max_bin_num = hist_conf.max_bin_num # Maximum number of bins self.min_point_per_bin = hist_conf.min_point_per_bin # Minimum number of points for binning # [leaf, feature, treatment, bin, target] self.bin_counts = None self.bin_hists = {} self._data = None
[docs] def update_hists(self, target, index, leaves_range, treatment, bin_features, is_gradient, is_splitting, threads): """ Update histograms for all nodes in the same level of a tree Arguments: target (_type_): _description_ index (_type_): _description_ leaves_range (_type_): _description_ treatment (_type_): _description_ bin_features (_type_): _description_ is_gradient (bool): _description_ is_splitting (bool): _description_ threads (_type_): _description_ Raises: ValueError: _description_ Returns: _type_: _description_ """ n, m = bin_features.shape n_w = self.info.n_treatment l = len(leaves_range) leaves = list(range(0, l, 2)) n_bins = self.max_bin_num if is_gradient: assert isinstance(target, (dict, )), f'target should be a dict!' keys = [k for k in target.keys()] outs = [np.zeros([l, m, n_bins, n_w, target[k].shape[1]], target[k].dtype) for k in keys] targets = [target[k] for k in keys] # update histogram of target update_histograms(targets, bin_features, index, leaves_range, treatment, outs, leaves, n_w, n_bins, threads) for i, k in enumerate(keys): if l > 1: outs[i][1::2] = self.bin_hists[k][is_splitting] - outs[i][::2] self.bin_hists[k] = outs[i] else: assert isinstance(target, (dict, )), '' keys = target.keys() outs = [np.zeros([l, m, n_bins, n_w, target[k].shape[1]], target[k].dtype) for k in keys] targets = [target[k] for k in keys] # update histogram of target update_histograms(targets, bin_features, index, leaves_range, treatment, outs, leaves, n_w, n_bins, threads) for i, k in enumerate(keys): if l > 1: outs[i][1::2] = self.bin_hists[k][is_splitting] - outs[i][::2] self.bin_hists[k] = outs[i] # update counts out = np.zeros([l, m, n_bins, n_w, 1], np.int32) update_histogram(np.ones([n, 1], np.int32), bin_features, index, leaves_range, treatment, out, leaves, n_w, n_bins, threads) if l > 1: out[1::2] = np.expand_dims(self.bin_counts[is_splitting], -1) - out[::2] self.bin_counts = out[:, :, :, :, 0] return self
def __getattr__(self, __name: str): """ Get the attribute value. Arguments: __name (str): The name of the attribute. Returns: ndarray: The attribute value. Raises: AttributeError: If the attribute is not found. """ if __name in self.bin_hists: return self.bin_hists[__name] raise AttributeError()
[docs] @classmethod def new_instance(cls, dataset: Dataset, conf: ConfigTree = None, **kwargs): """ Create a new instance of the histogram. Arguments: dataset (Dataset): The dataset. conf (ConfigTree): The configuration tree. kwargs: Additional keyword arguments. Returns: Histogram: The new instance of the histogram. """ hist = cls(conf, dataset.treatment, dataset.targets) hist.binning(dataset) return hist