Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNN] Reference implementation for GNN node classification #700

Merged
merged 10 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions graph_neural_network/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel

WORKDIR /workspace/repository

RUN pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
RUN pip install scikit-learn==0.24.2
RUN pip install torch_geometric==2.4.0
RUN pip install --no-index torch_scatter==2.1.1 torch_sparse==0.6.17 -f https://data.pyg.org/whl/torch-1.13.0+cu117.html
RUN pip install graphlearn-torch==0.2.2

RUN apt update
RUN apt install -y git
RUN pip install git+https://github.com/mlcommons/logging.git

# TF32 instead of FP32 for faster compute
ENV NVIDIA_TF32_OVERRIDE=1

RUN git clone https://github.com/alibaba/graphlearn-for-pytorch.git
WORKDIR /workspace/repository/graphlearn-for-pytorch/examples/igbh
170 changes: 170 additions & 0 deletions graph_neural_network/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# 1. Problem
This benchmark represents a multi-class node classification task in a heterogenous graph using the [IGB Heterogeneous Dataset](https://github.com/IllinoisGraphBenchmark/IGB-Datasets) named IGBH-Full. The task is carried out using a [GAT](https://arxiv.org/abs/1710.10903) model based on the [Relational Graph Attention Networks](https://arxiv.org/abs/1904.05811) paper.

The reference implementation is based on [graphlearn-for-pytorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch).

# 2. Directions
### Steps to configure machine

#### 1. Clone the repository:
```bash
git clone https://github.com/alibaba/graphlearn-for-pytorch
```

or
```bash
git clone https://github.com/mlcommons/training.git
```
once `GNN node classification` is merged into `mlcommons/training`.

#### 2. Build the docker image:

If you cloned the `graphlearn-for-pytorch` repository:
```bash
cd graphlearn-for-pytorch/examples/igbh/
docker build -f Dockerfile -t training_gnn:latest .
```

If you cloned the `mlcommons/training` repository:
```bash
cd training/gnn_node_classification/
docker build -f Dockerfile -t training_gnn:latest .
```


### Steps to download and verify data
Download the dataset:
```bash

bash download_igbh_full.sh
```

Before training, generate the seeds for training and validation:
```bash
python split_seeds.py --dataset_size='full'
```

The size of the `IGBH-Full` dataset is 2.2 TB. If you want to test with
the `tiny`, `small` or `medium` datasets, the download procedure is included
in the training script.

### Steps to run and time

#### Single-node Training

The original graph is in the `COO` format and the feature is in the FP32 format. The training script will transform the graph from `COO` to `CSC` and convert the feature to FP16, which could be time consuming due to the graph scale. We provide a script to convert the graph layout from `COO` to `CSC` and persist the feature in FP16 format:

```bash
python compress_graph.py --dataset_size='full' --layout='CSC' --use_fp16
```

To train the model using multiple GPUs:
```bash
CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='full' --layout='CSC' --use_fp16
```
The number of training processes is equal to the number of GPUS. Option `--pin_feature` decides if the feature data will be pinned in host memory, which enables zero-copy feature access from GPU, but will incur extra memory costs.


#### Distributed Training

##### 1. Data Partitioning
To partition the dataset (including both the topology and feature):
```bash
python partition.py --dataset_size='full' --num_partitions=2 --use_fp16 --layout='CSC'
```
The above script will partition the dataset into two parts, convert the feature into
the FP16 format, and transform the graph layout from `COO` to `CSC`.

We suggest using a distributed file system to store the partitioned data, such as HDFS or NFS, suhc that partitioned data can be accessed by all training nodes.

##### 2. Two-stage Data Partitioning
To speed up the partitioning process, GLT also supports two-stage partitioning, which splits the process of topology partitioning and feature partitioning. After the topology partitioning is executed in a single node, the feature partitioning process can be conducted in each training node in parallel to speedup the partitioning process.

The topology partitioning is conducted by executing:
```bash
python partition.py --dataset_size='full' --num_partitions=2 --layout='CSC' --with_feature=0
```

The feature partitioning in conducted in each training node:
```bash
# node 0 which holds partition 0:
python build_partition_feature.py --dataset_size='full' --use_fp16 --in_memory=0 --partition_idx=0

# node 1 which holds partition 1:
python build_partition_feature.py --dataset_size='full' --use_fp16 --in_memory=0 --partition_idx=1
```

##### 2. Model Training
The number of partitions and number of training nodes must be the same. In each training node, the model can be trained using the following command:

```bash
# node 0:
CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=2 --master_addr=master_address_ip --model='rgat' --dataset_size='full' --layout='CSC'

# node 1:
CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=2 --master_addr=master_address_ip --model='rgat' --dataset_size='full' --layout='CSC'
```
The above script assumes that the training nodes are equipped with 2 GPUs and the number of training processes is equal to the number of GPUs. Each training process has a corresponding
sampling process using the same GPU.

The `master_address_ip` should be replaced with the actual IP address of the master node. The `--pin_feature` option decides if the feature data will be pinned in host memory, which enables zero-copy feature access from GPU but will incur extra memory costs.


We recommend separating the sampling and training processes to different GPUs to achieve better performance. To seperate the GPU used by sampling and training processes, please add `--split_training_sampling` and set `--num_training_procs` as half of the number of devices:

```bash
# node 0:
CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='full' --layout='CSC' --split_training_sampling

# node 1:
CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='full' --layout='CSC' --split_training_sampling
```
The above script uses one GPU for training and another for sampling in each node.



# 3. Dataset/Environment
### Publication/Attribution
Arpandeep Khatua, Vikram Sharma Mailthody, Bhagyashree Taleka, Tengfei Ma, Xiang Song, and Wen-mei Hwu. 2023. IGB: Addressing The Gaps In Labeling, Features, Heterogeneity, and Size of Public Graph Datasets for Deep Learning Research. In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD '23). Association for Computing Machinery, New York, NY, USA, 4284–4295. https://doi.org/10.1145/3580305.3599843

### Data preprocessing
The original graph is in the `COO` format and the feature is in FP32 format. It is allowed to transform the graph from `COO` to `CSC` and convert the feature to FP16 (supported by the training script).

### Training and test data separation
The training and validation data are selected from the labeled ``paper`` nodes from the dataset and are generated by `split_seeds.py`. Differnet random seeds will result in different training and test data.

### Training data order
Randomly.

### Test data order
Randomly.

# 4. Model
### Publication/Attribution
Dan Busbridge and Dane Sherburn and Pietro Cavallo and Nils Y. Hammerla, Relational Graph Attention Networks, 2019, https://arxiv.org/abs/1904.05811

### List of layers
Three-layer RGAT model

### Loss function
CrossEntropyLoss

### Optimizer
Adam

# 5. Quality
### Quality metric
The validation accuracy is the target quality metric.
### Quality target
0.72
### Evaluation frequency
4,730,280 training seeds (5% of the entire training seeds, evaluated every 0.05 epoch)
### Evaluation thoroughness
788,380 validation seeds

# 6. Contributors
This benchmark is a collaborative effort with contributions from Alibaba, Intel, and Nvidia:

- Alibaba: Li Su, Baole Ai, Wenting Shen, Shuxian Hu, Wenyuan Yu, Yong Li
- Nvidia: Yunzhou (David) Liu, Kyle Kranen, Shriya Palasamudram
- Intel: Kaixuan Liu, Hesham Mostafa, Sasikanth Avancha, Keith Achorn, Radha Giduthuri, Deepak Canchi
61 changes: 61 additions & 0 deletions graph_neural_network/build_partition_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse
import os.path as osp

import graphlearn_torch as glt
import torch

from dataset import IGBHeteroDataset


def partition_feature(src_path: str,
dst_path: str,
partition_idx: int,
chunk_size: int,
dataset_size: str='tiny',
in_memory: bool=True,
use_fp16: bool=False):
print(f'-- Loading igbh_{dataset_size} ...')
data = IGBHeteroDataset(src_path, dataset_size, in_memory, with_edges=False, use_fp16=use_fp16)

print(f'-- Build feature for partition {partition_idx} ...')
dst_path = osp.join(dst_path, f'{dataset_size}-partitions')
node_feat_dtype = torch.float16 if use_fp16 else torch.float32
glt.partition.base.build_partition_feature(root_dir = dst_path,
partition_idx = partition_idx,
chunk_size = chunk_size,
node_feat = data.feat_dict,
node_feat_dtype = node_feat_dtype)


if __name__ == '__main__':
root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
glt.utils.ensure_dir(root)
parser = argparse.ArgumentParser(description="Arguments for partitioning ogbn datasets.")
parser.add_argument('--src_path', type=str, default=root,
help='path containing the datasets')
parser.add_argument('--dst_path', type=str, default=root,
help='path containing the partitioned datasets')
parser.add_argument('--dataset_size', type=str, default='full',
choices=['tiny', 'small', 'medium', 'large', 'full'],
help='size of the datasets')
parser.add_argument('--in_memory', type=int, default=0,
choices=[0, 1], help='0:read only mmap_mode=r, 1:load into memory')
parser.add_argument("--partition_idx", type=int, default=0,
help="Index of a partition")
parser.add_argument("--chunk_size", type=int, default=10000,
help="Chunk size for feature partitioning.")
parser.add_argument("--use_fp16", action="store_true",
help="save node/edge feature using fp16 format")


args = parser.parse_args()

partition_feature(
args.src_path,
args.dst_path,
partition_idx=args.partition_idx,
chunk_size=args.chunk_size,
dataset_size=args.dataset_size,
in_memory=args.in_memory==1,
use_fp16=args.use_fp16
)
120 changes: 120 additions & 0 deletions graph_neural_network/compress_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import argparse, datetime, os
import numpy as np
import torch
import os.path as osp

import graphlearn_torch as glt

from dataset import float2half
from download import download_dataset
from torch_geometric.utils import add_self_loops, remove_self_loops
from typing import Literal


class IGBHeteroDatasetCompress(object):
def __init__(self,
path,
dataset_size,
layout: Literal['CSC', 'CSR'] = 'CSC',):
self.dir = path
self.dataset_size = dataset_size
self.layout = layout

self.ntypes = ['paper', 'author', 'institute', 'fos']
self.etypes = None
self.edge_dict = {}
self.paper_nodes_num = {'tiny':100000, 'small':1000000, 'medium':10000000, 'large':100000000, 'full':269346174}
self.author_nodes_num = {'tiny':357041, 'small':1926066, 'medium':15544654, 'large':116959896, 'full':277220883}
if not osp.exists(osp.join(path, self.dataset_size, 'processed')):
download_dataset(path, 'heterogeneous', dataset_size)
self.process()

def process(self):
paper_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__cites__paper', 'edge_index.npy'))).t()
author_paper_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__written_by__author', 'edge_index.npy'))).t()
affiliation_author_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'author__affiliated_to__institute', 'edge_index.npy'))).t()
paper_fos_edges = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__topic__fos', 'edge_index.npy'))).t()
if self.dataset_size in ['large', 'full']:
paper_published_journal = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__published__journal', 'edge_index.npy'))).t()
paper_venue_conference = torch.from_numpy(np.load(osp.join(self.dir, self.dataset_size, 'processed',
'paper__venue__conference', 'edge_index.npy'))).t()

