Source code for openasce.inference.tree.cppnode

#    Copyright 2023 AntGroup CO., Ltd.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from typing import List

from .gbct_utils import common


[docs]def create_didnode_from_dict(info): """ Create a CppDebiasNode from a dictionary. Arguments: info (Dict): The node information. Returns: CppDebiasNode: The CppDebiasNode instance. """ # basic information for tree node assert len(info['children']) == 2, f'children should be 2!' node = common.CppDebiasNode() basic_keys = node.get_property_keys() for k in info.keys(): if k in basic_keys: setattr(node, k, info[k]) else: node.set_info(k, info[k]) return node
[docs]def predict(nodes: List[common.CppDiDNode], x, out, key, threads=20): """ Predict using the tree nodes. Arguments: nodes (List): The list of tree nodes. x (ndarray): The input data. out (ndarray): The output array. key (ndarray): The prediction key. threads (int): The number of threads. Returns: ndarray: The predicted values. Raises: RuntimeError: If the number of nodes is less than or equal to 0. ValueError: If the node type is not supported. """ if len(nodes) <= 0: raise RuntimeError(f'The number of nodes must be greater than 0!') elif isinstance(nodes[0], list) is False: nodes = [nodes] if isinstance(nodes[0][0], common.CppDiDNode): return common.predict_did(nodes, out, x, key, threads) elif isinstance(nodes[0][0], common.CppDebiasNode): return common.predict_debias(nodes, out, x, key, threads) else: raise ValueError(f'{type(nodes[0][0])} is not supported!')