-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chronologer Retention Time model predictor #761
Open
elaboy
wants to merge
109
commits into
smith-chem-wisc:master
Choose a base branch
from
elaboy:Chronologer
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
109 commits
Select commit
Hold shift + click to select a range
a6b1639
correct Within calculation
fa4da8b
update unit tests
3246567
conflicts resolved back to upstream
a018d4d
Merge remote-tracking branch 'upstream/master'
15a37d0
Merge remote-tracking branch 'upstream/master'
892fa45
this is the spot
211013c
Merge remote-tracking branch 'upstream/master'
68104ee
Merge branch 'master' of https://github.com/trishorts/mzLib
trishorts d715a08
Merge remote-tracking branch 'upstream/master'
3565522
Merge remote-tracking branch 'upstream/master'
72e7b53
Merge remote-tracking branch 'upstream/master'
593872a
Merge remote-tracking branch 'upstream/master'
trishorts 42dd034
Merge branch 'master' of https://github.com/trishorts/mzLib
trishorts fbeaec0
Merge remote-tracking branch 'upstream/master'
trishorts 614ded7
Merge remote-tracking branch 'upstream/master'
47307c8
Merge branch 'master' of https://github.com/trishorts/mzLib
28e05ae
Merge remote-tracking branch 'upstream/master'
0a7c609
Merge remote-tracking branch 'upstream/master'
630d8c7
Merge remote-tracking branch 'upstream/master'
trishorts f6a386b
Merge branch 'master' of https://github.com/trishorts/mzLib
trishorts d673800
Merge remote-tracking branch 'upstream/master'
675a0ae
Merge branch 'master' of https://github.com/trishorts/mzLib
15d4baf
Merge remote-tracking branch 'upstream/master'
03ca9f7
Merge remote-tracking branch 'upstream/master'
d0a4c79
Merge remote-tracking branch 'upstream/master'
894b998
Merge remote-tracking branch 'upstream/master'
88269a1
Merge remote-tracking branch 'upstream/master'
trishorts 9a9b24a
Merge remote-tracking branch 'upstream/master'
trishorts b4ad231
add space
trishorts bc59b38
Merge remote-tracking branch 'upstream/master'
trishorts f3c83ae
first move
trishorts d6d934b
psmFromTsv unit tests
trishorts 2db71cd
moved library spectrum
trishorts 562f69d
empty unit test for library spectrum
trishorts d3dcbe9
m
trishorts 2c4334a
library spectrum unit tests
trishorts a86d68e
lib spec unit tests
trishorts c7ce32d
PSMTSV unit tests
trishorts c610791
add tests for variants and localized glycans
trishorts 5e09c14
capitalization convention
trishorts 9055644
read internal ions test
trishorts 74b80ad
uncomment lines
trishorts d1bc75c
moved fragmentation and library spectrum to new project Omics
trishorts cec311a
Revert "moved fragmentation and library spectrum to new project Omics"
trishorts 8d88b32
someInterfaces
trishorts df0f605
good midpont
trishorts cad0d1c
omics classes and interfaces seem tobe working
trishorts 8991e14
move LibrarySpectrum class to Omics. Create SpectrumMatchFromTsvHeade…
trishorts 02bf807
not working
trishorts b7d15d6
Fixed up the PR
nbollis 2502322
Merge pull request #2 from trishorts/tempPsmFromTsv
trishorts 924e99f
fix broken test
trishorts 10f53a2
some unit tests
trishorts d0a55b2
dhg
trishorts 81f9338
Expanded test coverage on file classes
nbollis 382c0da
new header and xlink psmtsv reader unit tests
trishorts 3abe9a3
CPU(windows, linux, and mac) dll
elaboy 71c3ead
init
elaboy 7a84810
Merge branch 'pr/737' into TrainingMethodsForChronologer
elaboy 79e3d09
Custom Datasets and training functions
elaboy 848f81c
cool progress
elaboy d8576aa
training working
elaboy 81fe5b6
Working
elaboy d9bf11a
updated Directory
elaboy 4b4d624
cleaning code
elaboy 7ecc7c6
Update ChronologerRetentionTimeEstimator.cs
elaboy e786bd6
Merge branch 'master' into Chronologer
elaboy 816f031
.
elaboy 4cdc4cf
Update TerminusSpecificProductTypes.cs
elaboy 367cd94
Delete ChronologerTest.tsv
elaboy d139c70
Merge branch 'Chronologer' of https://github.com/elaboy/mzLib-Fork in…
elaboy 638a635
Update TestFlashLFQ.cs
elaboy 143f45f
internal and comments
elaboy 1dddd92
Merge branch 'master' into Chronologer
elaboy ba8e65c
changed estimator class and added comments
elaboy 684842a
Merge branch 'Chronologer' of https://github.com/elaboy/mzLib-Fork in…
elaboy ebee682
oops
elaboy 10fd45b
Merge branch 'master' into Chronologer
trishorts adfe301
Merge branch 'master' into Chronologer
trishorts 06f0798
Merge branch 'master' into Chronologer
elaboy a484e08
static method to access chronologer
elaboy fb5e618
Got rid of variables that were not being used
elaboy 383ea53
Merge branch 'master' into Chronologer
elaboy 371a077
fixed the terminus integers and N-acetylation
elaboy a70cc12
Merge branch 'Chronologer' of https://github.com/elaboy/mzLib-Fork in…
elaboy d2f0a34
Merge branch 'master' into Chronologer
elaboy 80f6373
updated the dictionary to include all the supported mods
elaboy 01e339b
more testing and changes to the tensorize method
elaboy 3279426
removed unused/repeated packagess
elaboy d9b34e3
Merge branch 'master' into Chronologer
elaboy c599e9d
tensorize method now looks at the base sequence for selenocysteine an…
elaboy c5b852b
Merge branch 'master' into Chronologer
nbollis 0e0fba3
Merge branch 'master' into Chronologer
trishorts 345577b
Merge branch 'master' into Chronologer
trishorts d93cf8f
making gpu available
elaboy d96e0cb
Revert "Merge branch 'master' into Chronologer"
elaboy bd08d85
making gpu available
elaboy a9fba92
Merge branch 'master' into Chronologer
elaboy 6b0d3f7
cuda support
elaboy 4b8bf13
Merge branch 'Chronologer' of https://github.com/elaboy/mzLib-Fork in…
elaboy 833d810
adds cuda support
elaboy abb88a5
add cuda support
elaboy c3369bb
try catch for cuda intitialization
elaboy 38d6ada
removed tests that where not supposed to be in this branch
elaboy 629c2a3
contructor is now internal, now use to it being public and some cleanup
elaboy 1635fee
fixing tests and making identification of non-compatible sequences cl…
elaboy 79af77c
removing unused code
elaboy 495444b
same
elaboy a99a4c4
cpu torchsharp for OSX
elaboy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
3 changes: 2 additions & 1 deletion
3
mzLib/Omics/Fragmentation/Peptide/TerminusSpecificProductTypes.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
|
||
namespace Omics.Fragmentation.Peptide | ||
{ | ||
{ | ||
public class TerminusSpecificProductTypes | ||
{ | ||
/// <summary> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
mzLib/Proteomics/RetentionTimePrediction/ChronologerModel/Chronologer.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
using System; | ||
using System.IO; | ||
using TorchSharp; | ||
using TorchSharp.Modules; | ||
|
||
namespace Proteomics.RetentionTimePrediction.ChronologerModel; | ||
/// <summary> | ||
/// Chronologer is a deep learning model for highly accurate prediction of peptide C18 retention times (reported in % ACN). | ||
/// Chronologer was trained on a new large harmonized database of > 2.6 million retention time observations | ||
/// (2.25M unique peptides) constructed from 11 community datasets | ||
/// and natively supports prediction of 17 different modification types. | ||
/// With only a few observations of a new modification type (> 10 peptides), | ||
/// Chronologer can be easily re-trained to predict up to 10 user supplied modifications. | ||
/// | ||
/// Damien Beau Wilburn, Ariana E. Shannon, Vic Spicer, Alicia L. Richards, Darien Yeung, Danielle L. Swaney, Oleg V. Krokhin, Brian C. Searle | ||
/// bioRxiv 2023.05.30.542978; doi: https://doi.org/10.1101/2023.05.30.542978 | ||
/// | ||
/// https://github.com/searlelab/chronologer | ||
/// | ||
/// Licensed under Apache License 2.0 | ||
/// | ||
/// </summary> | ||
internal class Chronologer : torch.nn.Module<torch.Tensor, torch.Tensor> | ||
{ | ||
internal Chronologer() : this(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, | ||
"RetentionTimePrediction", | ||
"ChronologerModel", "Chronologer_20220601193755_TorchSharp.dat")) | ||
{ | ||
RegisterComponents(); | ||
} | ||
|
||
/// <summary> | ||
/// Initializes a new instance of the Chronologer model class with pre-trained weights from the paper | ||
/// Deep learning from harmonized peptide libraries enables retention time prediction of diverse post | ||
/// translational modifications paper (https://github.com/searlelab/chronologer). | ||
/// Eval mode is set to true and training mode is set to false by default. | ||
/// | ||
/// Please use .Predict() for using the model, not .forward(). | ||
/// </summary> | ||
/// <param name="weightsPath"></param> | ||
/// <param name="evalMode"></param> | ||
private Chronologer(string weightsPath, bool evalMode = true) : base(nameof(Chronologer)) | ||
{ | ||
RegisterComponents(); | ||
|
||
LoadWeights(weightsPath);//loads weights from the file | ||
|
||
if (evalMode) | ||
{ | ||
eval(); //evaluation mode doesn't update the weights | ||
train(false); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Do not use for inferring. Use .Predict() instead. Why forward() is not used when predicting outside the training method? -> | ||
/// https://stackoverflow.com/questions/58508190/in-pytorch-what-is-the-difference-between-forward-and-an-ordinary-method | ||
/// </summary> | ||
/// <param name="x"></param> | ||
/// <returns></returns> | ||
public override torch.Tensor forward(torch.Tensor x) | ||
{ | ||
var input = seq_embed.forward(x).transpose(1, -1); | ||
|
||
var residual = input.clone(); //clones the tensor, later will be added to the input (residual connection) | ||
input = conv_layer_1.forward(input); //renet_block | ||
input = norm_layer_1.forward(input); //batch normalization | ||
input = relu.forward(input); //relu activation | ||
input = conv_layer_2.forward(input); //convolutional layer | ||
input = norm_layer_2.forward(input); //batch normalization | ||
input = relu.forward(input); //relu activation | ||
input = term_block.forward(input); //identity block | ||
input = residual + input; //residual connection | ||
input = relu.forward(input); //relu activation | ||
|
||
residual = input.clone(); //clones the tensor, later will be added to the input (residual connection) | ||
input = conv_layer_4.forward(input); //renet_block | ||
input = norm_layer_4.forward(input); //batch normalization | ||
input = relu.forward(input); //relu activation | ||
input = conv_layer_5.forward(input); //convolutional layer | ||
input = norm_layer_5.forward(input); //batch normalization | ||
input = relu.forward(input); //relu activation | ||
input = term_block.forward(input); //identity block | ||
input = residual + input; //residual connection | ||
input = relu.forward(input); //relu activation | ||
|
||
residual = input.clone(); //clones the tensor, later will be added to the input (residual connection) | ||
input = conv_layer_7.forward(input); //renet_block | ||
input = norm_layer_7.forward(input); //batch normalization | ||
input = term_block.forward(input); //identity block | ||
input = relu.forward(input); //relu activation | ||
input = conv_layer_8.forward(input); //convolutional layer | ||
input = norm_layer_8.forward(input); //batch normalization | ||
input = relu.forward(input); //relu activation | ||
input = term_block.forward(input); //identity block | ||
input = residual + input; //residual connection | ||
input = relu.forward(input); //relu activation | ||
|
||
input = dropout.forward(input); //dropout layer | ||
input = flatten.forward(input); //flatten layer | ||
input = output.forward(input); //output layer | ||
|
||
return input; | ||
} | ||
|
||
/// <summary> | ||
/// Loads pre-trained weights from the file Chronologer_20220601193755_TorchSharp.dat. | ||
/// </summary> | ||
/// <param name="weightsPath"></param> | ||
private void LoadWeights(string weightsPath) | ||
{ | ||
//load weights from the file | ||
load(weightsPath, true); | ||
} | ||
|
||
/// <summary> | ||
/// Predicts the retention time of the input peptide sequence. The input must be a torch.Tensor of shape (1, 52). | ||
/// </summary> | ||
/// <param name="input"></param> | ||
/// <returns></returns> | ||
internal torch.Tensor Predict(torch.Tensor input) | ||
{ | ||
return call(input); | ||
} | ||
|
||
//All Modules (shortcut modules are for loading the weights only, not used but required for the weights) | ||
private Embedding seq_embed = torch.nn.Embedding(55, 64, 0); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_1 = torch.nn.Conv1d(64, 64, 1, Padding.Same, dilation: 1); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_2 = torch.nn.Conv1d(64, 64, 7, Padding.Same, dilation: 1); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_3 = torch.nn.Conv1d(64, 64, 1, Padding.Same, dilation: 1); //shortcut | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_4 = torch.nn.Conv1d(64, 64, 1, Padding.Same, dilation: 2); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_5 = torch.nn.Conv1d(64, 64, 7, Padding.Same, dilation: 2); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_6 = torch.nn.Conv1d(64, 64, 1, Padding.Same, dilation: 2); //shortcut | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_7 = torch.nn.Conv1d(64, 64, 1, Padding.Same, dilation: 3); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_8 = torch.nn.Conv1d(64, 64, 7, Padding.Same, dilation: 3); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> conv_layer_9 = torch.nn.Conv1d(64, 64, 1, Padding.Same, dilation: 3); //shortcut | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_1 = torch.nn.BatchNorm1d(64); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_2 = torch.nn.BatchNorm1d(64); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_3 = torch.nn.BatchNorm1d(64); //shortcut | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_4 = torch.nn.BatchNorm1d(64); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_5 = torch.nn.BatchNorm1d(64); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_6 = torch.nn.BatchNorm1d(64); //shortcut | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_7 = torch.nn.BatchNorm1d(64); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_8 = torch.nn.BatchNorm1d(64); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> norm_layer_9 = torch.nn.BatchNorm1d(64); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> term_block = torch.nn.Identity(); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> relu = torch.nn.ReLU(true); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> dropout = torch.nn.Dropout(0.01); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> flatten = torch.nn.Flatten(1); | ||
private torch.nn.Module<torch.Tensor, torch.Tensor> output = torch.nn.Linear(52 * 64, 1); | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extra space