cites_edge = add_self_loops(remove_self_loops(paper_paper_edges)[0])[0]
self.edge_dict = {
('paper', 'cites', 'paper'): (torch.cat([cites_edge[1, :], cites_edge[0, :]]), torch.cat([cites_edge[0, :], cites_edge[1, :]])),
('paper', 'written_by', 'author'): author_paper_edges,
('author', 'affiliated_to', 'institute'): affiliation_author_edges,
('paper', 'topic', 'fos'): paper_fos_edges,
('author', 'rev_written_by', 'paper'): (author_paper_edges[1, :], author_paper_edges[0, :]),
('institute', 'rev_affiliated_to', 'author'): (affiliation_author_edges[1, :], affiliation_author_edges[0, :]),
('fos', 'rev_topic', 'paper'): (paper_fos_edges[1, :], paper_fos_edges[0, :])
}
if self.dataset_size in ['large', 'full']:
self.edge_dict[('paper', 'published', 'journal')] = paper_published_journal
self.edge_dict[('paper', 'venue', 'conference')] = paper_venue_conference
self.edge_dict[('journal', 'rev_published', 'paper')] = (paper_published_journal[1, :], paper_published_journal[0, :])
self.edge_dict[('conference', 'rev_venue', 'paper')] = (paper_venue_conference[1, :], paper_venue_conference[0, :])
self.etypes = list(self.edge_dict.keys())

