Source code for openasce.inference.learner.dml

#    Copyright 2023 AntGroup CO., Ltd.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


from typing import Iterable, Any, NoReturn

import numpy as np
from econml.dml import DML as _DML
from econml.dml._rlearner import _ModelNuisance
from econml.utilities import filter_none_kwargs

from openasce.inference.inference_model import InferenceModel


[docs]class DML(_DML, InferenceModel):
[docs] def fit( self, *, X: Iterable[np.ndarray], Y: Iterable[np.ndarray], T: Iterable[np.ndarray], **kwargs ): """Feed the sample data and train the model used to effect on the samples. Arguments: X: Features of the samples. Y: Outcomes of the samples. T: Treatments of the samples. Returns: """ def _nuisance_fit( _self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None ): assert Z is None, "Cannot accept instrument!" param = { "X": X, "W": W, "T": T, "Y": Y, "model_t": _self._model_t, "model_y": _self._model_y, "sample_weight": sample_weight, "groups": groups, } results = self.launch(num=2, param=param, dataset=None) for r in results: if "model_t" in r: _self._model_t = r["model_t"] elif "model_y" in r: _self._model_y = r["model_y"] return _self _ModelNuisance.fit = _nuisance_fit super().fit(Y, T, X=X, **kwargs)
[docs] def todo(self, idx: int, total_num: int, param: Any, dataset: Iterable) -> Any: model_t = param.pop("model_t") model_y = param.pop("model_y") X, Y, T, W = param["X"], param["Y"], param["T"], param["W"] sample_weight, groups = param["sample_weight"], param["groups"] result = {"idx": idx} if idx == 0: model_t.fit( X, W, T, **filter_none_kwargs(sample_weight=sample_weight, groups=groups) ) result["model_t"] = model_t elif idx == 1: model_y.fit( X, W, Y, **filter_none_kwargs(sample_weight=sample_weight, groups=groups) ) result["model_y"] = model_y return result
[docs] def estimate(self, *, X: Iterable[np.ndarray]) -> NoReturn: """Feed the sample data and estimate the effect on the samples Arguments: X: Features of the samples. Returns: """ self._estimate_result = self.const_marginal_effect(X)