This repository implements multiple Graph Neural Network architectures designed to improve the accuracy and efficiency of particle track finding in pixel detectors (such as those in the CMS and ATLAS experiments) for the High Luminosity LHC. For more information on tracking at the HL-LHC please see [1] - [3]. This project is supported by the Insititue for Research and Innovation in Software for High Energy Physics (IRIS-HEP) and maintained by Savannah Thais and Gage DeZoort. We work closely with the ExaTrkX project and Mark Neubauer's group at UIUC.
This code uses data from the CERN TrackML Challenge which can be downloaded from Kaggle. This repository includes the python library to easily read the data into a dataframe in the trackml-library
folder. Upon first downloading this repository you must install this library by running
pip install /path/to/repository
To run a GNN model over the Track ML dataset simply run
python prepare_XX.py configs/XX/prepare/config_name.yaml
to create the graphs and
python train_XX.py configs/XX/train/config_name.yaml
to train the GNN.
In both commands, 'XX' should be replaced with the name of your chosen model. There are many example config files for running with different configurations (i.e. including endcaps, data augmentation, etc) but you can also create a custom config for your task. Be sure to change the input directory
and output directory
commands in the config files to match your individual system. Note that the IN's plotting and training scripts use the same configuration file, configs/IN/train_IN.yaml.
Note: the prepare.py and train.py scripts must be run from the main directory. We include the core scripts for both models (Edge Classifier 1 and Interaction Network) there; additional prepare and train scripts can be found in the script folder.
All example code in this repository uses the CERN TrackML Challenge dataset. Each event is separated into four files: cells, hits, particles, and truth; we rely on hits to build the graphs and particles + truth to define the training labels.
The data is simulated at high pileup using a toy detector that roughly approximates the ATLAS detector. It is divided into modules (detector components) each of which is then divided into several layers as shown below. Here, we focus on the inner-most pixel layers, but this can be adjusted in the prepare.py scripts.
Each TrackML event is converted into a directed multigraph of hits connected by segments. Different pre-processing strategies are available, each with different graph construction efficiencies. This repo contains two such stratigies:
- Select one hit per particle per layer, connect hits in adjacent layers. This is the strategy used by the HEP.TrkX collaboration, which we denote "layer pairs (LP)" (see prepare_LP.py).
- Select hits between adjacent layers and hits within the same layer, requiring that same-layer hits are within some distance dR of each other. This is called "layer pairs plus (LPP)" (see prepare_LPP.py).
The models folder contains the structures to create graphs and multiple GNN models. The graphs are defined on a per-model basis, but all share four common attributes: a vector of hit features, a matrix representing incoming segments, a matrix representing outgoing segments, and a binary segment truth vector (x, R_i, R_o, and y respectively).
model/XX/graph.py: defines a graph as a namedTuple containing:
- x: a size (Nhits x 3) feature vector for each hit, containing the hit coordinates (r, phi, z).
- R_i: a size (Nhits x Nsegs) matrix whose entries (R_i)hs are 1 when segment s is incoming to hit h, and 0 otherwise.
- R_o: a size (Nhits x Nsegs) matrix whose entries (R_i)hs are 1 when segment s is outgoing from hit h, and 0 otherwise.
- R_a (IN only): a size (Nsegs x 1) vector whose sth entry is 0 if the segment s connects opposite-layer hits and 1 if segment s connects same-layer hits.
This model is based on the 2018 ExaTrkX paper which imploys a set of recurrent iterations of alternation EdgeNetwork and NodeNetwork components:
- NodeNetwork: computes new features for every node using the edge weight aggregated features of the connected nodes on the previous and next detector layers separately as well as the nodes' current features (a traditional Graph Convolutional Network
- Edge Network: computes weights for every edge of the graph using the features of the start and end nodes.
The IN is based on the model outlined in arXiv:1612.00222. In essence, the IN is a GNN architecture that applies relational and object models in stages to infer abstract interactions and object dynamics. We extend the paper's definition of an IN to include a recurrent structure, which involves repitions of the following cycle: marshall hits and segments into interaction terms, apply relational model to interaction terms to produce effects, aggregate effects across incoming segments to each hit, apply object model on hits and effects to produce re-embedded hit features, re-marshall the re-embedded hits, and apply the relational model a second time to produce edge weights (truth probabilities) for each segment. This cycle is formally defined in model/IN/interaction_network.py, which depends on the following relational and object models:
- model/IN/relational_model.py: a MLP that outputs 1 parameter per segment, which we interpret as an edge weight
- model/IN/object_model.py: a MLP that outputs 3 parameters per hit, which are the re-embedded position features of the hit
The scripts train_XX.py build and train the specified GNN on preprocessed TrackML graphs. The training definitions for each model live in the models/trainers folder.
The training data is organized into mini-batches, which are used to optimize the loss fuction specified in your config file using your selected optimizer. The network is then tested on a separate sample of pre-processed graphs. The results of the trained network can be tested using various scripts described in the Analyses Section below.
In the slurm
folder we provide scripts to submit the pre-processing and training processes as jobs on a slurm-based job submission system.
The analyses folder contains measurement scripts and plotting functions. Different tools are defined for the different stages of this project, namely graph construction (the "prepare" stage), model training (the "train" stage), and track finding (the "tracking" stage).
Graph construction efficiencies depend on the specific graph construction method. For this reason, we have defined two scripts that generate the true number of segments we'd expect from a single constructed graph (for various pt cuts ranging between 0.5 and 5 GeV): analyses/truth/generate_truth_LP.py and analyses/truth/generate_truth_LPP.py. Graph construction involves cuts that diminish the number of true segments in our graphs, so it is important to know how many we started with. This information is stored in "truth tables," which are used as inputs to the graph construction measurement script:
- analyses/prepare_efficiencies.py: calculate and plot key graph construction metrics, including segment efficiency (defined as sum(y)/len(y) for each graph), truth efficiency (defined as sum(y)/total_true for each graph), graph size, hits per graph, and nodes per graph.
The plotting functions themselves are contained in analyses/prepare_plots.py. YAML files like anlyses/prepare_efficiencies_IN.yaml specify the model and graph construction strategy we're interested in evaluating.
Several post-training plots and metrics are defined in analyses/training_plots.py, including loss-per-epoch plots, color-coded r-z and x-y views of the classified segments, confusion matrix entries plotted as a function of edge weight cut ("discriminant" value), true/false edges plotted as a function of edge weight, and ROC curves. These plots are key to evaluating our models, and are produced in scripts like plot_IN.py.
Upcoming - stay tuned!
[1] "HL-LHC Tracking Challenge", talk by Jean-Roch Vlimant
[2] "Tracking performance with the HL-LHC ATLAS detector", paper by the ATLAS Collaboration
[3] "Expected Performance of Tracking in CMS at the HL-LHC", paper by the CMS Collaboration
[4] “Interaction networks for learning about objects relations and physics”, arXiv:1612.00222