Source code for openasce.extension.debias_model

#    Copyright 2023 AntGroup CO., Ltd.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from typing import Dict, Iterable, Union, Tuple

import numpy as np

from openasce.core.runtime import Runtime
from openasce.utils.logger import logger


[docs]class CausalDebiasModel(Runtime): """Debias Inference Class Base class of the causal debias Attributes: """
[docs] def __init__(self) -> None: super().__init__()
[docs] def fit( self, *, X: Iterable[np.ndarray] = None, Y: Iterable[np.ndarray] = None, C: Dict[str, Iterable[np.ndarray]] = None, Z: Iterable[Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]] = None, num_epochs: int = 1, **kwargs, ) -> None: """Feed the sample data and train the model on the samples. Arguments: X: Features of the samples. Y: Outcomes of the samples. C: Other concerned columns of the samples, e.g. {'weight': Iterable[np.ndarray]} Z: The iterable object returning (a batch of X, a batch of Y, a batch of C) if having num_epochs: number of the train epoch Returns: None """ if C is not None and not isinstance(C, dict): raise ValueError(f"C should be dict.") if not (X or Y or C or Z): raise ValueError(f"One of (X, Y, C, Z) should be set.") self._X = X self._Y = Y self._C = C self._Z = Z self._train_loop(num_epochs=num_epochs, **kwargs)
[docs] def predict( self, *, X: Iterable[np.ndarray] = None, C: Dict[str, Iterable[np.ndarray]] = None, Z: Iterable[Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]] = None, **kwargs, ) -> None: """Feed the sample data Estimate the effect on the samples, and get_result method can be used to get the result of prediction Arguments: X: Features of the samples. C: Other concerned columns of the samples, e.g. {'weight': Iterable[np.ndarray]} Z: The iterable object returning (a batch of X, a batch of Y, a batch of C) if having Returns: None """ if C is not None and not isinstance(C, dict): raise ValueError(f"C should be dict.") if not (X or C or Z): raise ValueError(f"One of (X, C, Z) should be set.") self._X = X self._C = C self._Y = None self._Z = Z self._result = self._predict_loop(**kwargs)
[docs] def get_result(self): """Get the predict result Arguments: Returns: predict result """ return self._result
def _call( self, *, x: np.ndarray, y: np.ndarray, c: Dict[str, np.ndarray], training: bool ) -> Union[None, Dict[str, np.ndarray]]: """ The derived class should override this method to train the model using loss_object and optimizer or predict on the samples. Arguments: x: one batch of features y: one batch of labels c: one batch for each concerned columns of the samples, e.g. {'weight': Iterable[np.ndarray]} training: True means training and False for predict Returns: None for training and Dict for predict Raise: StopIteration: The process can be finished """ raise NotImplementedError(f"Not implementation for _call method") def _train_loop(self, *, num_epochs, **kwargs): """main loop for train""" curr_epoch = 0 while curr_epoch < num_epochs: for z in self._generator(): self._call( x=z[0] if len(z) > 0 else None, y=z[1] if len(z) > 1 else None, c=z[2] if len(z) > 2 else None, training=True, ) logger.info(f"Finish epoch {curr_epoch}.") curr_epoch += 1 def _predict_loop(self): """main loop for prediction""" f_result = {} for z in self._generator(): result = self._call( x=z[0] if len(z) > 0 else None, y=z[1] if len(z) > 1 else None, c=z[2] if len(z) > 2 else None, training=False, ) for k, v in result.items(): f_result[k] = np.hstack([f_result[k], v]) if k in f_result else v return f_result def _generator(self, **kwargs): """main loop""" def none_generator(): while True: yield None if self._Z: iz = iter(self._Z) else: ix = iter(self._X) if self._X else none_generator() iy = iter(self._Y) if self._Y else none_generator() ics = ( [(i[0], iter(i[1])) for i in self._C.items()] if self._C else {"nonsense_placeholder": none_generator()} ) iz = map( lambda _: ( next(ix), next(iy), {k: next(v) for k, v in ics}, ), none_generator(), ) return iz