diff --git a/graph_neural_network/Dockerfile b/graph_neural_network/Dockerfile new file mode 100644 index 000000000..3ea26c0f3 --- /dev/null +++ b/graph_neural_network/Dockerfile @@ -0,0 +1,19 @@ +FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel + +WORKDIR /workspace/repository + +RUN pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 +RUN pip install scikit-learn==0.24.2 +RUN pip install torch_geometric==2.4.0 +RUN pip install --no-index torch_scatter==2.1.1 torch_sparse==0.6.17 -f https://data.pyg.org/whl/torch-1.13.0+cu117.html +RUN pip install graphlearn-torch==0.2.2 + +RUN apt update +RUN apt install -y git +RUN pip install git+https://github.com/mlcommons/logging.git + +# TF32 instead of FP32 for faster compute +ENV NVIDIA_TF32_OVERRIDE=1 + +RUN git clone https://github.com/alibaba/graphlearn-for-pytorch.git +WORKDIR /workspace/repository/graphlearn-for-pytorch/examples/igbh diff --git a/graph_neural_network/README.md b/graph_neural_network/README.md new file mode 100644 index 000000000..4b33ddafb --- /dev/null +++ b/graph_neural_network/README.md @@ -0,0 +1,170 @@ +# 1. Problem +This benchmark represents a multi-class node classification task in a heterogenous graph using the [IGB Heterogeneous Dataset](https://github.com/IllinoisGraphBenchmark/IGB-Datasets) named IGBH-Full. The task is carried out using a [GAT](https://arxiv.org/abs/1710.10903) model based on the [Relational Graph Attention Networks](https://arxiv.org/abs/1904.05811) paper. + +The reference implementation is based on [graphlearn-for-pytorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch). + +# 2. Directions +### Steps to configure machine + +#### 1. Clone the repository: +```bash +git clone https://github.com/alibaba/graphlearn-for-pytorch +``` + +or +```bash +git clone https://github.com/mlcommons/training.git +``` +once `GNN node classification` is merged into `mlcommons/training`. + +#### 2. Build the docker image: + +If you cloned the `graphlearn-for-pytorch` repository: +```bash +cd graphlearn-for-pytorch/examples/igbh/ +docker build -f Dockerfile -t training_gnn:latest . +``` + +If you cloned the `mlcommons/training` repository: +```bash +cd training/gnn_node_classification/ +docker build -f Dockerfile -t training_gnn:latest . +``` + + +### Steps to download and verify data +Download the dataset: +```bash + +bash download_igbh_full.sh +``` + +Before training, generate the seeds for training and validation: +```bash +python split_seeds.py --dataset_size='full' +``` + +The size of the `IGBH-Full` dataset is 2.2 TB. If you want to test with +the `tiny`, `small` or `medium` datasets, the download procedure is included +in the training script. + +### Steps to run and time + +#### Single-node Training + +The original graph is in the `COO` format and the feature is in the FP32 format. The training script will transform the graph from `COO` to `CSC` and convert the feature to FP16, which could be time consuming due to the graph scale. We provide a script to convert the graph layout from `COO` to `CSC` and persist the feature in FP16 format: + +```bash +python compress_graph.py --dataset_size='full' --layout='CSC' --use_fp16 +``` + +To train the model using multiple GPUs: +```bash +CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='full' --layout='CSC' --use_fp16 +``` +The number of training processes is equal to the number of GPUS. Option `--pin_feature` decides if the feature data will be pinned in host memory, which enables zero-copy feature access from GPU, but will incur extra memory costs. + + +#### Distributed Training + +##### 1. Data Partitioning +To partition the dataset (including both the topology and feature): +```bash +python partition.py --dataset_size='full' --num_partitions=2 --use_fp16 --layout='CSC' +``` +The above script will partition the dataset into two parts, convert the feature into +the FP16 format, and transform the graph layout from `COO` to `CSC`. + +We suggest using a distributed file system to store the partitioned data, such as HDFS or NFS, suhc that partitioned data can be accessed by all training nodes. + +##### 2. Two-stage Data Partitioning +To speed up the partitioning process, GLT also supports two-stage partitioning, which splits the process of topology partitioning and feature partitioning. After the topology partitioning is executed in a single node, the feature partitioning process can be conducted in each training node in parallel to speedup the partitioning process. + +The topology partitioning is conducted by executing: +```bash +python partition.py --dataset_size='full' --num_partitions=2 --layout='CSC' --with_feature=0 +``` + +The feature partitioning in conducted in each training node: +```bash +# node 0 which holds partition 0: +python build_partition_feature.py --dataset_size='full' --use_fp16 --in_memory=0 --partition_idx=0 + +# node 1 which holds partition 1: +python build_partition_feature.py --dataset_size='full' --use_fp16 --in_memory=0 --partition_idx=1 +``` + +##### 2. Model Training +The number of partitions and number of training nodes must be the same. In each training node, the model can be trained using the following command: + +```bash +# node 0: +CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=2 --master_addr=master_address_ip --model='rgat' --dataset_size='full' --layout='CSC' + +# node 1: +CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=2 --master_addr=master_address_ip --model='rgat' --dataset_size='full' --layout='CSC' +``` +The above script assumes that the training nodes are equipped with 2 GPUs and the number of training processes is equal to the number of GPUs. Each training process has a corresponding +sampling process using the same GPU. + +The `master_address_ip` should be replaced with the actual IP address of the master node. The `--pin_feature` option decides if the feature data will be pinned in host memory, which enables zero-copy feature access from GPU but will incur extra memory costs. + + + We recommend separating the sampling and training processes to different GPUs to achieve better performance. To seperate the GPU used by sampling and training processes, please add `--split_training_sampling` and set `--num_training_procs` as half of the number of devices: + +```bash +# node 0: +CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='full' --layout='CSC' --split_training_sampling + +# node 1: +CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='full' --layout='CSC' --split_training_sampling +``` +The above script uses one GPU for training and another for sampling in each node. + + + +# 3. Dataset/Environment +### Publication/Attribution +Arpandeep Khatua, Vikram Sharma Mailthody, Bhagyashree Taleka, Tengfei Ma, Xiang Song, and Wen-mei Hwu. 2023. IGB: Addressing The Gaps In Labeling, Features, Heterogeneity, and Size of Public Graph Datasets for Deep Learning Research. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD '23). Association for Computing Machinery, New York, NY, USA, 4284–4295. https://doi.org/10.1145/3580305.3599843 + +### Data preprocessing +The original graph is in the `COO` format and the feature is in FP32 format. It is allowed to transform the graph from `COO` to `CSC` and convert the feature to FP16 (supported by the training script). + +### Training and test data separation +The training and validation data are selected from the labeled ``paper`` nodes from the dataset and are generated by `split_seeds.py`. Differnet random seeds will result in different training and test data. + +### Training data order +Randomly. + +### Test data order +Randomly. + +# 4. Model +### Publication/Attribution +Dan Busbridge and Dane Sherburn and Pietro Cavallo and Nils Y. Hammerla, Relational Graph Attention Networks, 2019, https://arxiv.org/abs/1904.05811 + +### List of layers +Three-layer RGAT model + +### Loss function +CrossEntropyLoss + +### Optimizer +Adam + +# 5. Quality +### Quality metric +The validation accuracy is the target quality metric. +### Quality target +0.72 +### Evaluation frequency +4,730,280 training seeds (5% of the entire training seeds, evaluated every 0.05 epoch) +### Evaluation thoroughness +788,380 validation seeds + +# 6. Contributors +This benchmark is a collaborative effort with contributions from Alibaba, Intel, and Nvidia: + +- Alibaba: Li Su, Baole Ai, Wenting Shen, Shuxian Hu, Wenyuan Yu, Yong Li +- Nvidia: Yunzhou (David) Liu, Kyle Kranen, Shriya Palasamudram +- Intel: Kaixuan Liu, Hesham Mostafa, Sasikanth Avancha, Keith Achorn, Radha Giduthuri, Deepak Canchi \ No newline at end of file diff --git a/graph_neural_network/build_partition_feature.py b/graph_neural_network/build_partition_feature.py new file mode 100644 index 000000000..25a7187ba --- /dev/null +++ b/graph_neural_network/build_partition_feature.py @@ -0,0 +1,61 @@ +import argparse +import os.path as osp + +import graphlearn_torch as glt +import torch + +from dataset import IGBHeteroDataset + + +def partition_feature(src_path: str, + dst_path: str, + partition_idx: int, + chunk_size: int, + dataset_size: str='tiny', + in_memory: bool=True, + use_fp16: bool=False): + print(f'-- Loading igbh_{dataset_size} ...') + data = IGBHeteroDataset(src_path, dataset_size, in_memory, with_edges=False, use_fp16=use_fp16) + + print(f'-- Build feature for partition {partition_idx} ...') + dst_path = osp.join(dst_path, f'{dataset_size}-partitions') + node_feat_dtype = torch.float16 if use_fp16 else torch.float32 + glt.partition.base.build_partition_feature(root_dir = dst_path, + partition_idx = partition_idx, + chunk_size = chunk_size, + node_feat = data.feat_dict, + node_feat_dtype = node_feat_dtype) + + +if __name__ == '__main__': + root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh') + glt.utils.ensure_dir(root) + parser = argparse.ArgumentParser(description="Arguments for partitioning ogbn datasets.") + parser.add_argument('--src_path', type=str, default=root, + help='path containing the datasets') + parser.add_argument('--dst_path', type=str, default=root, + help='path containing the partitioned datasets') + parser.add_argument('--dataset_size', type=str, default='full', + choices=['tiny', 'small', 'medium', 'large', 'full'], + help='size of the datasets') + parser.add_argument('--in_memory', type=int, default=0, + choices=[0, 1], help='0:read only mmap_mode=r, 1:load into memory') + parser.add_argument("--partition_idx", type=int, default=0, + help="Index of a partition") + parser.add_argument("--chunk_size", type=int, default=10000, + help="Chunk size for feature partitioning.") + parser.add_argument("--use_fp16", action="store_true", + help="save node/edge feature using fp16 format") + + + args = parser.parse_args() + + partition_feature( + args.src_path, + args.dst_path, + partition_idx=args.partition_idx, + chunk_size=args.chunk_size, + dataset_size=args.dataset_size, + in_memory=args.in_memory==1, + use_fp16=args.use_fp16 + ) diff --git a/graph_neural_network/compress_graph.py b/graph_neural_network/compress_graph.py new file mode 100644 index 000000000..65b87a750 --- /dev/null +++ b/graph_neural_network/compress_graph.py @@ -0,0 +1,120 @@ +import argparse, datetime, os +import numpy as np +import torch +import os.path as osp + +import graphlearn_torch as glt + +from dataset import float2half +from download import download_dataset +from torch_geometric.utils import add_self_loops, remove_self_loops +from typing import Literal + + +class IGBHeteroDatasetCompress(object): + def __init__(self, + path, + dataset_size, + layout: Literal['CSC', 'CSR'] = 'CSC',): + self.dir = path + self.dataset_size = dataset_size + self.layout = layout + + self.ntypes = ['paper', 'author', 'institute', 'fos'] + self.etypes = None + self.edge_dict = {} + self.paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174} + self.author_nodes_num = {'tiny':357041, 'small':1926066, 'medium':15544654, 'large':116959896, 'full':277220883} + if not osp.exists(osp.join(path, self.dataset_size, 'processed')): + download_dataset(path, 'heterogeneous', dataset_size) + self.process() + + def process(self): + paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__cites__paper', 'edge_index.npy'))).t() + author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__written_by__author', 'edge_index.npy'))).t() + affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'author__affiliated_to__institute', 'edge_index.npy'))).t() + paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__topic__fos', 'edge_index.npy'))).t() + if self.dataset_size in ['large', 'full']: + paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__published__journal', 'edge_index.npy'))).t() + paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed', + 'paper__venue__conference', 'edge_index.npy'))).t() + + cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0] + self.edge_dict = { + ('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])), + ('paper', 'written_by', 'author'): author_paper_edges, + ('author', 'affiliated_to', 'institute'): affiliation_author_edges, + ('paper', 'topic', 'fos'): paper_fos_edges, + ('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]), + ('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]), + ('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :]) + } + if self.dataset_size in ['large', 'full']: + self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal + self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference + self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :]) + self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :]) + self.etypes = list(self.edge_dict.keys()) + + # init graphlearn_torch Dataset. + edge_dir = 'out' if self.layout == 'CSR' else 'in' + glt_dataset = glt.data.Dataset(edge_dir=edge_dir) + glt_dataset.init_graph( + edge_index=self.edge_dict, + graph_mode='CPU', + ) + + # save the corresponding csr or csc file + compress_edge_dict = {} + compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper' + compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author' + compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute' + compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos' + compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper' + compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author' + compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper' + compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal' + compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference' + compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper' + compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper' + + for etype in self.etypes: + graph = glt_dataset.get_graph(etype) + indptr, indices, _ = graph.export_topology() + path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype]) + if not os.path.exists(path): + os.makedirs(path) + torch.save(indptr, os.path.join(path, 'indptr.pt')) + torch.save(indices, os.path.join(path, 'indices.pt')) + path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout) + print(f"The {self.layout} graph has been persisted in path: {path}") + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh') + glt.utils.ensure_dir(root) + parser.add_argument('--path', type=str, default=root, + help='path containing the datasets') + parser.add_argument('--dataset_size', type=str, default='full', + choices=['tiny', 'small', 'medium', 'large', 'full'], + help='size of the datasets') + parser.add_argument("--layout", type=str, default='CSC') + parser.add_argument('--use_fp16', action="store_true", + help="convert the node/edge feature into fp16 format") + args = parser.parse_args() + print(f"Start constructing the {args.layout} graph...") + igbh_dataset = IGBHeteroDatasetCompress(args.path, args.dataset_size, args.layout) + if args.use_fp16: + base_path = osp.join(args.path, args.dataset_size, 'processed') + float2half(base_path, args.dataset_size) + + + + diff --git a/graph_neural_network/dataset.py b/graph_neural_network/dataset.py new file mode 100644 index 000000000..9f7556daf --- /dev/null +++ b/graph_neural_network/dataset.py @@ -0,0 +1,270 @@ +import numpy as np +import torch +import os.path as osp + +from torch_geometric.utils import add_self_loops, remove_self_loops +from download import download_dataset +from typing import Literal + +def float2half(base_path, dataset_size): + paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174} + author_nodes_num = {'tiny':357041, 'small':1926066, 'medium':15544654, 'large':116959896, 'full':277220883} + # paper node + paper_feat_path = osp.join(base_path, 'paper', 'node_feat.npy') + paper_fp16_feat_path = osp.join(base_path, 'paper', 'node_feat_fp16.pt') + if not osp.exists(paper_fp16_feat_path): + if dataset_size in ['large', 'full']: + num_paper_nodes = paper_nodes_num[dataset_size] + paper_node_features = torch.from_numpy(np.memmap(paper_feat_path, dtype='float32', mode='r', shape=(num_paper_nodes,1024))) + else: + paper_node_features = torch.from_numpy(np.load(paper_feat_path, mmap_mode='r')) + paper_node_features = paper_node_features.half() + torch.save(paper_node_features, paper_fp16_feat_path) + + # author node + author_feat_path = osp.join(base_path, 'author', 'node_feat.npy') + author_fp16_feat_path = osp.join(base_path, 'author', 'node_feat_fp16.pt') + if not osp.exists(author_fp16_feat_path): + if dataset_size in ['large', 'full']: + num_author_nodes = author_nodes_num[dataset_size] + author_node_features = torch.from_numpy(np.memmap(author_feat_path, dtype='float32', mode='r', shape=(num_author_nodes,1024))) + else: + author_node_features = torch.from_numpy(np.load(author_feat_path, mmap_mode='r')) + author_node_features = author_node_features.half() + torch.save(author_node_features, author_fp16_feat_path) + + # institute node + institute_feat_path = osp.join(base_path, 'institute', 'node_feat.npy') + institute_fp16_feat_path = osp.join(base_path, 'institute', 'node_feat_fp16.pt') + if not osp.exists(institute_fp16_feat_path): + institute_node_features = torch.from_numpy(np.load(institute_feat_path, mmap_mode='r')) + institute_node_features = institute_node_features.half() + torch.save(institute_node_features, institute_fp16_feat_path) + + # fos node + fos_feat_path = osp.join(base_path, 'fos', 'node_feat.npy') + fos_fp16_feat_path = osp.join(base_path, 'fos', 'node_feat_fp16.pt') + if not osp.exists(fos_fp16_feat_path): + fos_node_features = torch.from_numpy(np.load(fos_feat_path, mmap_mode='r')) + fos_node_features = fos_node_features.half() + torch.save(fos_node_features, fos_fp16_feat_path) + + if dataset_size in ['large', 'full']: + # conference node + conference_feat_path = osp.join(base_path, 'conference', 'node_feat.npy') + conference_fp16_feat_path = osp.join(base_path, 'conference', 'node_feat_fp16.pt') + if not osp.exists(conference_fp16_feat_path): + conference_node_features = torch.from_numpy(np.load(conference_feat_path, mmap_mode='r')) + conference_node_features = conference_node_features.half() + torch.save(conference_node_features, conference_fp16_feat_path) + + # journal node + journal_feat_path = osp.join(base_path, 'journal', 'node_feat.npy') + journal_fp16_feat_path = osp.join(base_path, 'journal', 'node_feat_fp16.pt') + if not osp.exists(journal_fp16_feat_path): + journal_node_features = torch.from_numpy(np.load(journal_feat_path, mmap_mode='r')) + journal_node_features = journal_node_features.half() + torch.save(journal_node_features, journal_fp16_feat_path) + +class IGBHeteroDataset(object): + def __init__(self, + path, + dataset_size='tiny', + in_memory=True, + use_label_2K=False, + with_edges=True, + layout: Literal['CSC', 'CSR', 'COO'] = 'COO', + use_fp16=False): + self.dir = path + self.dataset_size = dataset_size + self.in_memory = in_memory + self.use_label_2K = use_label_2K + self.with_edges = with_edges + self.layout = layout + self.use_fp16 = use_fp16 + + self.ntypes = ['paper', 'author', 'institute', 'fos'] + self.etypes = None + self.edge_dict = {} + self.feat_dict = {} + self.paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174} + self.author_nodes_num = {'tiny':357041, 'small':1926066, 'medium':15544654, 'large':116959896, 'full':277220883} + # 'paper' nodes. + self.label = None + self.train_idx = None + self.val_idx = None + self.test_idx = None + self.base_path = osp.join(path, self.dataset_size, 'processed') + if not osp.exists(self.base_path): + download_dataset(path, 'heterogeneous', dataset_size) + if self.use_fp16: + float2half(self.base_path, self.dataset_size) + self.process() + + def process(self): + # load edges + if self.with_edges: + if self.layout == 'COO': + if self.in_memory: + paper_paper_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__cites__paper', 'edge_index.npy'))).t() + author_paper_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__written_by__author', 'edge_index.npy'))).t() + affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'author__affiliated_to__institute', 'edge_index.npy'))).t() + paper_fos_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__topic__fos', 'edge_index.npy'))).t() + if self.dataset_size in ['large', 'full']: + paper_published_journal = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__published__journal', 'edge_index.npy'))).t() + paper_venue_conference = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__venue__conference', 'edge_index.npy'))).t() + else: + paper_paper_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__cites__paper', 'edge_index.npy'), mmap_mode='r')).t() + author_paper_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__written_by__author', 'edge_index.npy'), mmap_mode='r')).t() + affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'author__affiliated_to__institute', 'edge_index.npy'), mmap_mode='r')).t() + paper_fos_edges = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__topic__fos', 'edge_index.npy'), mmap_mode='r')).t() + if self.dataset_size in ['large', 'full']: + paper_published_journal = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__published__journal', 'edge_index.npy'), mmap_mode='r')).t() + paper_venue_conference = torch.from_numpy(np.load(osp.join(self.base_path, + 'paper__venue__conference', 'edge_index.npy'), mmap_mode='r')).t() + + cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0] + self.edge_dict = { + ('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])), + ('paper', 'written_by', 'author'): author_paper_edges, + ('author', 'affiliated_to', 'institute'): affiliation_author_edges, + ('paper', 'topic', 'fos'): paper_fos_edges, + ('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]), + ('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]), + ('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :]) + } + if self.dataset_size in ['large', 'full']: + self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal + self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference + self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :]) + self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :]) + + # directly load from CSC or CSC files, which can be generated using compress_graph.py + else: + compress_edge_dict = {} + compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper' + compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author' + compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute' + compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos' + compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper' + compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author' + compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper' + if self.dataset_size in ['large', 'full']: + compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal' + compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference' + compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper' + compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper' + + for etype in compress_edge_dict.keys(): + edge_path = osp.join(self.base_path, self.layout, compress_edge_dict[etype]) + try: + edge_path = osp.join(self.base_path, self.layout, compress_edge_dict[etype]) + indptr = torch.load(osp.join(edge_path, 'indptr.pt')) + indices = torch.load(osp.join(edge_path, 'indices.pt')) + if self.layout == 'CSC': + self.edge_dict[etype] = (indices, indptr) + else: + self.edge_dict[etype] = (indptr, indices) + except FileNotFoundError as e: + print(f"FileNotFound: {e}") + exit() + except Exception as e: + print(f"Exception: {e}") + exit() + self.etypes = list(self.edge_dict.keys()) + + # load features and labels + label_file = 'node_label_19.npy' if not self.use_label_2K else 'node_label_2K.npy' + paper_feat_path = osp.join(self.base_path, 'paper', 'node_feat.npy') + paper_lbl_path = osp.join(self.base_path, 'paper', label_file) + num_paper_nodes = self.paper_nodes_num[self.dataset_size] + if self.in_memory: + if self.use_fp16: + paper_node_features = torch.load(osp.join(self.base_path, 'paper', 'node_feat_fp16.pt')) + else: + paper_node_features = torch.from_numpy(np.load(paper_feat_path)) + else: + if self.dataset_size in ['large', 'full']: + paper_node_features = torch.from_numpy(np.memmap(paper_feat_path, dtype='float32', mode='r', shape=(num_paper_nodes,1024))) + else: + paper_node_features = torch.from_numpy(np.load(paper_feat_path, mmap_mode='r')) + if self.dataset_size in ['large', 'full']: + paper_node_labels = torch.from_numpy(np.memmap(paper_lbl_path, dtype='float32', mode='r', shape=(num_paper_nodes))).to(torch.long) + else: + paper_node_labels = torch.from_numpy(np.load(paper_lbl_path)).to(torch.long) + self.feat_dict['paper'] = paper_node_features + self.label = paper_node_labels + + num_author_nodes = self.author_nodes_num[self.dataset_size] + author_feat_path = osp.join(self.base_path, 'author', 'node_feat.npy') + if self.in_memory: + if self.use_fp16: + author_node_features = torch.load(osp.join(self.base_path, 'author', 'node_feat_fp16.pt')) + else: + author_node_features = torch.from_numpy(np.load(author_feat_path)) + else: + if self.dataset_size in ['large', 'full']: + author_node_features = torch.from_numpy(np.memmap(author_feat_path, dtype='float32', mode='r', shape=(num_author_nodes,1024))) + else: + author_node_features = torch.from_numpy(np.load(author_feat_path, mmap_mode='r')) + self.feat_dict['author'] = author_node_features + + if self.in_memory: + if self.use_fp16: + institute_node_features = torch.load(osp.join(self.base_path, 'institute', 'node_feat_fp16.pt')) + else: + institute_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'institute', 'node_feat.npy'))) + else: + institute_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'institute', 'node_feat.npy'), mmap_mode='r')) + self.feat_dict['institute'] = institute_node_features + + if self.in_memory: + if self.use_fp16: + fos_node_features = torch.load(osp.join(self.base_path, 'fos', 'node_feat_fp16.pt')) + else: + fos_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'fos', 'node_feat.npy'))) + else: + fos_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'fos', 'node_feat.npy'), mmap_mode='r')) + self.feat_dict['fos'] = fos_node_features + + if self.dataset_size in ['large', 'full']: + if self.in_memory: + if self.use_fp16: + conference_node_features = torch.load(osp.join(self.base_path, 'conference', 'node_feat_fp16.pt')) + else: + conference_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'conference', 'node_feat.npy'))) + else: + conference_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'conference', 'node_feat.npy'), mmap_mode='r')) + self.feat_dict['conference'] = conference_node_features + + if self.in_memory: + if self.use_fp16: + journal_node_features = torch.load(osp.join(self.base_path, 'journal', 'node_feat_fp16.pt')) + else: + journal_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'journal', 'node_feat.npy'))) + else: + journal_node_features = torch.from_numpy(np.load(osp.join(self.base_path, 'journal', 'node_feat.npy'), mmap_mode='r')) + self.feat_dict['journal'] = journal_node_features + + # Please ensure that train_idx and val_idx have been generated using split_seeds.py + try: + self.train_idx = torch.load(osp.join(self.base_path, 'train_idx.pt')) + self.val_idx = torch.load(osp.join(self.base_path, 'val_idx.pt')) + except FileNotFoundError as e: + print(f"FileNotFound: {e}, please ensure that train_idx and val_idx have been generated using split_seeds.py") + exit() + except Exception as e: + print(f"Exception: {e}") + exit() + diff --git a/graph_neural_network/dist_train_rgnn.py b/graph_neural_network/dist_train_rgnn.py new file mode 100644 index 000000000..a77a1d34e --- /dev/null +++ b/graph_neural_network/dist_train_rgnn.py @@ -0,0 +1,475 @@ +import argparse, datetime +import os.path as osp +import time, tqdm + +import graphlearn_torch as glt +import mlperf_logging.mllog.constants as mllog_constants +import numpy as np +import sklearn.metrics +import torch +import torch.distributed + +from mlperf_logging_utils import get_mlperf_logger, submission_info +from torch.nn.parallel import DistributedDataParallel +from utilities import create_ckpt_folder +from rgnn import RGNN + +mllogger = get_mlperf_logger(path=osp.dirname(osp.abspath(__file__))) + +def evaluate(model, dataloader, current_device, use_fp16, with_gpu, + rank, world_size, epoch_num): + if rank == 0: + mllogger.start( + key=mllog_constants.EVAL_START, + metadata={mllog_constants.EPOCH_NUM: epoch_num}, + ) + predictions = [] + labels = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataloader): + batch_size = batch['paper'].batch_size + if use_fp16: + x_dict = {node_name: node_feat.to(current_device).to(torch.float32) + for node_name, node_feat in batch.x_dict.items()} + else: + x_dict = {node_name: node_feat.to(current_device) + for node_name, node_feat in batch.x_dict.items()} + out = model(x_dict, + batch.edge_index_dict, + num_sampled_nodes_dict=batch.num_sampled_nodes, + num_sampled_edges_dict=batch.num_sampled_edges)[:batch_size] + batch_size = min(out.shape[0], batch_size) + labels.append(batch['paper'].y[:batch_size].cpu().clone().numpy()) + predictions.append(out.argmax(1).cpu().clone().numpy()) + + predictions = np.concatenate(predictions) + labels = np.concatenate(labels) + acc = sklearn.metrics.accuracy_score(labels, predictions) + + if with_gpu: + torch.cuda.synchronize() + torch.distributed.barrier() + + acc_tensor = torch.tensor(acc).to(current_device) + torch.distributed.all_reduce(acc_tensor, op=torch.distributed.ReduceOp.SUM) + global_acc = acc_tensor.item() / world_size + if rank == 0: + mllogger.event( + key=mllog_constants.EVAL_ACCURACY, + value=global_acc, + metadata={mllog_constants.EPOCH_NUM: epoch_num}, + ) + mllogger.end( + key=mllog_constants.EVAL_STOP, + metadata={mllog_constants.EPOCH_NUM: epoch_num}, + ) + return acc, global_acc + +def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs, + split_training_sampling, hidden_channels, num_classes, num_layers, + model_type, num_heads, fan_out, epochs, train_batch_size, val_batch_size, + learning_rate, + random_seed, + dataset, train_idx, val_idx, + train_channel_size, + val_channel_size, + master_addr, + training_pg_master_port, + train_loader_master_port, + val_loader_master_port, + with_gpu, trim_to_layer, use_fp16, + edge_dir, rpc_timeout, + validation_acc, validation_frac_within_epoch, evaluate_on_epoch_end, + checkpoint_on_epoch_end, ckpt_steps, ckpt_path): + + world_size=num_nodes*num_training_procs + rank=node_rank*num_training_procs+local_proc_rank + if rank == 0: + if ckpt_steps > 0: + ckpt_dir = create_ckpt_folder(base_dir=osp.dirname(osp.abspath(__file__))) + + glt.utils.common.seed_everything(random_seed) + + # Initialize graphlearn_torch distributed worker group context. + glt.distributed.init_worker_group( + world_size=world_size, + rank=rank, + group_name='distributed-igbh-trainer' + ) + + current_ctx = glt.distributed.get_context() + if with_gpu: + if split_training_sampling: + current_device = torch.device((local_proc_rank * 2) % torch.cuda.device_count()) + sampling_device = torch.device((local_proc_rank * 2 + 1) % torch.cuda.device_count()) + else: + current_device = torch.device(local_proc_rank % torch.cuda.device_count()) + sampling_device = current_device + else: + current_device = torch.device('cpu') + sampling_device = current_device + + # Initialize training process group of PyTorch. + torch.distributed.init_process_group( + backend='nccl' if with_gpu else 'gloo', + timeout=datetime.timedelta(seconds=rpc_timeout), + rank=current_ctx.rank, + world_size=current_ctx.world_size, + init_method='tcp://{}:{}'.format(master_addr, training_pg_master_port) + ) + + # Create distributed neighbor loader for training + train_idx = train_idx.split(train_idx.size(0) // num_training_procs)[local_proc_rank] + train_loader = glt.distributed.DistNeighborLoader( + data=dataset, + num_neighbors=[int(fanout) for fanout in fan_out.split(',')], + input_nodes=('paper', train_idx), + batch_size=train_batch_size, + shuffle=True, + drop_last=False, + edge_dir=edge_dir, + collect_features=True, + to_device=current_device, + random_seed=random_seed, + worker_options = glt.distributed.MpDistSamplingWorkerOptions( + num_workers=1, + worker_devices=sampling_device, + worker_concurrency=4, + master_addr=master_addr, + master_port=train_loader_master_port, + channel_size=train_channel_size, + pin_memory=True, + rpc_timeout=rpc_timeout, + num_rpc_threads=2 + ) + ) + # Create distributed neighbor loader for validation. + val_idx = val_idx.split(val_idx.size(0) // num_training_procs)[local_proc_rank] + val_loader = glt.distributed.DistNeighborLoader( + data=dataset, + num_neighbors=[int(fanout) for fanout in fan_out.split(',')], + input_nodes=('paper', val_idx), + batch_size=val_batch_size, + shuffle=True, + drop_last=False, + edge_dir=edge_dir, + collect_features=True, + to_device=current_device, + random_seed=random_seed, + worker_options = glt.distributed.MpDistSamplingWorkerOptions( + num_workers=1, + worker_devices=sampling_device, + worker_concurrency=4, + master_addr=master_addr, + master_port=val_loader_master_port, + channel_size=val_channel_size, + pin_memory=True, + rpc_timeout=rpc_timeout, + num_rpc_threads=2 + ) + ) + + # Load checkpoint + ckpt = None + if ckpt_path is not None: + try: + ckpt = torch.load(ckpt_path) + except FileNotFoundError as e: + print(f"Checkpoint file not found: {e}") + return -1 + + # Define model and optimizer. + if with_gpu: + torch.cuda.set_device(current_device) + model = RGNN(dataset.get_edge_types(), + dataset.node_features['paper'].shape[1], + hidden_channels, + num_classes, + num_layers=num_layers, + dropout=0.2, + model=model_type, + heads=num_heads, + node_type='paper', + with_trim=trim_to_layer).to(current_device) + if ckpt is not None: + model.load_state_dict(ckpt['model_state_dict']) + model = DistributedDataParallel(model, + device_ids=[current_device.index] if with_gpu else None, + find_unused_parameters=True) + + param_size = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + size_all_mb = (param_size + buffer_size) / 1024**2 + print('model size: {:.3f}MB'.format(size_all_mb)) + + loss_fcn = torch.nn.CrossEntropyLoss().to(current_device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + if ckpt is not None: + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + batch_num = (len(train_idx) + train_batch_size - 1) // train_batch_size + validation_freq = int(batch_num * validation_frac_within_epoch) + is_success = False + epoch_num = 0 + + training_start = time.time() + for epoch in range(epochs): + model.train() + total_loss = 0 + train_acc = 0 + idx = 0 + gpu_mem_alloc = 0 + epoch_start = time.time() + for batch in tqdm.tqdm(train_loader): + idx += 1 + batch_size = batch['paper'].batch_size + if use_fp16: + x_dict = {node_name: node_feat.to(current_device).to(torch.float32) + for node_name,node_feat in batch.x_dict.items()} + else: + x_dict = {node_name: node_feat.to(current_device) + for node_name,node_feat in batch.x_dict.items()} + out = model(x_dict, + batch.edge_index_dict, + num_sampled_nodes_dict=batch.num_sampled_nodes, + num_sampled_edges_dict=batch.num_sampled_edges)[:batch_size] + batch_size = min(batch_size, out.shape[0]) + y = batch['paper'].y[:batch_size] + loss = loss_fcn(out, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() + train_acc += sklearn.metrics.accuracy_score(y.cpu().numpy(), + out.argmax(1).detach().cpu().numpy())*100 + gpu_mem_alloc += ( + torch.cuda.max_memory_allocated() / 1000000 + if with_gpu + else 0 + ) + #checkpoint + if ckpt_steps > 0 and idx % ckpt_steps == 0: + if with_gpu: + torch.cuda.synchronize() + torch.distributed.barrier() + if rank == 0: + epoch_num = round((epoch + idx / batch_num), 2) + glt.utils.common.save_ckpt(idx + epoch * batch_num, + ckpt_dir, model.module, optimizer, epoch_num) + torch.distributed.barrier() + # evaluate + if idx % validation_freq == 0: + if with_gpu: + torch.cuda.synchronize() + torch.distributed.barrier() + epoch_num = round((epoch + idx / batch_num), 2) + model.eval() + rank_val_acc, global_acc = evaluate(model, val_loader, current_device, + use_fp16, with_gpu, rank, + world_size, epoch_num) + if validation_acc is not None and global_acc >= validation_acc: + is_success = True + break + model.train() + + train_acc /= idx + gpu_mem_alloc /= idx + + if with_gpu: + torch.cuda.synchronize() + torch.distributed.barrier() + + #checkpoint at the end of epoch + if checkpoint_on_epoch_end: + if rank == 0: + epoch_num = epoch + 1 + glt.utils.common.save_ckpt(idx + epoch * batch_num, + ckpt_dir, model.module, optimizer, epoch_num) + torch.distributed.barrier() + + # evaluate at the end of epoch + if evaluate_on_epoch_end and not is_success: + epoch_num = epoch + 1 + model.eval() + rank_val_acc, global_acc = evaluate(model, val_loader, current_device, + use_fp16, with_gpu, rank, world_size, + epoch_num) + if validation_acc is not None and global_acc >= validation_acc: + is_success = True + + tqdm.tqdm.write( + "Rank{:02d} | Epoch {:03d} | Loss {:.4f} | Train Acc {:.2f} | Val Acc {:.2f} | Time {} | GPU {:.1f} MB".format( + current_ctx.rank, + epoch, + total_loss, + train_acc, + rank_val_acc*100, + str(datetime.timedelta(seconds = int(time.time() - epoch_start))), + gpu_mem_alloc + ) + ) + + # stop training if success + if is_success: + break + + if rank == 0: + status = mllog_constants.SUCCESS if is_success else mllog_constants.ABORTED + mllogger.end(key=mllog_constants.RUN_STOP, + metadata={mllog_constants.STATUS: status, + mllog_constants.EPOCH_NUM: epoch_num, + } + ) + print("Total time taken " + str(datetime.timedelta(seconds = int(time.time() - training_start)))) + + +if __name__ == '__main__': + mllogger.event(key=mllog_constants.CACHE_CLEAR, value=True) + mllogger.start(key=mllog_constants.INIT_START) + + parser = argparse.ArgumentParser() + root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh') + glt.utils.ensure_dir(root) + parser.add_argument('--path', type=str, default=root, + help='path containing the datasets') + parser.add_argument('--dataset_size', type=str, default='full', + choices=['tiny', 'small', 'medium', 'large', 'full'], + help='size of the datasets') + parser.add_argument('--num_classes', type=int, default=2983, + choices=[19, 2983], help='number of classes') + parser.add_argument('--in_memory', type=int, default=0, + choices=[0, 1], help='0:read only mmap_mode=r, 1:load into memory') + # Model + parser.add_argument('--model', type=str, default='rgat', + choices=['rgat', 'rsage']) + # Model parameters + parser.add_argument('--fan_out', type=str, default='15,10,5') + parser.add_argument('--train_batch_size', type=int, default=512) + parser.add_argument('--val_batch_size', type=int, default=512) + parser.add_argument('--hidden_channels', type=int, default=512) + parser.add_argument('--learning_rate', type=float, default=0.001) + parser.add_argument('--epochs', type=int, default=2) + parser.add_argument('--num_layers', type=int, default=3) + parser.add_argument('--num_heads', type=int, default=4) + parser.add_argument('--random_seed', type=int, default=42) + # Distributed settings. + parser.add_argument("--num_nodes", type=int, default=2, + help="Number of distributed nodes.") + parser.add_argument("--node_rank", type=int, default=0, + help="The current node rank.") + parser.add_argument("--num_training_procs", type=int, default=2, + help="The number of traning processes per node.") + parser.add_argument("--master_addr", type=str, default='localhost', + help="The master address for RPC initialization.") + parser.add_argument("--training_pg_master_port", type=int, default=12111, + help="The port used for PyTorch's process group initialization across training processes.") + parser.add_argument("--train_loader_master_port", type=int, default=12112, + help="The port used for RPC initialization across all sampling workers of train loader.") + parser.add_argument("--val_loader_master_port", type=int, default=12113, + help="The port used for RPC initialization across all sampling workers of val loader.") + parser.add_argument("--cpu_mode", action="store_true", + help="Only use CPU for sampling and training, default is False.") + parser.add_argument("--edge_dir", type=str, default='out', + help="sampling direction, can be 'in' for 'by_dst' or 'out' for 'by_src' for partitions.") + parser.add_argument('--layout', type=str, default='COO', + help="Layout of input graph: CSC, CSR, COO. Default is COO.") + parser.add_argument('--train_channel_size', type=str, default='16GB', + help="Size of shared memory queue to put sampled results for train dataset") + parser.add_argument('--val_channel_size', type=str, default='16GB', + help="Size of shared memory queue to put sampled results for val dataset") + parser.add_argument("--rpc_timeout", type=int, default=180, + help="rpc timeout in seconds") + parser.add_argument("--split_training_sampling", action="store_true", + help="Use seperate GPUs for training and sampling processes.") + parser.add_argument("--with_trim", action="store_true", + help="use trim_to_layer function from PyG") + parser.add_argument("--use_fp16", action="store_true", + help="load node/edge feature using fp16 format to reduce memory usage") + parser.add_argument("--validation_frac_within_epoch", type=float, default=0.05, + help="Fraction of the epoch after which validation should be performed.") + parser.add_argument("--validation_acc", type=float, default=0.72, + help="Validation accuracy threshold to stop training once reached.") + parser.add_argument("--evaluate_on_epoch_end", action="store_true", + help="Evaluate using validation set on each epoch end."), + parser.add_argument("--checkpoint_on_epoch_end", action="store_true", + help="Save checkpoint on each epoch end."), + parser.add_argument('--ckpt_steps', type=int, default=-1, + help="Save checkpoint every n steps. Default is -1, which means no checkpoint is saved.") + parser.add_argument('--ckpt_path', type=str, default=None, + help="Path to load checkpoint from. Default is None.") + args = parser.parse_args() + assert args.layout in ['COO', 'CSC', 'CSR'] + + glt.utils.common.seed_everything(args.random_seed) + # when set --cpu_mode or GPU is not available, use cpu only mode. + args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available() + if args.with_gpu: + assert(not args.num_training_procs > torch.cuda.device_count()) + if args.split_training_sampling: + assert(not args.num_training_procs > torch.cuda.device_count() // 2) + + if args.node_rank == 0: + world_size = args.num_nodes * args.num_training_procs + submission_info(mllogger, 'GNN', 'reference_implementation') + + mllogger.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=world_size*args.train_batch_size) + mllogger.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1) + mllogger.event(key=mllog_constants.OPT_NAME, value='adam') + mllogger.event(key=mllog_constants.OPT_BASE_LR, value=args.learning_rate) + mllogger.event(key=mllog_constants.SEED,value=args.random_seed) + mllogger.end(key=mllog_constants.INIT_STOP) + mllogger.start(key=mllog_constants.RUN_START) + + print('--- Loading data partition ...\n') + data_pidx = args.node_rank % args.num_nodes + dataset = glt.distributed.DistDataset(edge_dir=args.edge_dir) + dataset.load( + root_dir=osp.join(args.path, f'{args.dataset_size}-partitions'), + partition_idx=data_pidx, + graph_mode='ZERO_COPY' if args.with_gpu else 'CPU', + input_layout = args.layout, + feature_with_gpu=args.with_gpu, + whole_node_label_file={'paper': osp.join(args.path, f'{args.dataset_size}-label', 'label.pt')} + ) + train_idx = torch.load( + osp.join(args.path, f'{args.dataset_size}-train-partitions', f'partition{data_pidx}.pt') + ) + val_idx = torch.load( + osp.join(args.path, f'{args.dataset_size}-val-partitions', f'partition{data_pidx}.pt') + ) + train_idx.share_memory_() + val_idx.share_memory_() + + print('--- Launching training processes ...\n') + torch.multiprocessing.spawn( + run_training_proc, + args=(args.num_nodes, args.node_rank, args.num_training_procs, + args.split_training_sampling, args.hidden_channels, args.num_classes, + args.num_layers, args.model, args.num_heads, args.fan_out, + args.epochs, args.train_batch_size, args.val_batch_size, args.learning_rate, + args.random_seed, + dataset, train_idx, val_idx, + args.train_channel_size, + args.val_channel_size, + args.master_addr, + args.training_pg_master_port, + args.train_loader_master_port, + args.val_loader_master_port, + args.with_gpu, + args.with_trim, + args.use_fp16, + args.edge_dir, + args.rpc_timeout, + args.validation_acc, + args.validation_frac_within_epoch, + args.evaluate_on_epoch_end, + args.checkpoint_on_epoch_end, + args.ckpt_steps, + args.ckpt_path), + nprocs=args.num_training_procs, + join=True + ) diff --git a/graph_neural_network/download.py b/graph_neural_network/download.py new file mode 100644 index 000000000..944c7f7c4 --- /dev/null +++ b/graph_neural_network/download.py @@ -0,0 +1,91 @@ +import tarfile, hashlib, os +import os.path as osp +from tqdm import tqdm +import urllib.request as ur + +# https://github.com/IllinoisGraphBenchmark/IGB-Datasets/blob/main/igb/download.py + +GBFACTOR = float(1 << 30) + +def decide_download(url): + d = ur.urlopen(url) + size = int(d.info()["Content-Length"])/GBFACTOR + ### confirm if larger than 1GB + if size > 1: + return input("This will download %.2fGB. Will you proceed? (y/N) " % (size)).lower() == "y" + else: + return True + + +dataset_urls = { + 'homogeneous' : { + 'tiny' : 'https://igb-public.s3.us-east-2.amazonaws.com/igb-homogeneous/igb_homogeneous_tiny.tar.gz', + 'small' : 'https://igb-public.s3.us-east-2.amazonaws.com/igb-homogeneous/igb_homogeneous_small.tar.gz', + 'medium' : 'https://igb-public.s3.us-east-2.amazonaws.com/igb-homogeneous/igb_homogeneous_medium.tar.gz' + }, + 'heterogeneous' : { + 'tiny' : 'https://igb-public.s3.us-east-2.amazonaws.com/igb-heterogeneous/igb_heterogeneous_tiny.tar.gz', + 'small' : 'https://igb-public.s3.us-east-2.amazonaws.com/igb-heterogeneous/igb_heterogeneous_small.tar.gz', + 'medium' : 'https://igb-public.s3.us-east-2.amazonaws.com/igb-heterogeneous/igb_heterogeneous_medium.tar.gz' + } +} + + +md5checksums = { + 'homogeneous' : { + 'tiny' : '34856534da55419b316d620e2d5b21be', + 'small' : '6781c699723529902ace0a95cafe6fe4', + 'medium' : '4640df4ceee46851fd18c0a44ddcc622' + }, + 'heterogeneous' : { + 'tiny' : '83fbc1091497ff92cf20afe82fae0ade', + 'small' : '2f42077be60a074aec24f7c60089e1bd', + 'medium' : '7f0df4296eca36553ff3a6a63abbd347' + } +} + + +def check_md5sum(dataset_type, dataset_size, filename): + original_md5 = md5checksums[dataset_type][dataset_size] + + with open(filename, 'rb') as file_to_check: + data = file_to_check.read() + md5_returned = hashlib.md5(data).hexdigest() + + if original_md5 == md5_returned: + print(" md5sum verified.") + return + else: + os.remove(filename) + raise Exception(" md5sum verification failed!.") + + +def download_dataset(path, dataset_type, dataset_size): + output_directory = path + url = dataset_urls[dataset_type][dataset_size] + if decide_download(url): + data = ur.urlopen(url) + size = int(data.info()["Content-Length"]) + chunk_size = 1024*1024 + num_iter = int(size/chunk_size) + 2 + downloaded_size = 0 + filename = path + "/igb_" + dataset_type + "_" + dataset_size + ".tar.gz" + with open(filename, 'wb') as f: + pbar = tqdm(range(num_iter)) + for _ in pbar: + chunk = data.read(chunk_size) + downloaded_size += len(chunk) + pbar.set_description("Downloaded {:.2f} GB".format(float(downloaded_size)/GBFACTOR)) + f.write(chunk) + print("Downloaded" + " igb_" + dataset_type + "_" + dataset_size, end=" ->") + check_md5sum(dataset_type, dataset_size, filename) + file = tarfile.open(filename) + file.extractall(output_directory) + file.close() + size = 0 + for path, _, files in os.walk(output_directory+"/"+dataset_size): + for f in files: + fp = osp.join(path, f) + size += osp.getsize(fp) + print("Final dataset size {:.2f} GB.".format(size/GBFACTOR)) + os.remove(filename) diff --git a/graph_neural_network/download_igbh_full.sh b/graph_neural_network/download_igbh_full.sh new file mode 100644 index 000000000..7f46721c5 --- /dev/null +++ b/graph_neural_network/download_igbh_full.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +#https://github.com/IllinoisGraphBenchmark/IGB-Datasets/blob/main/igb/download_igbh600m.sh +echo "IGBH600M download starting" +cd ../../data/ +mkdir -p igbh/full/processed +cd igbh/full/processed + +# paper +mkdir paper +cd paper +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper/node_feat.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper/node_label_19.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper/node_label_2K.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper/paper_id_index_mapping.npy +cd .. + +# paper__cites__paper +mkdir paper__cites__paper +cd paper__cites__paper +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper__cites__paper/edge_index.npy +cd .. + +# author +mkdir author +cd author +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/author/author_id_index_mapping.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/author/node_feat.npy +cd .. + +# conference +mkdir conference +cd conference +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/conference/conference_id_index_mapping.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/conference/node_feat.npy +cd .. + +# institute +mkdir institute +cd institute +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/institute/institute_id_index_mapping.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/institute/node_feat.npy +cd .. + +# journal +mkdir journal +cd journal +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/journal/journal_id_index_mapping.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/journal/node_feat.npy +cd .. + +# fos +mkdir fos +cd fos +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/fos/fos_id_index_mapping.npy +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/fos/node_feat.npy +cd .. + +# author__affiliated_to__institute +mkdir author__affiliated_to__institute +cd author__affiliated_to__institute +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/author__affiliated_to__institute/edge_index.npy +cd .. + +# paper__published__journal +mkdir paper__published__journal +cd paper__published__journal +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper__published__journal/edge_index.npy +cd .. + +# paper__topic__fos +mkdir paper__topic__fos +cd paper__topic__fos +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper__topic__fos/edge_index.npy +cd .. + +# paper__venue__conference +mkdir paper__venue__conference +cd paper__venue__conference +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper__venue__conference/edge_index.npy +cd .. + +# paper__written_by__author +mkdir paper__written_by__author +cd paper__written_by__author +wget -c https://igb-public.s3.us-east-2.amazonaws.com/IGBH/processed/paper__written_by__author/edge_index.npy +cd .. + +echo "IGBH-IGBH download complete" diff --git a/graph_neural_network/mlperf_logging_utils.py b/graph_neural_network/mlperf_logging_utils.py new file mode 100644 index 000000000..109ecd533 --- /dev/null +++ b/graph_neural_network/mlperf_logging_utils.py @@ -0,0 +1,33 @@ +import os +from mlperf_logging import mllog +from mlperf_logging.mllog import constants +from mlperf_logging.mllog.mllog import MLLogger + +def get_mlperf_logger(path, filename='mlperf_gnn.log'): + mllog.config(filename=os.path.join(path, filename)) + mllogger = mllog.get_mllogger() + mllogger.logger.propagate = False + return mllogger + +def submission_info(mllogger: MLLogger, benchmark_name: str, submitter_name: str): + """Logs required for a valid MLPerf submission.""" + mllogger.event( + key=constants.SUBMISSION_BENCHMARK, + value=benchmark_name, + ) + mllogger.event( + key=constants.SUBMISSION_ORG, + value=submitter_name, + ) + mllogger.event( + key=constants.SUBMISSION_DIVISION, + value=constants.CLOSED, + ) + mllogger.event( + key=constants.SUBMISSION_STATUS, + value=constants.ONPREM, + ) + mllogger.event( + key=constants.SUBMISSION_PLATFORM, + value=submitter_name, + ) diff --git a/graph_neural_network/partition.py b/graph_neural_network/partition.py new file mode 100644 index 000000000..b580022a9 --- /dev/null +++ b/graph_neural_network/partition.py @@ -0,0 +1,147 @@ +import argparse +import os.path as osp + +import graphlearn_torch as glt +import torch + +from dataset import IGBHeteroDataset +from typing import Literal + +def partition_dataset(src_path: str, + dst_path: str, + num_partitions: int, + chunk_size: int, + dataset_size: str='tiny', + in_memory: bool=True, + edge_assign_strategy: str='by_src', + use_label_2K: bool=False, + with_feature: bool=True, + use_fp16: bool=False, + layout: Literal['CSC', 'CSR', 'COO'] = 'COO'): + print(f'-- Loading igbh_{dataset_size} ...') + data = IGBHeteroDataset(src_path, dataset_size, in_memory, use_label_2K, use_fp16=use_fp16) + node_num = {k : v.shape[0] for k, v in data.feat_dict.items()} + + print('-- Saving label ...') + label_dir = osp.join(dst_path, f'{dataset_size}-label') + glt.utils.ensure_dir(label_dir) + torch.save(data.label.squeeze(), osp.join(label_dir, 'label.pt')) + + print('-- Partitioning training idx ...') + train_idx = data.train_idx + train_idx = train_idx.split(train_idx.size(0) // num_partitions) + train_idx_partitions_dir = osp.join(dst_path, f'{dataset_size}-train-partitions') + glt.utils.ensure_dir(train_idx_partitions_dir) + for pidx in range(num_partitions): + torch.save(train_idx[pidx], osp.join(train_idx_partitions_dir, f'partition{pidx}.pt')) + + print('-- Partitioning validation idx ...') + val_idx = data.val_idx + val_idx = val_idx.split(val_idx.size(0) // num_partitions) + val_idx_partitions_dir = osp.join(dst_path, f'{dataset_size}-val-partitions') + glt.utils.ensure_dir(val_idx_partitions_dir) + for pidx in range(num_partitions): + torch.save(val_idx[pidx], osp.join(val_idx_partitions_dir, f'partition{pidx}.pt')) + + print('-- Partitioning graph and features ...') + partitions_dir = osp.join(dst_path, f'{dataset_size}-partitions') + partitioner = glt.partition.RandomPartitioner( + output_dir=partitions_dir, + num_parts=num_partitions, + num_nodes=node_num, + edge_index=data.edge_dict, + node_feat=data.feat_dict, + node_feat_dtype = torch.float16 if use_fp16 else torch.float32, + edge_assign_strategy=edge_assign_strategy, + chunk_size=chunk_size, + ) + partitioner.partition(with_feature) + + if layout in ['CSC', 'CSR']: + compress_edge_dict = {} + compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper' + compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author' + compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute' + compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos' + compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper' + compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author' + compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper' + compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal' + compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference' + compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper' + compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper' + + for pidx in range(num_partitions): + base_path = osp.join(dst_path, f'{dataset_size}-partitions', f'part{pidx}', 'graph') + device = torch.device('cpu') + graph_dict = {} + for etype, e_path in compress_edge_dict.items(): + graph = glt.partition.base.load_graph_partition_data(osp.join(base_path, e_path), device) + if graph != None: + graph_dict[etype] = graph + + edge_dir = 'out' if layout == 'CSR' else 'in' + dataset = glt.distributed.DistDataset(edge_dir=edge_dir) + edge_index, edge_ids, edge_weights = {}, {}, {} + for k, v in graph_dict.items(): + edge_index[k] = v.edge_index + edge_ids[k] = v.eids + edge_weights[k] = v.weights + # COO is the oroginal layout of raw igbh graph + dataset.init_graph(edge_index, edge_ids, edge_weights, layout='COO', + graph_mode='CPU', device=device) + + for etype in graph_dict: + graph = dataset.get_graph(etype) + indptr, indices, _ = graph.export_topology() + path = osp.join(base_path, compress_edge_dict[etype]) + if layout == 'CSR': + torch.save(indptr, osp.join(path, 'rows.pt')) + torch.save(indices, osp.join(path, 'cols.pt')) + else: + torch.save(indptr, osp.join(path, 'cols.pt')) + torch.save(indices, osp.join(path, 'rows.pt')) + +if __name__ == '__main__': + root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh') + glt.utils.ensure_dir(root) + parser = argparse.ArgumentParser(description="Arguments for partitioning ogbn datasets.") + parser.add_argument('--src_path', type=str, default=root, + help='path containing the datasets') + parser.add_argument('--dst_path', type=str, default=root, + help='path containing the partitioned datasets') + parser.add_argument('--dataset_size', type=str, default='full', + choices=['tiny', 'small', 'medium', 'large', 'full'], + help='size of the datasets') + parser.add_argument('--num_classes', type=int, default=2983, + choices=[19, 2983], help='number of classes') + parser.add_argument('--in_memory', type=int, default=0, + choices=[0, 1], help='0:read only mmap_mode=r, 1:load into memory') + parser.add_argument("--num_partitions", type=int, default=2, + help="Number of partitions") + parser.add_argument("--chunk_size", type=int, default=10000, + help="Chunk size for feature partitioning.") + parser.add_argument("--edge_assign_strategy", type=str, default='by_src', + help="edge assign strategy can be either 'by_src' or 'by_dst'") + parser.add_argument('--with_feature', type=int, default=1, + choices=[0, 1], help='0:do not partition feature, 1:partition feature') + parser.add_argument('--use_fp16', action="store_true", + help="save partitioned node/edge feature into fp16 format") + parser.add_argument("--layout", type=str, default='COO', + help="layout of the partitioned graph: CSC, CSR, COO") + + args = parser.parse_args() + + partition_dataset( + args.src_path, + args.dst_path, + num_partitions=args.num_partitions, + chunk_size=args.chunk_size, + dataset_size=args.dataset_size, + in_memory=args.in_memory, + edge_assign_strategy=args.edge_assign_strategy, + use_label_2K=args.num_classes==2983, + with_feature=args.with_feature==1, + use_fp16=args.use_fp16, + layout = args.layout + ) diff --git a/graph_neural_network/rgnn.py b/graph_neural_network/rgnn.py new file mode 100644 index 000000000..39e9c9908 --- /dev/null +++ b/graph_neural_network/rgnn.py @@ -0,0 +1,66 @@ +import torch +import torch.nn.functional as F + +from torch_geometric.nn import HeteroConv, GATConv, GCNConv, SAGEConv +from torch_geometric.utils import trim_to_layer + +class RGNN(torch.nn.Module): + r""" [Relational GNN model](https://arxiv.org/abs/1703.06103). + + Args: + etypes: edge types. + in_dim: input size. + h_dim: Dimension of hidden layer. + out_dim: Output dimension. + num_layers: Number of conv layers. + dropout: Dropout probability for hidden layers. + model: "rsage" or "rgat". + heads: Number of multi-head-attentions for GAT. + node_type: The predict node type for node classification. + + """ + def __init__(self, etypes, in_dim, h_dim, out_dim, num_layers=2, + dropout=0.2, model='rgat', heads=4, node_type=None, with_trim=False): + super().__init__() + self.node_type = node_type + if node_type is not None: + self.lin = torch.nn.Linear(h_dim, out_dim) + + self.convs = torch.nn.ModuleList() + for i in range(num_layers): + in_dim = in_dim if i == 0 else h_dim + h_dim = out_dim if (i == (num_layers - 1) and node_type is None) else h_dim + if model == 'rsage': + self.convs.append(HeteroConv({ + etype: SAGEConv(in_dim, h_dim, root_weight=False) + for etype in etypes})) + elif model == 'rgat': + self.convs.append(HeteroConv({ + etype: GATConv(in_dim, h_dim // heads, heads=heads, add_self_loops=False) + for etype in etypes})) + self.dropout = torch.nn.Dropout(dropout) + self.with_trim = with_trim + + def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict=None, + num_sampled_nodes_dict=None): + for i, conv in enumerate(self.convs): + if self.with_trim: + x_dict, edge_index_dict, _ = trim_to_layer( + layer=i, + num_sampled_nodes_per_hop=num_sampled_nodes_dict, + num_sampled_edges_per_hop=num_sampled_edges_dict, + x=x_dict, + edge_index=edge_index_dict + ) + for key in list(edge_index_dict.keys()): + if key[0] not in x_dict or key[-1] not in x_dict: + del edge_index_dict[key] + + x_dict = conv(x_dict, edge_index_dict) + if i != len(self.convs) - 1: + x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} + x_dict = {key: self.dropout(x) for key, x in x_dict.items()} + if hasattr(self, 'lin'): # for node classification + return self.lin(x_dict[self.node_type]) + else: + return x_dict diff --git a/graph_neural_network/split_seeds.py b/graph_neural_network/split_seeds.py new file mode 100644 index 000000000..c8675ba92 --- /dev/null +++ b/graph_neural_network/split_seeds.py @@ -0,0 +1,59 @@ +import argparse +import os.path as osp +import torch + +class SeedSplitter(object): + def __init__(self, + path, + dataset_size='tiny', + use_label_2K=True, + random_seed=42, + validation_frac=0.01): + self.path = path + self.dataset_size = dataset_size + self.use_label_2K = use_label_2K + self.random_seed = random_seed + self.validation_frac = validation_frac + self.paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174} + self.process() + + def process(self): + torch.manual_seed(self.random_seed) + n_labeled_idx = self.paper_nodes_num[self.dataset_size] + if self.dataset_size == 'full': + if self.use_label_2K: + n_labeled_idx = 157675969 + else: + n_labeled_idx = 227130858 + + shuffled_index = torch.randperm(n_labeled_idx) + n_train = int(n_labeled_idx * 0.6) + n_val = int(n_labeled_idx * self.validation_frac) + + train_idx = shuffled_index[:n_train] + val_idx = shuffled_index[n_train : n_train + n_val] + + path = osp.join(self.path, self.dataset_size, 'processed') + torch.save(train_idx, osp.join(path, 'train_idx.pt')) + torch.save(val_idx, osp.join(path, 'val_idx.pt')) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh') + parser.add_argument('--path', type=str, default=root, + help='path containing the datasets') + parser.add_argument('--dataset_size', type=str, default='full', + choices=['tiny', 'small', 'medium', 'large', 'full'], + help='size of the datasets') + parser.add_argument("--random_seed", type=int, default='42') + parser.add_argument('--num_classes', type=int, default=2983, + choices=[19, 2983], help='number of classes') + parser.add_argument("--validation_frac", type=float, default=0.005, + help="Fraction of labeled vertices to be used for validation.") + + args = parser.parse_args() + splitter = SeedSplitter(path=args.path, + dataset_size=args.dataset_size, + use_label_2K=(args.num_classes==2983), + random_seed=args.random_seed, + validation_frac=args.validation_frac) \ No newline at end of file diff --git a/graph_neural_network/train_rgnn_multi_gpu.py b/graph_neural_network/train_rgnn_multi_gpu.py new file mode 100644 index 000000000..53ec87a0c --- /dev/null +++ b/graph_neural_network/train_rgnn_multi_gpu.py @@ -0,0 +1,358 @@ +import argparse, datetime, os +import numpy as np +import os.path as osp +import sklearn.metrics +import time, tqdm +import torch +import warnings + +import torch.distributed as dist +import graphlearn_torch as glt +import mlperf_logging.mllog.constants as mllog_constants + +from torch.nn.parallel import DistributedDataParallel + +from dataset import IGBHeteroDataset +from mlperf_logging_utils import get_mlperf_logger, submission_info +from utilities import create_ckpt_folder +from rgnn import RGNN + +warnings.filterwarnings("ignore") +mllogger = get_mlperf_logger(path=osp.dirname(osp.abspath(__file__))) + +def evaluate(model, dataloader, current_device, rank, world_size, epoch_num): + if rank == 0: + mllogger.start( + key=mllog_constants.EVAL_START, + metadata={mllog_constants.EPOCH_NUM: epoch_num}, + ) + predictions = [] + labels = [] + with torch.no_grad(): + for batch in dataloader: + batch_size = batch['paper'].batch_size + out = model( + { + node_name: node_feat.to(current_device).to(torch.float32) + for node_name, node_feat in batch.x_dict.items() + }, + batch.edge_index_dict + )[:batch_size] + labels.append(batch['paper'].y[:batch_size].cpu().numpy()) + predictions.append(out.argmax(1).cpu().numpy()) + + predictions = np.concatenate(predictions) + labels = np.concatenate(labels) + acc = sklearn.metrics.accuracy_score(labels, predictions) + + torch.cuda.synchronize() + dist.barrier() + + acc_tensor = torch.tensor(acc).to(current_device) + torch.distributed.all_reduce(acc_tensor, op=torch.distributed.ReduceOp.SUM) + global_acc = acc_tensor.item() / world_size + if rank == 0: + mllogger.event( + key=mllog_constants.EVAL_ACCURACY, + value=global_acc, + metadata={mllog_constants.EPOCH_NUM: epoch_num}, + ) + mllogger.end( + key=mllog_constants.EVAL_STOP, + metadata={mllog_constants.EPOCH_NUM: epoch_num}, + ) + return acc, global_acc + +def run_training_proc(rank, world_size, + hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out, + epochs, train_batch_size, val_batch_size, learning_rate, random_seed, dataset, + train_idx, val_idx, with_gpu, validation_acc, validation_frac_within_epoch, + evaluate_on_epoch_end, checkpoint_on_epoch_end, ckpt_steps, ckpt_path): + if rank == 0: + if ckpt_steps > 0: + ckpt_dir = create_ckpt_folder(base_dir=osp.dirname(osp.abspath(__file__))) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group('nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + glt.utils.common.seed_everything(random_seed) + current_device =torch.device(rank) + + print(f'Rank {rank} init graphlearn_torch NeighborLoader...') + # Create rank neighbor loader for training + train_idx = train_idx.split(train_idx.size(0) // world_size)[rank] + train_loader = glt.loader.NeighborLoader( + data=dataset, + num_neighbors=[int(fanout) for fanout in fan_out.split(',')], + input_nodes=('paper', train_idx), + batch_size=train_batch_size, + shuffle=True, + drop_last=False, + device=current_device, + seed=random_seed + ) + + # Create rank neighbor loader for validation. + val_idx = val_idx.split(val_idx.size(0) // world_size)[rank] + val_loader = glt.loader.NeighborLoader( + data=dataset, + num_neighbors=[int(fanout) for fanout in fan_out.split(',')], + input_nodes=('paper', val_idx), + batch_size=val_batch_size, + shuffle=True, + drop_last=False, + device=current_device, + seed=random_seed + ) + # Load checkpoint + ckpt = None + if ckpt_path is not None: + try: + ckpt = torch.load(ckpt_path) + except FileNotFoundError as e: + print(f"Checkpoint file not found: {e}") + return -1 + + # Define model and optimizer. + model = RGNN(dataset.get_edge_types(), + dataset.node_features['paper'].shape[1], + hidden_channels, + num_classes, + num_layers=num_layers, + dropout=0.2, + model=model_type, + heads=num_heads, + node_type='paper').to(current_device) + if ckpt is not None: + model.load_state_dict(ckpt['model_state_dict']) + model = DistributedDataParallel(model, + device_ids=[current_device.index] if with_gpu else None, + find_unused_parameters=True) + + param_size = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + size_all_mb = (param_size + buffer_size) / 1024**2 + print('model size: {:.3f}MB'.format(size_all_mb)) + + loss_fcn = torch.nn.CrossEntropyLoss().to(current_device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + if ckpt is not None: + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + + batch_num = (len(train_idx) + train_batch_size - 1) // train_batch_size + validation_freq = int(batch_num * validation_frac_within_epoch) + is_success = False + epoch_num = 0 + + training_start = time.time() + for epoch in tqdm.tqdm(range(epochs)): + model.train() + total_loss = 0 + train_acc = 0 + idx = 0 + gpu_mem_alloc = 0 + epoch_start = time.time() + for batch in train_loader: + idx += 1 + batch_size = batch['paper'].batch_size + out = model( + { + node_name: node_feat.to(current_device).to(torch.float32) + for node_name, node_feat in batch.x_dict.items() + }, + batch.edge_index_dict + )[:batch_size] + y = batch['paper'].y[:batch_size] + loss = loss_fcn(out, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() + train_acc += sklearn.metrics.accuracy_score(y.cpu().numpy(), + out.argmax(1).detach().cpu().numpy())*100 + gpu_mem_alloc += ( + torch.cuda.max_memory_allocated() / 1000000 + if with_gpu + else 0 + ) + #checkpoint + if ckpt_steps > 0 and idx % ckpt_steps == 0: + if with_gpu: + torch.cuda.synchronize() + dist.barrier() + if rank == 0: + epoch_num = round((epoch + idx / batch_num), 2) + glt.utils.common.save_ckpt(idx + epoch * batch_num, + ckpt_dir, model.module, optimizer, epoch_num) + dist.barrier() + # evaluate + if idx % validation_freq == 0: + if with_gpu: + torch.cuda.synchronize() + dist.barrier() + epoch_num = round((epoch + idx / batch_num), 2) + model.eval() + rank_val_acc, global_acc = evaluate(model, val_loader, current_device, + rank, world_size, epoch_num) + if validation_acc is not None and global_acc >= validation_acc: + is_success = True + break + model.train() + + if with_gpu: + torch.cuda.synchronize() + dist.barrier() + + #checkpoint at the end of epoch + if checkpoint_on_epoch_end: + if rank == 0: + epoch_num = epoch + 1 + glt.utils.common.save_ckpt(idx + epoch * batch_num, + ckpt_dir, model.module, optimizer, epoch_num) + dist.barrier() + + # evaluate at the end of epoch + if evaluate_on_epoch_end and not is_success: + epoch_num = epoch + 1 + model.eval() + rank_val_acc, global_acc = evaluate(model, val_loader, current_device, + rank, world_size, epoch_num) + if validation_acc is not None and global_acc >= validation_acc: + is_success = True + + #tqdm + train_acc /= idx + gpu_mem_alloc /= idx + tqdm.tqdm.write( + "Rank{:02d} | Epoch {:03d} | Loss {:.4f} | Train Acc {:.2f} | Val Acc {:.2f} | Time {} | GPU {:.1f} MB".format( + rank, + epoch, + total_loss, + train_acc, + rank_val_acc*100, + str(datetime.timedelta(seconds = int(time.time() - epoch_start))), + gpu_mem_alloc + ) + ) + + # stop training if success + if is_success: + break + + #log run status + if rank == 0: + status = mllog_constants.SUCCESS if is_success else mllog_constants.ABORTED + mllogger.end(key=mllog_constants.RUN_STOP, + metadata={mllog_constants.STATUS: status, + mllog_constants.EPOCH_NUM: epoch_num, + } + ) + print("Total time taken " + str(datetime.timedelta(seconds = int(time.time() - training_start)))) + + +if __name__ == '__main__': + mllogger.event(key=mllog_constants.CACHE_CLEAR, value=True) + mllogger.start(key=mllog_constants.INIT_START) + + parser = argparse.ArgumentParser() + root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh') + glt.utils.ensure_dir(root) + parser.add_argument('--path', type=str, default=root, + help='path containing the datasets') + parser.add_argument('--dataset_size', type=str, default='full', + choices=['tiny', 'small', 'medium', 'large', 'full'], + help='size of the datasets') + parser.add_argument('--num_classes', type=int, default=2983, + choices=[19, 2983], help='number of classes') + parser.add_argument('--in_memory', type=int, default=1, + choices=[0, 1], help='0:read only mmap_mode=r, 1:load into memory') + # Model + parser.add_argument('--model', type=str, default='rgat', + choices=['rgat', 'rsage']) + # Model parameters + parser.add_argument('--fan_out', type=str, default='15,10,5') + parser.add_argument('--train_batch_size', type=int, default=1024) + parser.add_argument('--val_batch_size', type=int, default=1024) + parser.add_argument('--hidden_channels', type=int, default=512) + parser.add_argument('--learning_rate', type=float, default=0.001) + parser.add_argument('--epochs', type=int, default=2) + parser.add_argument('--num_layers', type=int, default=3) + parser.add_argument('--num_heads', type=int, default=4) + parser.add_argument('--random_seed', type=int, default=42) + parser.add_argument("--cpu_mode", action="store_true", + help="Only use CPU for sampling and training, default is False.") + parser.add_argument("--edge_dir", type=str, default='in') + parser.add_argument('--layout', type=str, default='COO', + help="Layout of input graph. Default is COO.") + parser.add_argument("--pin_feature", action="store_true", + help="Pin the feature in host memory. Default is False.") + parser.add_argument("--use_fp16", action="store_true", + help="To use FP16 for loading the features. Default is False.") + parser.add_argument("--validation_frac_within_epoch", type=float, default=0.05, + help="Fraction of the epoch after which validation should be performed.") + parser.add_argument("--validation_acc", type=float, default=0.72, + help="Validation accuracy threshold to stop training once reached.") + parser.add_argument("--evaluate_on_epoch_end", action="store_true", + help="Evaluate using validation set on each epoch end.") + parser.add_argument("--checkpoint_on_epoch_end", action="store_true", + help="Save checkpoint on each epoch end.") + parser.add_argument('--ckpt_steps', type=int, default=-1, + help="Save checkpoint every n steps. Default is -1, which means no checkpoint is saved.") + parser.add_argument('--ckpt_path', type=str, default=None, + help="Path to load checkpoint from. Default is None.") + args = parser.parse_args() + args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available() + assert args.layout in ['COO', 'CSC', 'CSR'] + + glt.utils.common.seed_everything(args.random_seed) + world_size = torch.cuda.device_count() + submission_info(mllogger, 'GNN', 'reference_implementation') + mllogger.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=world_size*args.train_batch_size) + mllogger.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1) + mllogger.event(key=mllog_constants.OPT_NAME, value='adam') + mllogger.event(key=mllog_constants.OPT_BASE_LR, value=args.learning_rate) + mllogger.event(key=mllog_constants.SEED,value=args.random_seed) + mllogger.end(key=mllog_constants.INIT_STOP) + mllogger.start(key=mllog_constants.RUN_START) + + igbh_dataset = IGBHeteroDataset(args.path, args.dataset_size, args.in_memory, + args.num_classes==2983, True, args.layout, + args.use_fp16) + # init graphlearn_torch Dataset. + glt_dataset = glt.data.Dataset(edge_dir=args.edge_dir) + + glt_dataset.init_node_features( + node_feature_data=igbh_dataset.feat_dict, + with_gpu=args.with_gpu and args.pin_feature + ) + + glt_dataset.init_graph( + edge_index=igbh_dataset.edge_dict, + layout = args.layout, + graph_mode='ZERO_COPY' if args.with_gpu else 'CPU', + ) + + glt_dataset.init_node_labels(node_label_data={'paper': igbh_dataset.label}) + + train_idx = igbh_dataset.train_idx.clone().share_memory_() + val_idx = igbh_dataset.val_idx.clone().share_memory_() + + print('--- Launching training processes ...\n') + torch.multiprocessing.spawn( + run_training_proc, + args=(world_size, args.hidden_channels, args.num_classes, args.num_layers, + args.model, args.num_heads, args.fan_out, args.epochs, + args.train_batch_size, args.val_batch_size, + args.learning_rate, args.random_seed, + glt_dataset, train_idx, val_idx, args.with_gpu, + args.validation_acc, args.validation_frac_within_epoch, + args.evaluate_on_epoch_end, args.checkpoint_on_epoch_end, + args.ckpt_steps, args.ckpt_path), + nprocs=world_size, + join=True + ) diff --git a/graph_neural_network/utilities.py b/graph_neural_network/utilities.py new file mode 100644 index 000000000..10cb1514d --- /dev/null +++ b/graph_neural_network/utilities.py @@ -0,0 +1,11 @@ +import os +import time + +def create_ckpt_folder(base_dir, prefix="ckpt"): + timestamp = time.strftime("%Y%m%d-%H%M%S") + folder_name = f"{prefix}_{timestamp}" if prefix else timestamp + full_path = os.path.join(base_dir, folder_name) + if not os.path.exists(full_path): + os.makedirs(full_path) + return full_path +