# Copyright 2023 AntGroup CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import typing
import numpy as np
import tensorflow as tf
from openasce.extension.debias.common.utils import DNNModel, mmd_rbf
from openasce.extension.debias_model import CausalDebiasModel
[docs]class CFRModel(CausalDebiasModel):
"""Building a CFR model.
Model: CFR (CounterFactual Regression).
Paper: Estimating individual treatment effect: generalization bounds and algorithms.
Link: http://proceedings.mlr.press/v70/shalit17a/shalit17a.pdf.
Author: Uri Shalit, Fredrik D. Johansson and David Sontag.
"""
[docs] def __init__(self, params: typing.Dict) -> None:
"""Initialize.
Args:
params: parameter dict.
"""
super().__init__()
# initialize params.
self.hidden_units = params.get("hidden_units", [64, 16, 1])
self.hidden_units_emb = params.get("hidden_units_emb", [128, 64])
self.act_fn = params.get("act_fn", "relu")
self.l2_reg = params.get("l2_reg", 0.001)
self.dropout_rate = params.get("dropout_rate", 0)
self.use_bn = params.get("use_bn", False)
self.apply_final_act = params.get("apply_final_act", False)
self.apply_final_act_emb = params.get("apply_final_act_emb", True)
self.lr = params.get("lr", 0.0001)
self.proportion = params.get("proportion", 0.9)
self.w = params.get("w", 0.1)
# define model.
self.model_emb = DNNModel(
hidden_units=self.hidden_units_emb,
act_fn=self.act_fn,
l2_reg=self.l2_reg,
use_bn=self.use_bn,
dropout_rate=self.dropout_rate,
apply_final_act=self.apply_final_act_emb,
)
self.control_net = DNNModel(
hidden_units=self.hidden_units,
act_fn=self.act_fn,
l2_reg=self.l2_reg,
use_bn=self.use_bn,
dropout_rate=self.dropout_rate,
apply_final_act=self.apply_final_act,
)
self.treatment_nets = DNNModel(
hidden_units=self.hidden_units,
act_fn=self.act_fn,
l2_reg=self.l2_reg,
use_bn=self.use_bn,
dropout_rate=self.dropout_rate,
apply_final_act=self.apply_final_act,
)
self.optimizer = self.get_optimizer()
@property
def trainable_variables(self):
variables = (
self.model_emb.trainable_variables
+ self.control_net.trainable_variables
+ self.treatment_nets.trainable_variables
)
return variables
[docs] def forward(
self, x: tf.Tensor, c: typing.Dict[str, tf.Tensor], training: bool
) -> typing.Dict[str, tf.Tensor]:
feature, treatment = c.get("feature"), c.get("treatment")
features_emb = self.model_emb(feature, training=training)
logit_control = self.control_net(features_emb, training=training)
logit_treatment = self.treatment_nets(features_emb, training=training)
effect = logit_treatment - logit_control
predictions = {
"effect": effect,
"logit_treatment": logit_treatment,
"logit_control": logit_control,
"treatment": treatment,
"features_emb": features_emb,
}
return predictions
def _call(
self,
*,
x: np.ndarray,
y: np.ndarray,
c: typing.Dict[str, np.ndarray],
training: bool
) -> typing.Union[None, typing.Dict[str, np.ndarray]]:
"""Building a callable function.
fit and predict are the base class interface methods to be called by outside users, which should not be overloaded.
_call is used to implement the logic of the algorithm after it has been overloaded.
Args:
x: the original input feature.
y: the original input label.
c: the original input dict, here, {'feature': np.ndarray, 'treatment': np.ndarray}.
feature: train feature.
treatment: binary treatment or multiple discrete treatment.
training: bool, identify the status.
Returns:
A callable function,
for training, return loss, optimizer, and model;
for inference, return the prediction dict.
"""
def grad(x, c, training, labels):
with tf.GradientTape() as tape:
predictions = self.forward(x, c, training)
loss_value = self.loss(predictions, labels)
return loss_value, tape.gradient(loss_value, self.trainable_variables)
if training:
# train procedure
# calculate loss, gradient, optimizer updates model, etc.
# The framework doesn't care about return values.
loss_value, grads = grad(x, c, training, y)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
else:
# inference procedure.
# calculate the prediction and return with a dict.
predictions = self.forward(x, c, training)
if y is not None:
predictions["label"] = y
return predictions
[docs] def loss(self, predictions: typing.Dict, labels: tf.Tensor):
"""Compute scalar loss tensors with respect to provided labels.
Args:
predictions: a dictionary holding predicted tensors.
labels: label tensor.
Returns:
A scalar loss or A dictionary mapping strings (loss names) to scalar loss.
"""
logit_treatment = predictions["logit_treatment"]
logit_control = predictions["logit_control"]
treatment = predictions["treatment"]
features_emb = predictions["features_emb"]
loss_treat = tf.compat.v1.losses.sigmoid_cross_entropy(
labels, logit_treatment, weights=treatment
)
loss_control = tf.compat.v1.losses.sigmoid_cross_entropy(
labels, logit_control, weights=1 - treatment
)
ipm_loss = mmd_rbf(features_emb, treatment, self.proportion, 2.0)
loss = loss_treat + loss_control + self.w * ipm_loss
return loss
[docs] def get_optimizer(self):
"""Build the optimizer.
Args:
Returns:
An optimizer.
"""
# lr hyper-parameters.
lr = self.lr
return tf.keras.optimizers.Adam(lr=lr)