Welcome to My Blog

这是模板文件


class GNNLayer(nn.Module):
def __init__(self, in_channel, out_channel, inter_dim=0, heads=1, node_num=100):
super(GNNLayer, self).__init__()

self.gnn = GraphLayer(in_channel, out_channel, inter_dim=inter_dim, heads=heads, concat=False)

self.bn = nn.BatchNorm1d(out_channel)
self.relu = nn.ReLU()
self.leaky_relu = nn.LeakyReLU()

self.att_weight_1 = None
self.edge_index_1 = None

def forward(self, x, edge_index, embedding=None, node_num=0):
out, (new_edge_index, att_weight) = self.gnn(x, edge_index, embedding, return_attention_weights=True)
self.att_weight_1 = att_weight
self.edge_index_1 = new_edge_index
out = self.bn(out)
return self.relu(out)