-
Notifications
You must be signed in to change notification settings - Fork 111
BERT-style mask prediction pretraining #851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
niklasmei
merged 41 commits into
graphnet-team:main
from
niklasmei:maskpred_pretraining
Mar 18, 2026
Merged
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
e0b7243
mad the basis for maskpred pretraining and TheseusDeepIce
niklasmei f387e0a
updated the mask_pred module; currently still in my theseus.py file
niklasmei 408cb3f
updated the maskpred frame; nowa certain percentage is masked and com…
niklasmei e234e03
added loss function
niklasmei 21f7a45
fixed mse loss
niklasmei 8c34835
renamed file and brought up to date
niklasmei f92a13f
added in the charge prediction functionality and learend masking values
niklasmei 1115508
some minor fixes
niklasmei 0e751f6
restored saving
niklasmei eafe2e3
tested and should be ready for pull request
niklasmei 3d38cae
black reformatted
niklasmei d280f9a
further fix for passing checks
niklasmei 172ce83
added docstrings
niklasmei dae5642
another reformatting
niklasmei cb74825
docformatter reformatting
niklasmei e7d8ffb
more formatting
niklasmei 89fe2ae
formatting
niklasmei ef53032
formatting
niklasmei 744d376
mypy fix
niklasmei 4fc261b
formatting? black check fails even though running black on my end lea…
niklasmei 91822d4
avoid mypy problem with type of 'rep'
niklasmei d248288
avoid mypy error due to type of 'rep' variable
niklasmei c982cca
black formatting
niklasmei 4504994
mypy fix
niklasmei 103a40d
again weird behaviour of black
niklasmei 99b5b38
still error from black
niklasmei 8d8448e
formatting
niklasmei a4d0668
removed the loss function, kept rest, added a short example file in g…
niklasmei 6de0871
reformatting again
niklasmei f2707ac
pydocstyle on example
niklasmei 3ff1abf
still fixing example
niklasmei 720f7a0
still example
niklasmei ca9509f
still example
niklasmei 480ea2a
still example
niklasmei 52fd234
adjusted the path in the example file
niklasmei 73a7558
implemented task that handles augmentation and loss calculation as a …
niklasmei a4a07ff
removed old mask_pred_frame
niklasmei 59469a6
minor fix with task
niklasmei 0eedefe
changed from print to logging + minor removals
niklasmei 5f2cbec
black
niklasmei 33b22da
docformatter
niklasmei File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| """Minimal example for use of maskpred pretraining.""" | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| from graphnet.models.pretraining_maskpred import mask_pred_frame | ||
| from graphnet.models.pretraining_maskpred import default_mask_augment | ||
| from graphnet.models.pretraining_maskpred import default_loss_calc | ||
| from graphnet.models import Model | ||
| from torch_geometric.data import Data | ||
| from graphnet.models.data_representation.graphs import KNNGraph | ||
| from graphnet.data.dataset.sqlite.sqlite_dataset import SQLiteDataset | ||
| from graphnet.data.dataloader import DataLoader | ||
| from graphnet.constants import EXAMPLE_DATA_DIR | ||
|
|
||
| from torch_scatter import scatter | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from graphnet.models.detector.prometheus import Prometheus | ||
| from graphnet.models.graphs.nodes import NodesAsPulses | ||
|
|
||
| from graphnet.models.task.task import UnsupervisedTask | ||
|
|
||
|
|
||
| class simple_model(Model): | ||
| """Just for a dummy model.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| ) -> None: | ||
| """Construct.""" | ||
| super().__init__() | ||
| self.net = torch.nn.Sequential( | ||
| torch.nn.Linear(4, 10), torch.nn.SELU(), torch.nn.Linear(10, 5) | ||
| ) | ||
|
|
||
| def forward(self, data: Data) -> Tuple[Tensor, Tensor]: | ||
| """Forward pass.""" | ||
| x = self.net(data.x) | ||
| x_rep = scatter(src=x, index=data.batch, dim=0, reduce="max") | ||
| return x, x_rep | ||
|
|
||
|
|
||
| def test() -> None: | ||
| """Short test with saving at the end.""" | ||
| graph_definition = KNNGraph( | ||
| detector=Prometheus(), | ||
| node_definition=NodesAsPulses(), | ||
| nb_nearest_neighbours=8, | ||
| ) | ||
|
|
||
| dataset = SQLiteDataset( | ||
| path=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", | ||
| pulsemaps="total", | ||
| truth_table="mc_truth", | ||
| features=["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t", "q"], | ||
| truth=["injection_energy", "injection_zenith"], | ||
| data_representation=graph_definition, | ||
| ) | ||
|
|
||
| dataloader = DataLoader( | ||
| dataset, | ||
| batch_size=3, | ||
| num_workers=10, | ||
| ) | ||
|
|
||
| for batch in dataloader: | ||
| data = batch | ||
| break | ||
|
|
||
| dummy_model = simple_model() | ||
| default_task = UnsupervisedTask( | ||
| default_mask_augment(), default_loss_calc() | ||
| ) | ||
|
|
||
| model = mask_pred_frame( | ||
| encoder=dummy_model, | ||
| bert_task=default_task, | ||
| encoder_out_dim=5, | ||
| need_charge_rep=False, | ||
| ) | ||
|
|
||
| out = model(data) | ||
| print(out) | ||
|
|
||
| # for training | ||
| # model.fit(train_dataloader=dataloader, max_epochs=10, gpus=1) | ||
|
|
||
| # for saving | ||
| # model.save_pretrained_model('some/path') | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these comments intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I commented the training and the saving so that the example file runs even if someone does not have a gpu available and also without actually saving a model that does not actually do anything. I can uncomment or remove if you want me to