-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgat_model.py
42 lines (37 loc) · 1.48 KB
/
gat_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch import nn
from torch_geometric.nn import GATConv
from constants import *
from transformer import *
class GAT(torch.nn.Module):
def __init__(self, num_skills, num_layers=3, hidden_dim=skill_embd_dim):
"""
Represents a Graph Attention Network (GAT) with num_layers
layers and embedding and hidden dimension of hidden_dim.
"""
super().__init__()
self.tag = 'GAT'
self.pre_embs = nn.Embedding(num_skills, hidden_dim)
self.conv1 = GATConv(hidden_dim, hidden_dim)
self.prelu1 = nn.PReLU()
self.conv2 = GATConv(hidden_dim, hidden_dim)
self.prelu2 = nn.PReLU()
self.conv3 = GATConv(hidden_dim, hidden_dim)
self.prelu3 = nn.PReLU()
self.out = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(0.3)
def forward(self, x, edge_index, edge_weight):
"""
Runs a forward pass through GAT with given initial skill IDs and
edge_index and edge_weights.
Arguments:
- x: skill IDs (torch.Tensor)
- edge_index: edges in skill graph (torch.Tensor)
- edge_weight: edge weights of skill graph (torch.Tensor)
Returns:
- final node embedding for skill
"""
h1 = self.dropout(self.prelu1(self.conv1(self.pre_embs(x), edge_index)))
h2 = self.dropout(self.prelu2(self.conv2(h1, edge_index)))
h3 = self.prelu3(self.conv3(h2, edge_index))
return self.out(h3)