Source code for openasce.discovery.regression_discovery.locally_connected

#    Copyright 2023 AntGroup CO., Ltd.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# Some of the code implementation is referred from https://github.com/xunzheng/notears
# Modified by Ant Group in 2023

import math

import torch
import torch.nn as nn


[docs]class LocallyConnected(nn.Module): """Local linear layer Argument: num_linear: num of local linear layers, i.e. in_features: m1 out_features: m2 bias: whether to include bias or not Shape: - Input: [n, d, m1] - Output: [n, d, m2] Attributes: weight: [d, m1, m2] bias: [d, m2] """
[docs] def __init__(self, num_linear, input_features, output_features, bias=True): super(LocallyConnected, self).__init__() self.num_linear = num_linear self.input_features = input_features self.output_features = output_features self.weight = nn.Parameter( torch.Tensor(num_linear, input_features, output_features) ) if bias: self.bias = nn.Parameter(torch.Tensor(num_linear, output_features)) else: self.register_parameter("bias", None) self.reset_parameters()
[docs] @torch.no_grad() def reset_parameters(self): k = 1.0 / self.input_features bound = math.sqrt(k) nn.init.uniform_(self.weight, -bound, bound) if self.bias is not None: nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, input: torch.Tensor): # [n, d, 1, m2] = [n, d, 1, m1] @ [1, d, m1, m2] out = torch.matmul(input.unsqueeze(dim=2), self.weight.unsqueeze(dim=0)) out = out.squeeze(dim=2) if self.bias is not None: # [n, d, m2] += [d, m2] out += self.bias return out