Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add annotation #9

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added __init__.py
Empty file.
42 changes: 42 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from src.utils import tab_printer
from src.simgnn import SimGNNTrainer
from src.parser import parameter_parser


def main():
"""
Parsing command line parameters, reading data, fitting and scoring a SimGNN model.
"""
args = parameter_parser() # 解析命令行输入的参数
tab_printer(args) # 以表格的形式打印参数
trainer = SimGNNTrainer(args) # 构建SimGNNTrainer类
# 从下面的几个命令开始,根据不同的参数设置调用SimGNN类的实例的forward函数
if args.measure_time:
trainer.measure_time() # Measure average calculation time for one graph pair
else:
if args.load:
trainer.load() # Load a pretrained model
else:
trainer.fit() # training a model
trainer.score()
if args.save: # Store the model. Default is None.
trainer.save()

if args.notify: # 是否需要发送通知,根据操作系统的不同,使用不同的方法来发送通知
import os
import sys

if sys.platform == "linux": #Linux操作系统
os.system('notify-send SimGNN "Program is finished."')
elif sys.platform == "posix": #macOS操作系统
os.system(
"""
osascript -e 'display notification "SimGNN" with title "Program is finished."'
"""
)
else:
raise NotImplementedError("No notification support for this OS.")


if __name__ == "__main__":
main()
Empty file added src/__init__.py
Empty file.
27 changes: 15 additions & 12 deletions src/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ def forward(self, x, batch, size=None):
:param batch: Batch vector, which assigns each node to a specific example
:return representation: A graph level representation matrix.
"""
size = batch[-1].item() + 1 if size is None else size
mean = scatter_mean(x, batch, dim=0, dim_size=size)
transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))
size = batch[-1].item() + 1 if size is None else size #=128,表示批次大小
mean = scatter_mean(x, batch, dim=0, dim_size=size) # 按照 batch 维度对输入张量 x 进行分组并求取每组的平均值
# 相当于对每个图中节点的特征求平均值 [128,16]
transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix)) # 均值乘以可学习的权重得到全局特征 [128,16]*[16,16]=[128,16]

coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))
weighted = coefs.unsqueeze(-1) * x
coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1)) # 用每一个节点的特征与全局特征作内积,得到相似度权重
weighted = coefs.unsqueeze(-1) * x # 根据相似度权重对每个点进行加权

return scatter_add(weighted, batch, dim=0, dim_size=size)
return scatter_add(weighted, batch, dim=0, dim_size=size) # 按batch对weighted进行求和汇总,得到图级别的全局特征

def get_coefs(self, x):
mean = x.mean(dim=0)
Expand Down Expand Up @@ -162,16 +163,18 @@ def forward(self, embedding_1, embedding_2):
:param embedding_2: Result of the 2nd embedding after attention.
:return scores: A similarity score vector.
"""
batch_size = len(embedding_1)
scoring = torch.matmul(
batch_size = len(embedding_1) #embedding_1=[128,16]
# self.weight_matrix.view(self.args.filters_3, -1).shape=[16,256],原始输入的两个实体都是16维的特征向量,现在用256维来表示它们的某种关系
scoring = torch.matmul( # [128,256]
embedding_1, self.weight_matrix.view(self.args.filters_3, -1)
)
scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1])
scoring = torch.matmul(
scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1]) # [128,16,16]
# embedding_2.view(batch_size, self.args.filters_3, 1).shape=[128,16,1]
scoring = torch.matmul( # [128,16]
scoring, embedding_2.view(batch_size, self.args.filters_3, 1)
).view(batch_size, -1)
combined_representation = torch.cat((embedding_1, embedding_2), 1)
block_scoring = torch.t(
combined_representation = torch.cat((embedding_1, embedding_2), 1) # [128,32]
block_scoring = torch.t( # ([16,32]*[32,128])^T=[128,16]
torch.mm(self.weight_matrix_block, torch.t(combined_representation))
)
scores = F.relu(scoring + block_scoring + self.bias.view(-1))
Expand Down
18 changes: 11 additions & 7 deletions src/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ def parameter_parser():
A method to parse up command line parameters.
The default hyperparameters give a high performance model without grid search.
"""
parser = argparse.ArgumentParser(description="Run SimGNN.")
parser = argparse.ArgumentParser(description="Run SimGNN.") # 创建一个命令行参数解析器

parser.add_argument(
"--dataset",
nargs="?",
default="AIDS700nef",
parser.add_argument( # 向解析器添加了一个命令行参数的定义
"--dataset", # 参数名,双减号表示这是一个长参数
nargs="?", # 参数后面可以跟一个值,这个值可以存在也可以不存在。
default="AIDS700nef", # 如果不提供该参数,则默认为 "AIDS700nef"。
help="Dataset name. Default is AIDS700nef",
)

Expand Down Expand Up @@ -97,8 +97,11 @@ def parameter_parser():

parser.add_argument(
"--diffpool",
dest="diffpool",
action="store_true",
dest="diffpool", # 参数的目标属性名称,将该参数的值存储到 args.diffpool 属性中
# 如果没有指定 dest 参数,argparse 默认会根据命令行参数的名称生成目标属性名称。
# 1. 对于长参数名(以 -- 开头),argparse 会将前缀 -- 去除,并将剩余部分的连字符(-)替换为下划线(_)作为目标属性的名称。
# 2. 对于短参数名(以 - 开头),argparse 会将前缀 - 去除,并将剩余部分的每个字符都作为目标属性的名称。
action="store_true", #参数的动作,设置为 store_true 表示如果命令行中出现了该参数,则将其设置为 True。
help="Enable differentiable pooling.",
)

Expand Down Expand Up @@ -129,6 +132,7 @@ def parameter_parser():
help="Send notification message when the code is finished (only Linux & Mac OS support).",
)

# 设置默认参数
parser.set_defaults(histogram=False)
parser.set_defaults(diffpool=False)
parser.set_defaults(plot=False)
Expand Down
Loading