# init graphlearn_torch Dataset.
edge_dir = 'out' if self.layout == 'CSR' else 'in'
glt_dataset = glt.data.Dataset(edge_dir=edge_dir)
glt_dataset.init_graph(
edge_index=self.edge_dict,
graph_mode='CPU',
)

# save the corresponding csr or csc file
compress_edge_dict = {}
compress_edge_dict[('paper', 'cites', 'paper')] = 'paper__cites__paper'
compress_edge_dict[('paper', 'written_by', 'author')] = 'paper__written_by__author'
compress_edge_dict[('author', 'affiliated_to', 'institute')] = 'author__affiliated_to__institute'
compress_edge_dict[('paper', 'topic', 'fos')] = 'paper__topic__fos'
compress_edge_dict[('author', 'rev_written_by', 'paper')] = 'author__rev_written_by__paper'
compress_edge_dict[('institute', 'rev_affiliated_to', 'author')] = 'institute__rev_affiliated_to__author'
compress_edge_dict[('fos', 'rev_topic', 'paper')] = 'fos__rev_topic__paper'
compress_edge_dict[('paper', 'published', 'journal')] = 'paper__published__journal'
compress_edge_dict[('paper', 'venue', 'conference')] = 'paper__venue__conference'
compress_edge_dict[('journal', 'rev_published', 'paper')] = 'journal__rev_published__paper'
compress_edge_dict[('conference', 'rev_venue', 'paper')] = 'conference__rev_venue__paper'

for etype in self.etypes:
graph = glt_dataset.get_graph(etype)
indptr, indices, _ = graph.export_topology()
path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype])
if not os.path.exists(path):
os.makedirs(path)
torch.save(indptr, os.path.join(path, 'indptr.pt'))
torch.save(indices, os.path.join(path, 'indices.pt'))
path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout)
print(f"The {self.layout} graph has been persisted in path: {path}")



if __name__ == '__main__':
parser = argparse.ArgumentParser()
root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
glt.utils.ensure_dir(root)
parser.add_argument('--path', type=str, default=root,
help='path containing the datasets')
parser.add_argument('--dataset_size', type=str, default='full',
choices=['tiny', 'small', 'medium', 'large', 'full'],
help='size of the datasets')
parser.add_argument("--layout", type=str, default='CSC')
parser.add_argument('--use_fp16', action="store_true",
help="convert the node/edge feature into fp16 format")
args = parser.parse_args()
print(f"Start constructing the {args.layout} graph...")
igbh_dataset = IGBHeteroDatasetCompress(args.path, args.dataset_size, args.layout)
if args.use_fp16:
base_path = osp.join(args.path, args.dataset_size, 'processed')
float2half(base_path, args.dataset_size)




Loading
Loading