Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e0b7243
mad the basis for maskpred pretraining and TheseusDeepIce
niklasmei Jun 13, 2025
f387e0a
updated the mask_pred module; currently still in my theseus.py file
niklasmei Jun 13, 2025
408cb3f
updated the maskpred frame; nowa certain percentage is masked and com…
niklasmei Jun 26, 2025
e234e03
added loss function
niklasmei Jul 2, 2025
21f7a45
fixed mse loss
niklasmei Jul 2, 2025
8c34835
renamed file and brought up to date
niklasmei Jul 15, 2025
f92a13f
added in the charge prediction functionality and learend masking values
niklasmei Oct 7, 2025
1115508
some minor fixes
niklasmei Oct 7, 2025
0e751f6
restored saving
niklasmei Oct 7, 2025
eafe2e3
tested and should be ready for pull request
niklasmei Dec 1, 2025
3d38cae
black reformatted
niklasmei Dec 1, 2025
d280f9a
further fix for passing checks
niklasmei Dec 1, 2025
172ce83
added docstrings
niklasmei Dec 1, 2025
dae5642
another reformatting
niklasmei Dec 1, 2025
cb74825
docformatter reformatting
niklasmei Dec 1, 2025
e7d8ffb
more formatting
niklasmei Dec 1, 2025
89fe2ae
formatting
niklasmei Dec 1, 2025
ef53032
formatting
niklasmei Dec 1, 2025
744d376
mypy fix
niklasmei Dec 1, 2025
4fc261b
formatting? black check fails even though running black on my end lea…
niklasmei Dec 1, 2025
91822d4
avoid mypy problem with type of 'rep'
niklasmei Dec 1, 2025
d248288
avoid mypy error due to type of 'rep' variable
niklasmei Dec 1, 2025
c982cca
black formatting
niklasmei Dec 1, 2025
4504994
mypy fix
niklasmei Dec 1, 2025
103a40d
again weird behaviour of black
niklasmei Dec 1, 2025
99b5b38
still error from black
niklasmei Dec 1, 2025
8d8448e
formatting
niklasmei Dec 2, 2025
a4d0668
removed the loss function, kept rest, added a short example file in g…
niklasmei Jan 23, 2026
6de0871
reformatting again
niklasmei Jan 23, 2026
f2707ac
pydocstyle on example
niklasmei Jan 23, 2026
3ff1abf
still fixing example
niklasmei Jan 23, 2026
720f7a0
still example
niklasmei Jan 23, 2026
ca9509f
still example
niklasmei Jan 23, 2026
480ea2a
still example
niklasmei Jan 23, 2026
52fd234
adjusted the path in the example file
niklasmei Jan 26, 2026
73a7558
implemented task that handles augmentation and loss calculation as a …
niklasmei Feb 23, 2026
a4a07ff
removed old mask_pred_frame
niklasmei Feb 23, 2026
59469a6
minor fix with task
niklasmei Feb 23, 2026
0eedefe
changed from print to logging + minor removals
niklasmei Mar 17, 2026
5f2cbec
black
niklasmei Mar 17, 2026
33b22da
docformatter
niklasmei Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions examples/pretraining_maskpred_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Minimal example for use of maskpred pretraining."""

from typing import Tuple

from graphnet.models.pretraining_maskpred import mask_pred_frame
from graphnet.models.pretraining_maskpred import default_mask_augment
from graphnet.models.pretraining_maskpred import default_loss_calc
from graphnet.models import Model
from torch_geometric.data import Data
from graphnet.models.data_representation.graphs import KNNGraph
from graphnet.data.dataset.sqlite.sqlite_dataset import SQLiteDataset
from graphnet.data.dataloader import DataLoader
from graphnet.constants import EXAMPLE_DATA_DIR

from torch_scatter import scatter

import torch
from torch import Tensor

from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.graphs.nodes import NodesAsPulses

from graphnet.models.task.task import UnsupervisedTask


class simple_model(Model):
"""Just for a dummy model."""

def __init__(
self,
) -> None:
"""Construct."""
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(4, 10), torch.nn.SELU(), torch.nn.Linear(10, 5)
)

def forward(self, data: Data) -> Tuple[Tensor, Tensor]:
"""Forward pass."""
x = self.net(data.x)
x_rep = scatter(src=x, index=data.batch, dim=0, reduce="max")
return x, x_rep


def test() -> None:
"""Short test with saving at the end."""
graph_definition = KNNGraph(
detector=Prometheus(),
node_definition=NodesAsPulses(),
nb_nearest_neighbours=8,
)

dataset = SQLiteDataset(
path=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db",
pulsemaps="total",
truth_table="mc_truth",
features=["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t", "q"],
truth=["injection_energy", "injection_zenith"],
data_representation=graph_definition,
)

dataloader = DataLoader(
dataset,
batch_size=3,
num_workers=10,
)

for batch in dataloader:
data = batch
break

dummy_model = simple_model()
default_task = UnsupervisedTask(
default_mask_augment(), default_loss_calc()
)

model = mask_pred_frame(
encoder=dummy_model,
bert_task=default_task,
encoder_out_dim=5,
need_charge_rep=False,
)

out = model(data)
print(out)

# for training
# model.fit(train_dataloader=dataloader, max_epochs=10, gpus=1)

# for saving
# model.save_pretrained_model('some/path')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these comments intended?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I commented the training and the saving so that the example file runs even if someone does not have a gpu available and also without actually saving a model that does not actually do anything. I can uncomment or remove if you want me to



if __name__ == "__main__":
test()
Loading
Loading