Skip to content

Commit

Permalink
Update node_representation_learning.md
Browse files Browse the repository at this point in the history
in progress
  • Loading branch information
robertdhayanturner authored Jan 8, 2024
1 parent 38f1c18 commit 327dac7
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions docs/use_cases/node_representation_learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Though BoW representations embody _some_ information about article connectivity,

Can we make up for BoW's inability to represent the citation network's structure? Are their methods that capture node and node connectivity data better?

Node2Vec is built to do precisely this. So is GraphSAGE. First, let's look at Node2Vec.
Node2Vec is built to do precisely this. So is GraphSAGE.
First, let's look at Node2Vec.

## Learning node embeddings with Node2Vec

Expand Down Expand Up @@ -146,15 +147,17 @@ Let's also see if Node2Vec does a better job of **representing citation data** t
This time, using Node2Vec we can see a well defined separation; these embeddings capture the connectivity of the graph much better than BoW did.

But can we _further_ improve classification performance?
One approach is _combining_ the two information sources, relations (Node2Vec) embeddings and textual (BoW) features.
One method is _combining_ the two information sources, relationship (Node2Vec) embeddings and textual (BoW) features.

### Node2Vec embeddings + text-based (BoW) features

A straightforward approach for combining vectors from different sources is to concatenate them dimension-wise. We have BoW features `v_bow` and Node2Vec embeddings `v_n2v`. The fused representation would then be `v_fused = torch.cat((v_n2v, v_bow), dim=1)`. Before combining them, we should examine the L2 norm distribution of both embeddings, to ensure that one kind of representations will not dominate the other:
A straightforward approach for combining vectors from different sources is to **concatenate them dimension-wise**. We have BoW features `v_bow` and Node2Vec embeddings `v_n2v`. The fused representation would then be `v_fused = torch.cat((v_n2v, v_bow), dim=1)`. But before combining them, we should examine the L2 norm distribution of both embeddings, to ensure that one kind of representations will not dominate the other:

![L2 norm distribution of text based and Node2Vec embeddings](../assets/use_cases/node_representation_learning/l2_norm.png)

From the plot above, it's clear that the scales of the embedding vector lengths differ. To avoid the larger magnitude Node2Vec vector overshadowing the BoW vector, we can divide each embedding vector by their average length. But we can _further_ optimize performance by introducing a **weighting factor** ($\alpha$). The combined representations are constructed as `x = torch.cat((alpha * v_n2v, v_bow), dim=1)`. To determine the appropriate value for $\alpha$, we employ a 1D grid search approach. The results are displayed in the following plot.
From the plot above, it's clear that the scales of the embedding vector lengths differ. To avoid the larger magnitude Node2Vec vector overshadowing the BoW vector, we can divide each embedding vector by their average length.

But we can _further_ optimize performance by introducing a **weighting factor** ($\alpha$). The combined representations are constructed as `x = torch.cat((alpha * v_n2v, v_bow), dim=1)`. To determine the appropriate value for $\alpha$, we employ a 1D grid search approach. Our results are displayed in the following plot.

![Grid search for alpha](../assets/use_cases/node_representation_learning/grid_search_alpha_bow.png)

