# Copyright 2023 AntGroup CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# Some of the code implementation is referred from https://github.com/xunzheng/notears
# Modified by Ant Group in 2023
import numpy as np
import torch
import torch.nn as nn
from openasce.discovery.regression_discovery.locally_connected import LocallyConnected
from openasce.discovery.regression_discovery.trace_expm import trace_expm
from openasce.utils.logger import logger
[docs]class NotearsMLP(nn.Module):
[docs] def __init__(self, dims, bias=True):
super(NotearsMLP, self).__init__()
assert dims[-1] == 1, "The node number in output layer should be one"
d = dims[0]
self.dims = dims
# fc1: variable splitting for l1
self.fc1_pos = nn.Linear(d, d * dims[1], bias=bias)
self.fc1_neg = nn.Linear(d, d * dims[1], bias=bias)
self.fc1_pos.weight.bounds = self._bounds()
self.fc1_neg.weight.bounds = self._bounds()
# fc2: local linear layers
layers = []
for l in range(len(dims) - 2):
layers.append(LocallyConnected(d, dims[l + 1], dims[l + 2], bias=bias))
self.fc2 = nn.ModuleList(layers)
def _bounds(self):
d = self.dims[0]
bounds = []
for j in range(d):
for m in range(self.dims[1]):
for i in range(d):
if i == j:
bound = (0, 0)
else:
bound = (0, None)
bounds.append(bound)
return bounds
[docs] def forward(self, x): # [n, d] -> [n, d]
x = self.fc1_pos(x) - self.fc1_neg(x) # [n, d * m1]
x = x.view(-1, self.dims[0], self.dims[1]) # [n, d, m1]
for fc in self.fc2:
x = torch.sigmoid(x) # [n, d, m1]
x = fc(x) # [n, d, m2]
x = x.squeeze(dim=2) # [n, d]
return x
[docs] def h_func(self):
"""Constrain 2-norm-squared of fc1 weights along m1 dim to be a DAG"""
d = self.dims[0]
fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i]
fc1_weight = fc1_weight.view(d, -1, d) # [j, m1, i]
A = torch.sum(fc1_weight * fc1_weight, dim=1).t() # [i, j]
h = trace_expm(A) - d
return h
[docs] def l2_reg(self):
"""Take 2-norm-squared of all parameters"""
reg = 0.0
fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i]
reg += torch.sum(fc1_weight**2)
for fc in self.fc2:
reg += torch.sum(fc.weight**2)
return reg
[docs] def fc1_l1_reg(self):
"""Take l1 norm of fc1 weight"""
reg = torch.sum(self.fc1_pos.weight + self.fc1_neg.weight)
return reg
[docs] @torch.no_grad()
def fc1_to_adj(self) -> np.ndarray: # [j * m1, i] -> [i, j]
"""Get W from fc1 weights, take 2-norm over m1 dim"""
d = self.dims[0]
fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i]
fc1_weight = fc1_weight.view(d, -1, d) # [j, m1, i]
A = torch.sum(fc1_weight * fc1_weight, dim=1).t() # [i, j]
W = torch.sqrt(A) # [i, j]
W = W.cpu().detach().numpy() # [i, j]
return W