# Copyright 2023 AntGroup CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import List, Tuple
import numpy as np
from openasce.discovery.causal_graph import CausalGraph
from openasce.utils.logger import logger
[docs]class Strategy(object):
"""General class to implement different structure learning methods
Attributes
edge_gain (float): the minimal gain of adding edge.
target_name (str): the name of the node that will be label.
"""
[docs] def __init__(self, node_names: List[str], **kwargs):
"""Contructor
Arguments:
node_names: the name of nodes
"""
self.node_names = node_names
self.strategy_name = "k2"
self.edge_gain = kwargs.get("edge_gain", 20)
self.target_name = kwargs.get("target_name", None)
self.target_index = (
self.node_names.index(self.target_name) if self.target_name else None
)
[docs] def run(self, data: np.ndarray, **kwargs) -> Tuple:
"""Run the actual strategy
Arguments:
data: the features of samples
**kwargs (dict): dictionnary with method specific args
Returns:
"""
g, s = self.k2(data=data, **kwargs)
logger.info(f"Best score is {s}")
return g, s
[docs] def best_parent(self, *, g, s, node_i, data, max_parents, r, s_i):
"""Search for best parent
Returns g by adding to node i the best parent that maximizes the score
Arguments:
Returns:
"""
found_new = False
g_max = g
s_max = s
shuffle_no = np.random.permutation(range(g.n))
if self.target_name:
# The target node can't be any node's parent if set target, so remove it from the node candidate
shuffle_no = np.delete(
shuffle_no, np.where(shuffle_no == self.target_index)
)
shuffle_no_new = shuffle_no
edge_gain = self.edge_gain
for j in shuffle_no_new:
if j != node_i and j not in g.parents[node_i]:
g_work = CausalGraph(bn=g)
if g_work.add_edge(j, node_i, max_parents):
# Try to add one edge between (j, node_i)
new_score = g_work.score_node(node_i, data, r)
logger.debug(f"new_score={new_score}")
s_new = s - s_i + new_score
if s_new > s_max + edge_gain:
found_new = True
g_max = g_work
s_max = s_new
return g_max, s_max, found_new
[docs] def k2(self, data: np.ndarray, **kwargs):
"""Implements k2 algorithm
Agrument:
data: the features of samples
"""
names = self.node_names
global_max_parents = (
kwargs.get("max_parents")
if kwargs.get("max_parents")
else len(list(names)) / 2
)
max_parents = global_max_parents
logger.info(
f"current max parent number: {global_max_parents}, target_name={self.target_name}"
)
ordering = np.random.permutation(range(len(names)))
if self.target_index: # set target only so put target first one
ordering = np.delete(ordering, np.where(ordering == self.target_index))
ordering = np.insert(ordering, 0, self.target_index)
max_parents = min(max_parents, len(list(names)) / 2)
g = CausalGraph(names)
global_s = g.score(data)
logger.info(f"initial graph score:{global_s}")
curr_data_r = g.compute_r(data)
logger.info(f"graph curr_data_r={curr_data_r}")
curr_pos = 0
ordering_size = len(ordering)
while curr_pos < ordering_size:
node_i = ordering[curr_pos]
s_i = g.score_node(node_i, data, curr_data_r)
logger.info(f"Begin to explore node {node_i}, s_i={s_i}")
curr_parent_count, found_new = 0, True
while found_new and curr_parent_count < max_parents:
g, global_s, found_new = self.best_parent(
g=g,
s=global_s,
node_i=node_i,
data=data,
max_parents=global_max_parents,
r=curr_data_r,
s_i=s_i,
)
curr_parent_count += 1
max_parents = global_max_parents
curr_pos += 1
return g, global_s