Expand All @@ -170,7 +173,7 @@ evaluate(x, ds.y)
>>> F1 macro 0.831
```

By combining the representations of the network structure (Node2Vec) and text (BoW) of the paper, we were able to improve performance on article classification. Specifically, the Node2Vec + BoW fusion resulted in a 3.6% improvement from the Node2Vec-only and 15.4% from the BoW-only classifiers.
By combining representations of the network structure (Node2Vec) and text (BoW) of the articles, we were able to significantly improve performance on article classification. Relatively, the Node2Vec + BoW fusion performed 3% better than the Node2Vec-only, and 11.4% better the BoW-only classifiers.

These are impressive results. **But what if our citation network grows? What happens when new papers need to be classified?**

Expand Down Expand Up @@ -286,44 +289,41 @@ The results are slightly worse than the results we got by combining Node2Vec wit

In addition to not being able to represent network structure, BoW vectors - because they treat words as contextless occurrences, merely in terms of their frequency - can't capture semantic meaning, and therefore performs less well (on classification and ) article relatedness... tasks than approaches that can do semantic embedding. Let's summarize the classification performance results we obtained above using BoW features.

| Metric | BoW | Node2Vec | Node2Vec+BoW | GraphSAGE+BoW |
| Metric | BoW | Node2Vec | Node2Vec+BoW | GraphSAGE(BoW-trained) |
| --- | --- | --- | --- | --- |
| Accuracy | 0.738 | 0.822 | 0.852 | 0.844 |
| F1 (macro) | 0.701 | 0.803 | 0.831 | 0.820 |

These article classification results **can be improved further using LLM embeddings**, because LLM embeddings excel in capturing semantic meaning.

To do this, we used the `all-mpnet-base-v2` model available on [Hugging Face](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) for embedding the title and abstract of each paper.
...everything is done in exactly the same way as with the BoW features, we just simply replace them with the LLM features..
To do this, we used the `all-mpnet-base-v2` model available on [Hugging Face](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) for embedding the title and abstract of each paper. Otherwise, data loading and optimizing is done in exactly the same way as indicated in the code snippets above when we used BoW features, we just simply replace them with LLM features.

The results obtained with LLM only, Node2Vec combined with LLM and GraphSAGE trained on LLM features can be found in the following table along with the relative improvement compared to using the BoW features:
The results obtained with LLM only, Node2Vec combined with LLM, and GraphSAGE trained on LLM can be found in the following table along with the _relative_ improvement compared to using the BoW features:

| Metric | LLM | Node2Vec | GraphSAGE |
| Metric | LLM | Node2Vec+LLM | GraphSAGE(LLM-trained) |
| --- | --- | --- | --- |
| Accuracy | 0.816 (+10%) | **0.867** (+1.7%) | 0.852 (+0.9%) |
| F1 (macro) | 0.779 (+11%) | **0.840** (+1%) | 0.831 (+1.3%) |

| Accuracy | 0.816 (+7.8%) | **0.867** (+1.5%) | 0.852 (+0.8%) |
| F1 (macro) | 0.779 (+7.8%) | **0.840** (+0.9%) | 0.831 (+1.1%) |


## Conclusion
## Conclusion: LLM, Node2Vec, GraphSAGE better at learning node and node relationship data than BoW

From all of the results we can draw the following conclusions (on this dataset):
At least for classification tasks on our article citation dataset, we can conclude that:

1. LLM features beat BoW features in all scenarios.
2. Combining text-based representations with network structure results in improved classification performance. (and what about article similarity?)
2. Combining text-based representations with network structure was better than either alone. (and what about article similarity?)
3. We achieved the best results using Node2Vec with LLM features.


As a final note, we've included a pro vs con comparison of our two node representation learning algorithms (Node2Vec and GraphSAGE), to help with thinking about which model might be a better fit for your use case:
As a final note, we've included a **pro vs con comparison** of our two node representation learning algorithms - **Node2Vec and GraphSAGE**, to help with thinking about which model might be a better fit for your use case:

| Aspect | Node2Vec | GraphSAGE|
| --- | --- | --- |
| Generalizing to new nodes | No | Yes |
| Inference time | Constant | We have control over the inference time |
| Accomodating different graph types and objectives | By setting the $p$ and $q$ parameters we can adapt the representations to our fit | Limited control |
| Combining with other representations | Concatenation | By design the model learns to map node representations to embeddings |
| Dependency on additional representations | Relies solely on graph data |Relies on quality and availability of node representations; impacts model performance if lacking |
| Embedding flexibility | Very flexible node representations | Neighboring nodes can't have much variation in their representations
| Inference time | Constant | We have control over inference time |
| Accommodating different graph types and objectives | By setting the $p$ and $q$ parameters, we can adapt representations to our fit | Limited control |
| Combining with other representations | Concatenation | By design, the model learns to map node representations to embeddings |
| Dependency on additional representations | Relies solely on graph data | Relies on quality and availability of node representations; impacts model performance if lacking |
| Embedding flexibility | Very flexible node representations | Neighboring nodes can't have much variation in their representations |

---
## Contributors
Expand Down

0 comments on commit 327dac7

Please sign in to comment.