Source code for openasce.discovery.regression_discovery.notears_mlp

#    Copyright 2023 AntGroup CO., Ltd.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# Some of the code implementation is referred from https://github.com/xunzheng/notears
# Modified by Ant Group in 2023

import numpy as np
import torch
import torch.nn as nn

from openasce.discovery.regression_discovery.locally_connected import LocallyConnected
from openasce.discovery.regression_discovery.trace_expm import trace_expm
from openasce.utils.logger import logger


[docs]class NotearsMLP(nn.Module):
[docs] def __init__(self, dims, bias=True): super(NotearsMLP, self).__init__() assert dims[-1] == 1, "The node number in output layer should be one" d = dims[0] self.dims = dims # fc1: variable splitting for l1 self.fc1_pos = nn.Linear(d, d * dims[1], bias=bias) self.fc1_neg = nn.Linear(d, d * dims[1], bias=bias) self.fc1_pos.weight.bounds = self._bounds() self.fc1_neg.weight.bounds = self._bounds() # fc2: local linear layers layers = [] for l in range(len(dims) - 2): layers.append(LocallyConnected(d, dims[l + 1], dims[l + 2], bias=bias)) self.fc2 = nn.ModuleList(layers)
def _bounds(self): d = self.dims[0] bounds = [] for j in range(d): for m in range(self.dims[1]): for i in range(d): if i == j: bound = (0, 0) else: bound = (0, None) bounds.append(bound) return bounds
[docs] def forward(self, x): # [n, d] -> [n, d] x = self.fc1_pos(x) - self.fc1_neg(x) # [n, d * m1] x = x.view(-1, self.dims[0], self.dims[1]) # [n, d, m1] for fc in self.fc2: x = torch.sigmoid(x) # [n, d, m1] x = fc(x) # [n, d, m2] x = x.squeeze(dim=2) # [n, d] return x
[docs] def h_func(self): """Constrain 2-norm-squared of fc1 weights along m1 dim to be a DAG""" d = self.dims[0] fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i] fc1_weight = fc1_weight.view(d, -1, d) # [j, m1, i] A = torch.sum(fc1_weight * fc1_weight, dim=1).t() # [i, j] h = trace_expm(A) - d return h
[docs] def l2_reg(self): """Take 2-norm-squared of all parameters""" reg = 0.0 fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i] reg += torch.sum(fc1_weight**2) for fc in self.fc2: reg += torch.sum(fc.weight**2) return reg
[docs] def fc1_l1_reg(self): """Take l1 norm of fc1 weight""" reg = torch.sum(self.fc1_pos.weight + self.fc1_neg.weight) return reg
[docs] @torch.no_grad() def fc1_to_adj(self) -> np.ndarray: # [j * m1, i] -> [i, j] """Get W from fc1 weights, take 2-norm over m1 dim""" d = self.dims[0] fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i] fc1_weight = fc1_weight.view(d, -1, d) # [j, m1, i] A = torch.sum(fc1_weight * fc1_weight, dim=1).t() # [i, j] W = torch.sqrt(A) # [i, j] W = W.cpu().detach().numpy() # [i, j] return W