Source code for openasce.inference.graph_inference

#    Copyright 2023 AntGroup CO., Ltd.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import copy
from functools import reduce
from typing import Dict, Iterable, List

import numpy as np

from openasce.discovery.causal_graph import CausalGraph
from openasce.discovery.discovery import Discovery
from openasce.discovery.graph_node_form import GraphNodeForm
from openasce.inference.inference_model import InferenceModel
from openasce.utils.logger import logger


[docs]class GraphInferModel(InferenceModel): """The inference using the causal graph Attributes: graph: The causal graph. If not set, the class will try to find it out if discovery is available. column_names: all names of sample treatment_name: treatment column name in column_names label_name: target column name in column_names """
[docs] def __init__( self, *, graph: CausalGraph = None, column_names: List[str] = None, treatment_name: str = None, label_name: str = None, num_iteration=20, ) -> None: """ Arguments: graph: causal graph column_names: all names of column treatment_name: the name of treatment column label_name: the name of target name """ super().__init__() self._graph = graph self._column_names = column_names self._treatment_name = treatment_name self._label_name = label_name self._discovery = None self._data = None self._num_iteration = num_iteration self._label_value = None
@property def data(self): assert self._data is not None, f"Must have sample data." return self._data @property def graph(self): assert self._graph is not None, "The graph object should be set" return self._graph @graph.setter def graph(self, value): assert self._graph is None, "The graph object should be set once only" self._graph = value # graph is available, set the column names using graph columns self.column_names = list(self.graph.names_to_index.keys()) @property def column_names(self): """All nodes' name. Note: should include the treatment node and label node. """ assert self._column_names is not None, "The column names should be set" return self._column_names @column_names.setter def column_names(self, value: List[str]): assert self._column_names is None, "The column names should be set once only" self._column_names = value @property def treatment_name(self): assert self._treatment_name is not None, "The treatment name should be set" return self._treatment_name @treatment_name.setter def treatment_name(self, value: str): assert ( self._treatment_name is None ), "The treatment name should be set once only" self._treatment_name = value @property def label_name(self): assert self._label_name is not None, "The label name should be set" return self._label_name @label_name.setter def label_name(self, value: str): assert self._label_name is None, "The label name should be set once only" self._label_name = value @property def discovery(self) -> Discovery: assert self._discovery is not None, "The discovery object should be set" return self._discovery @discovery.setter def discovery(self, value: Discovery): self._discovery = value
[docs] def fit( self, *, X: Iterable[np.ndarray], Y: Iterable[np.ndarray] = None, T: Iterable[np.ndarray] = None, **kwargs, ) -> None: """Feed the sample data to train the graph Arguments: X: All features of the samples including the treatment and the label node. Y: Ignore in causal graph inference T: Ignore in causal graph inference. Returns: """ if Y is not None or T is not None: logger.info( f"All columns used in the casual graph discovery should be in X and Y/T is ignore" ) self._data = np.vstack(list(iter(X))) if self.graph: if not self.graph.para: self.graph.calculate_parameter(data=self._data) elif self._discovery: self.column_names = self.discovery.node_names logger.info(f"Begin to discover the causal graph.") self.discovery.fit(self._data) self.graph, _ = self.discovery.get_result() if not self.graph.para: self.graph.calculate_parameter(data=self._data) else: raise ValueError(f"There is neither causal graph nor discovery.")
[docs] def estimate( self, *, X: Iterable[np.ndarray] = None, Y: Iterable[np.ndarray] = None, T: Iterable[np.ndarray] = None, **kwargs, ) -> None: """Feed the sample data and estimate the outcome on the samples Arguments: X: Features of the samples. Y: Ignore in causal graph inference T: Ignore in causal graph inference kwargs: {'treat_value': treat_value, 'label_value':label_value} Returns: """ if ( not self.column_names or not self.treatment_name or not self.label_name or self.treatment_name not in self.column_names or self.label_name not in self.column_names ): raise ValueError( f"Either label or treatment is not set, or treat or label is not in columns." ) self._data = np.vstack(list(iter(X))) if X else self._data self._label_value = kwargs.get(InferenceModel.LABEL_VALUE, None) treatment_value = kwargs.get( InferenceModel.TREATMENT_VALUE ) # Support one treatment value for now condition = dict( map( lambda p: (p[0], p[1] if isinstance(p[1], List) else [p[1]]), kwargs.get(InferenceModel.CONDITION_DICT_NAME, {}).items() if isinstance(kwargs.get(InferenceModel.CONDITION_DICT_NAME, {}), dict) else {}, ) ) result = self._do_lbp( do_condition={ self.treatment_name: treatment_value if isinstance(treatment_value, list) else [treatment_value] }, condition=condition, ) self._result = dict( map(lambda x: [self.graph.index_to_names.get(x[0]), x[1]], result.items()) )
[docs] def get_result(self): """Get the estimated result The sub-class should implement this routine and runtime invokes it. Arguments: Returns: The estimation result. """ if self._label_value: self._result.get(self.label_name) else: return self._result
[docs] def output(self): """Output the estimated result to files The sub-class should implement this routine and runtime invokes it. Arguments: Returns: """ raise NotImplementedError(f"Not implement for abstract class")
def _do_lbp(self, *, do_condition: Dict, condition: Dict): """compute the under the treatment Argument: do_condition: treatment node and its value condition: node and its value need to compute the outcome """ n_iterr = self._num_iteration data = self._data graph = self.graph s_ij, s_ji, p_node, p_factor, node_link = {}, {}, {}, {}, {} do_node_list = list(do_condition.keys()) assert len(do_node_list) == 1, "Only one DO node is supported for now" do_node_n = [graph.names_to_index[l] for l in do_node_list] all_conditions = copy.deepcopy(condition) all_conditions.update(do_condition) node_link = dict( map( lambda i: [ i, set( filter( lambda j: j == i or (j not in do_node_n and i in graph.parents[j]), range(graph.n), ) ), ], range(graph.n), ) ) for i in range(graph.n): values = np.unique(data[:, i]) tmp_node = [[v_, 1 / len(values)] for v_ in values] p_node[i] = GraphNodeForm( tmp_node, columns=[graph.index_to_names[i], GraphNodeForm.SCORE_COLUMN_NAME], ) p_node[i] = self._strict_to_condition(all_conditions, p_node[i]) decrete_ = [ [v_, list(data[:, i]).count(v_) / len(data[:, i])] for v_ in values ] s_ij[i] = {} s_ji[i] = {} par_set = set([i]) if i in do_node_n else set(list(graph.parents[i]) + [i]) for k in par_set: values_ = np.unique(data[:, k]) tmp_decrete_ = [ [v_, list(data[:, k]).count(v_) / len(data[:, k])] for v_ in values_ ] s_ij[i][k] = GraphNodeForm( tmp_decrete_, columns=[graph.index_to_names[k], GraphNodeForm.SCORE_COLUMN_NAME], ) s_ij[i][k] = self._strict_to_condition(all_conditions, s_ij[i][k]) for k in node_link[i]: s_ji[i][k] = GraphNodeForm( decrete_, columns=[graph.index_to_names[i], GraphNodeForm.SCORE_COLUMN_NAME], ) s_ji[i][k] = self._strict_to_condition(all_conditions, s_ji[i][k]) p_factor = copy.deepcopy(graph.para) for node_i in p_factor.keys(): p_factor[node_i] = self._strict_to_condition( all_conditions, p_factor[node_i] ) if node_i in do_node_n: # it is do_node p_factor[node_i].set_groupby_sum(do_node_list[0]) run_count = 0 ori_dis = copy.deepcopy(p_node) while run_count < n_iterr: for num in range(graph.n): for out_j in s_ji[num].keys(): s_ji[num][out_j] = self._update_normalize( reduce( lambda x, y: self._update_multipy(x, y), map( lambda in_j: s_ij[out_j][in_j], filter(lambda in_j: in_j != num, s_ij[out_j].keys()), ), p_factor[out_j], ) ) for num in range(graph.n): for out_i in s_ij[num].keys(): s_ij[num][out_i] = self._update_normalize( reduce( lambda x, y: self._update_multipy(x, y), map( lambda in_i: s_ji[out_i][in_i], filter(lambda in_i: in_i != num, s_ji[out_i].keys()), ), p_node[out_i], ) ) error_list = [] for num in range(graph.n): tmp = self._update_normalize( reduce( lambda x, y: self._update_multipy(x, y), map(lambda out_j: s_ji[num][out_j], s_ji[num].keys()), p_node[num], ) ) error = ori_dis[num].get_score_deviation(tmp) error_list.append(error) ori_dis[num] = tmp if sum(error_list) < 0.00001: break run_count += 1 logger.info(f"Finish the lbp process. ") return ori_dis def _strict_to_condition(self, condition, p_r): column_list: List[str] = p_r.columns for key, value in condition.items(): if key in column_list: p_r.set_flag_zero(key, value) p_r.set_norm() return p_r def _update_multipy(self, form_a: GraphNodeForm, form_b: GraphNodeForm): def _update_multipy_internal(form1: GraphNodeForm, form2: GraphNodeForm, flag): # Pick up the key except GraphNodeForm.SCORE_COLUMN_NAME unique_key_index = 1 if form2.score_column_index == 0 else 0 unique_key = form2.columns[unique_key_index] form1.multiply_score_column(unique_key, form2) form1.sort_by_column(unique_key) res = {} res_r = [] for l_data in form1.data: k_st = "".join( map( lambda y: str(int(l_data[y[0]])), filter( lambda x: not ( (x[1] == unique_key and flag) or (x[1] == GraphNodeForm.SCORE_COLUMN_NAME) ), enumerate(form1.columns), ), ) ) res[k_st] = res.get(k_st, []) res[k_st].append(l_data) unique_key_index = form1.index(unique_key) columns_list = form1.columns for d_dt in res.values(): dc = d_dt[0] dc[form1.score_column_index] = sum( map(lambda x: x[form1.score_column_index], d_dt) ) if flag: at_list = [True for i in range(dc.size)] at_list[unique_key_index] = False dc = dc[at_list] res_r.append(dc) if flag: columns_list.pop(unique_key_index) return GraphNodeForm(res_r, columns_list) flag = not (set(list(form_a.columns)) == set(list(form_b.columns))) res_r = _update_multipy_internal(copy.deepcopy(form_a), form_b, flag) return res_r def _update_normalize(self, node: GraphNodeForm): node.set_norm() return node