# Copyright 2023 AntGroup CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict, Iterable, List, Tuple, Union
import numpy as np
import tensorflow as tf
from openasce.extension.debias.common.utils import DNNLayer
from openasce.inference.inference_model import InferenceModel
from openasce.utils.logger import logger
[docs]class TModel(InferenceModel):
"""
T_model based on NN
"""
[docs] def __init__(
self, hidden_units: Dict, lr: float = 0.1, name: str = "t_model"
) -> None:
"""Initialize.
Args:
hidden_units (dict): list of positive integer, the layer number and units in each layer.
lr (float): learning rate
"""
super().__init__()
self.hidden_units = hidden_units
self.name = name
self.base_embed_layer = DNNLayer(
hidden_units=hidden_units["base"],
activation=tf.nn.leaky_relu,
apply_final_act=True,
name="base_embed_layer",
)
self.test_embed_layer = DNNLayer(
hidden_units=hidden_units["test"],
activation=tf.nn.leaky_relu,
apply_final_act=True,
name="test_embed_layer",
)
self.control_embed_layer = DNNLayer(
hidden_units=hidden_units["control"],
activation=tf.nn.leaky_relu,
apply_final_act=False,
name="control_embed_layer",
)
self.optimizer = self.get_optimizer(lr=lr)
[docs] def fit(
self,
X: Iterable[tf.Tensor] = None,
Y: Iterable[tf.Tensor] = None,
T: Iterable[tf.Tensor] = None,
*,
Z: Iterable[Tuple[tf.Tensor, tf.Tensor, tf.Tensor]] = None,
num_epochs: int = 1,
**kwargs,
) -> None:
"""Feed the sample data and train the model on the samples.
Arguments:
X: Features of the samples.
Y: Outcomes of the samples.
T: Treatments of the samples.
Z: The iterable object returning (a batch of X, a batch of Y, a batch of T)
num_epochs: number of the train epoch
Returns:
None
"""
self._X = X
self._Y = Y
self._T = T
self._Z = Z
self._train_loop(num_epochs=num_epochs, **kwargs)
[docs] def estimate(
self,
X: Iterable[tf.Tensor] = None,
T: Iterable[tf.Tensor] = None,
*,
Z: Iterable[Tuple[tf.Tensor, tf.Tensor, tf.Tensor]] = None,
**kwargs,
) -> None:
"""Feed the sample data
Estimate the effect on the samples, and get_result method can be used to get the result of prediction
Arguments:
X: Features of the samples.
T: Treatments of the samples.
Z: The iterable object returning (a batch of X, a batch of Y, a batch of T)
Returns:
None
"""
self._X = X
self._T = T
self._Z = Z
self._estimate_result = self._predict_loop(**kwargs)
def _predict_loop(self):
"""main loop for prediction"""
f_result = {}
for z in self._generator():
result = self._call(
x=z[0] if len(z) > 0 else None,
y=z[1] if len(z) > 1 else None,
t=z[2] if len(z) > 2 else None,
training=False,
)
for k, v in result.items():
f_result[k] = np.hstack([f_result[k], v]) if k in f_result else v
return f_result
def _train_loop(self, *, num_epochs, **kwargs):
"""main loop for train"""
curr_epoch = 0
while curr_epoch < num_epochs:
for z in self._generator():
self._call(
x=z[0] if len(z) > 0 else None,
y=z[1] if len(z) > 1 else None,
t=z[2] if len(z) > 2 else None,
training=True,
)
logger.info(f"Finish epoch {curr_epoch}.")
curr_epoch += 1
@property
def trainable_variables(self):
variables = (
self.base_embed_layer.trainable_variables
+ self.test_embed_layer.trainable_variables
+ self.control_embed_layer.trainable_variables
)
return variables
[docs] def forward(
self, x: tf.Tensor, t: tf.Tensor, training: bool
) -> Dict[str, tf.Tensor]:
base_emb = self.base_embed_layer(x)
test_logits = self.test_embed_layer(base_emb)
control_logits = self.control_embed_layer(base_emb)
treatment = tf.cast(tf.reshape(t, [-1, 1]), dtype="float32")
treatment = tf.concat([1 - treatment, treatment], axis=1)
logits = tf.concat([control_logits, test_logits], axis=1)
outcome_logits = tf.reshape(
tf.reduce_sum(tf.multiply(logits, treatment), axis=1), [-1, 1]
)
predictions = {
"test_logits": test_logits,
"control_logits": control_logits,
"outcome_logits": outcome_logits,
}
return predictions
def _call(
self, *, x: tf.Tensor, y: tf.Tensor, t: tf.Tensor, training: bool
) -> Union[None, Dict[str, tf.Tensor]]:
"""
Arguments:
x: one batch of features
y: one batch of labels, shape: [batch_size], outcome labels
t: one batch of treatments
training: True means training and False for predict
Returns:
None for training and Dict for predict
"""
def grad(x, t, training, labels):
with tf.GradientTape() as tape:
predictions = self.forward(x, t, training)
loss_value = self.loss(predictions, labels)
return loss_value, tape.gradient(loss_value, self.trainable_variables)
if training:
y = tf.reshape(y, [-1, 1])
t = tf.reshape(t, [-1, 1])
loss_value, grads = grad(x, t, training, [t, y])
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
else:
predictions = self.forward(x, t, training)
if y is not None:
predictions["outcome"] = y
return predictions
[docs] def loss(self, predictions: Dict, labels: List[tf.Tensor]):
control_logits, test_logits = (
predictions["control_logits"],
predictions["test_logits"],
)
treatment = tf.cast(labels[0], tf.float32)
outcome_y = tf.cast(labels[1], tf.float32)
treatment = tf.concat([1 - treatment, treatment], 1)
logits = tf.concat([control_logits, test_logits], 1)
outcome_logits = tf.reshape(
tf.reduce_sum(tf.multiply(logits, treatment), axis=1), [-1, 1]
)
self.outcome_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=outcome_logits, labels=outcome_y
)
)
return self.outcome_loss
[docs] def get_optimizer(self, lr: float = 0.01):
"""Build the optimizer.
Args:
Returns:
An optimizer.
"""
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
return optimizer
def _generator(self, **kwargs):
"""main loop"""
def none_generator():
while True:
yield None
if self._Z:
iz = iter(self._Z)
else:
ix = iter(self._X) if self._X else none_generator()
iy = iter(self._Y) if self._Y else none_generator()
ics = (
[(i[0], iter(i[1])) for i in self._C.items()]
if self._C
else {"nonsense_placeholder": none_generator()}
)
iz = map(
lambda _: (
next(ix),
next(iy),
{k: next(v) for k, v in ics},
),
none_generator(),
)
return iz