From 4d929b5e6ef24aeaab97aea635286462423a39af Mon Sep 17 00:00:00 2001 From: robertturner <143536791+robertdhayanturner@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:24:39 -0500 Subject: [PATCH] Update node_representation_learning.md for content check by RK --- .../use_cases/node_representation_learning.md | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/docs/use_cases/node_representation_learning.md b/docs/use_cases/node_representation_learning.md index 81dc0043f..ae1c3a528 100644 --- a/docs/use_cases/node_representation_learning.md +++ b/docs/use_cases/node_representation_learning.md @@ -2,29 +2,26 @@ # Representation Learning on Graph Structured Data -## Introduction: how to represent relationships +## Introduction: representing things and relationships between them -Of the various types of information - words, pictures, and connections between things - **relationships** are especially interesting. Relationships show how things interact and create networks. But not all ways of representing relationships are the same. In machine learning, **how we do vector representation of relationships affects performance** on a wide range of tasks. +Of the various types of information - words, pictures, and connections between things - **relationships** are especially interesting. Relationships show how things interact and create networks. But not all ways of representing relationships are the same. In machine learning, **how we do vector representation of things and the relationships between them affects performance** on a wide range of tasks. Below, we evaluate several approaches to vector representation on a real-life use case: how well each approach classifies academic articles in a subset of the Cora citation network. We look first at Bag-of-Words (BoW), a standard approach to vectorizing text data in ML. Because BoW can't represent the network structure, we turn to solutions that can help BoW's performance: Node2Vec and GraphSAGE. We also look for a solution to BoW's other shortcoming - its inability to capture semantic meaning. We evaluate LLM embeddings, first on their own, then combined with Node2Vec, and, finally, LLM-trained GraphSAGE. -## Leading our dataset, evaluating BoW +## Loading our dataset, and evaluating BoW + +Our use case is a subset of the **Cora citation network**. This subset comprises 2708 scientific papers (nodes) and connections that indicate citations between them. Each paper has a BoW descriptor containing 1433 words. The papers in the dataset are also divided into 7 different topics (classes). Each paper belongs to exactly one of them. -**Our dataset: Cora** -Our use case is a subset of the Cora citation network. This subset comprises 2708 scientific papers (nodes) and connections that indicate citations between them. Each paper has a BoW descriptor containing 1433 words. The papers in the dataset are also divided into 7 different topics (classes). Each paper belongs to exactly one of them. - -**Loading the dataset** - -We load the dataset as follows: +We **load the dataset** as follows: ```python from torch_geometric.datasets import Planetoid ds = Planetoid("./data", "Cora")[0] ``` -**Evaluating BoW on a classification task** +### Evaluating BoW on a classification task We can evaluate how well the BoW descriptors represent the articles by measuring classification performance (Accuracy and macro F1). We use a KNN (K-Nearest Neighbors) classifier with 15 neighbors, and cosine similarity as the similarity metric: @@ -53,7 +50,7 @@ evaluate(ds.x, ds.y) BoW's accuracy and F1 macro scores are pretty good, but leave significant room for improvement. BoW falls short of correctly classify papers more than 25% of the time. And on average across classes BoW is inaccurate nearly 30% of the time. -## Improving on BoW: taking advantage of citation graph data +## Taking advantage of citation graph data Can we improve on this? Our citation network contains not only text data but also relationship data - a citation graph. Any given article will tend to cite other articles that belong to the same topic that it belongs to. Therefore, representations that embed not just textual data but also citation data of articles contained in our network will probably classify articles more accurately. @@ -79,7 +76,7 @@ Are there methods that capture node connectivity data better? Node2Vec is built to do precisely this, for static networks. So is GraphSAGE, for dynamic ones. Let's look at Node2Vec first. -## Learning node embeddings with Node2Vec +## Embedding network structure with Node2Vec As opposed to BoW vectors, node embeddings are vector representations that capture the structural role and properties of nodes in a network. Node2Vec is an algorithm that learns node representations using the Skip-Gram method; it models the conditional probability of encountering a context node given a source node in node sequences (random walks): @@ -96,7 +93,7 @@ The random walks are sampled according to a policy, which is guided by 2 paramet These parameters are particularly useful for accommodating different networks and tasks. Adjusting the values of $p$ and $q$ captures different characteristics of the graph in the sampled walks. BFS-like exploration is useful for learning local patterns. On the other hand, using DFS-like sampling is useful for capturing patterns on a bigger scale, like structural roles. -### Node2Vec embeddings +### Node2Vec embedding process In our example, we use the `torch_geometric` implementation of the Node2Vec algorithm. We **initialize the model** by specifying the following attributes: @@ -143,6 +140,8 @@ for epoch in range(200): print(f'Epoch: {epoch:03d}, Loss: {total_loss / len(loader):.4f}') ``` +### Node2Vec classification performance + Finally, now that we have a fully trained model, we can evaluate the learned embeddings on our classification task, using the `evaluate` function we defined earlier. ```python @@ -288,6 +287,8 @@ for epoch in range(100): print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}') ``` +### GraphSAGE Results + Next, we can embed nodes and evaluate the embeddings on the classification task: ```python @@ -297,7 +298,7 @@ evaluate(embeddings, ds.y) >>> F1 macro 0.820 ``` -The results are slightly worse than the results we got by combining Node2Vec with BoW features. But, remember, we're evaluating GraphSAGE because it can handle dynamic network data, whereas Node2Vec cannot. GraphSAGE embeddings perform well on our classification task _and_ are able to embed completely new nodes as well. When your use case involves new nodes or nodes that evolve, an induction model like GraphSAGE may be a better choice than Node2Vec. +The results are only slightly worse than the results we got by combining Node2Vec with BoW features. But, remember, we're evaluating GraphSAGE because it can handle dynamic network data, whereas Node2Vec cannot. **GraphSAGE embeddings perform well on our classification task _and_ are able to embed completely new nodes as well**. When your use case involves new nodes or nodes that evolve, an induction model like GraphSAGE may be a better choice than Node2Vec. ## Embedding semantics: LLM