Skip to content

Commit 4e7900a

Browse files
authored
Merge pull request #38 from LorenzLamm/dataloading_adjustments
Dataloading adjustments
2 parents 7d986f8 + 0ec47e9 commit 4e7900a

7 files changed

Lines changed: 60 additions & 22 deletions

File tree

src/membrain_pick/cli/predict_cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def predict(
5151
verbose: bool = Option(
5252
True, help="Should the prediction progress bar be printed?"
5353
), # noqa: B008
54+
num_workers: int = Option(
55+
None, help="Number of workers for the DataLoader."
56+
), # noqa: B008
5457
):
5558
"""Predict the output of the trained model on the given data.
5659
@@ -80,6 +83,7 @@ def predict(
8083
mean_shift_score_threshold=mean_shift_score_threshold,
8184
mean_shift_device=mean_shift_device,
8285
verbose=verbose,
86+
num_workers=num_workers,
8387
)
8488

8589

src/membrain_pick/cli/train_cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def train(
3030
verbose: bool = Option(
3131
True, help="Should the training progress bar be printed?"
3232
), # noqa: B008
33+
num_workers: int = Option(
34+
None, help="Number of workers for the DataLoader."
35+
), # noqa: B008
3336
):
3437
"""Train a diffusion net model.
3538
@@ -63,6 +66,7 @@ def train(
6366
mean_shift_output=False,
6467
max_epochs=max_epochs,
6568
verbose=verbose,
69+
num_workers=num_workers,
6670
)
6771

6872

@@ -126,6 +130,9 @@ def train_advanced(
126130
verbose: bool = Option(
127131
True, help="Should the training progress bar be printed?"
128132
), # noqa: B008
133+
num_workers: int = Option(
134+
None, help="Number of workers for the DataLoader."
135+
), # noqa: B008
129136
):
130137
"""Train a diffusion net model.
131138
@@ -163,4 +170,5 @@ def train_advanced(
163170
mean_shift_margin=mean_shift_margin,
164171
max_epochs=max_epochs,
165172
verbose=verbose,
173+
num_workers=num_workers,
166174
)

src/membrain_pick/dataloading/diffusionnet_datamodule.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,24 @@
99
def custom_collate(batch):
1010
"""Custom collate function to handle a complex data structure.
1111
12-
Each sample is a dictionary containing numpy arrays and another dictionary
13-
with sparse matrices. Since we're using a batch size of 1, this function
14-
simplifies the handling of these structures.
15-
1612
Args:
17-
batch: A list of samples, where each sample is the complex data structure
18-
described above.
13+
batch: A list of samples, where each sample is the complex data structure.
1914
2015
Returns:
2116
Processed batch ready for model input.
2217
"""
23-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2418
# Unpack the single sample from the batch
2519
sample = batch[0]
26-
# Initialize a new dictionary to store the processed sample
2720
processed_sample = {}
2821

2922
for key, value in sample.items():
3023
if isinstance(value, np.ndarray):
3124
# Convert numpy arrays to tensors
32-
processed_sample[key] = torch.tensor(value).to(device)
25+
processed_sample[key] = torch.tensor(value)
3326
elif isinstance(value, dict):
34-
# For the nested dictionary, we assume it contains sparse matrices
35-
# and pass it through directly without modifications
27+
# For the nested dictionary, directly pass it through without GPU operations
3628
processed_sample[key] = {
37-
subkey: subvalue.to(device) for subkey, subvalue in value.items()
29+
subkey: subvalue for subkey, subvalue in value.items()
3830
}
3931
else:
4032
# Directly pass through any other types of values

src/membrain_pick/networks/diffusion_net/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,21 @@ def random_rotate_points_y(pts):
5757
# Numpy things
5858

5959

60-
# Numpy sparse matrix to pytorch
6160
def sparse_np_to_torch(A):
62-
Acoo = A.tocoo()
63-
values = Acoo.data
64-
indices = np.vstack((Acoo.row, Acoo.col))
61+
"""
62+
Converts a numpy sparse matrix to a PyTorch sparse tensor.
63+
64+
Args:
65+
A: A scipy sparse matrix (e.g., COO, CSR).
66+
67+
Returns:
68+
PyTorch sparse tensor.
69+
"""
70+
Acoo = A.tocoo() # Convert to COO format if not already
71+
values = torch.tensor(Acoo.data, dtype=torch.float32)
72+
indices = torch.tensor(np.vstack((Acoo.row, Acoo.col)), dtype=torch.int64)
6573
shape = Acoo.shape
66-
return torch.sparse.FloatTensor(
67-
torch.LongTensor(indices), torch.FloatTensor(values), torch.Size(shape)
68-
).coalesce()
74+
return torch.sparse_coo_tensor(indices, values, torch.Size(shape)).coalesce()
6975

7076

7177
# Pytorch sparse to numpy csc matrix

src/membrain_pick/optimization/diffusion_training_pylit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(self, batch):
103103
return out
104104

105105
def configure_optimizers(self):
106-
optimizer = Adam(self.parameters(), lr=1e-3)
106+
optimizer = Adam(self.parameters(), lr=1e-3 * 5)
107107
scheduler = {
108108
"scheduler": LambdaLR(
109109
optimizer, lr_lambda=lambda epoch: (1 - epoch / self.max_epochs) ** 0.9

src/membrain_pick/predict.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from membrain_pick.dataloading.diffusionnet_datamodule import (
88
MemSegDiffusionNetDataModule,
99
)
10+
from membrain_pick.train import get_optimal_num_workers
1011
from membrain_pick.optimization.diffusion_training_pylit import DiffusionNetModule
1112

1213
from membrain_pick.dataloading.data_utils import (
@@ -99,6 +100,7 @@ def predict(
99100
# mean_shift_device: str = "cuda:0",
100101
mean_shift_device: str = "cpu",
101102
verbose: bool = True,
103+
num_workers: int = None,
102104
):
103105
"""Predict the output of the trained model on the given data.
104106
@@ -120,7 +122,9 @@ def predict(
120122
k_eig=k_eig,
121123
batch_size=1,
122124
force_recompute=force_recompute_partitioning,
123-
num_workers=0,
125+
num_workers=(
126+
num_workers if num_workers is not None else get_optimal_num_workers()
127+
),
124128
pin_memory=False,
125129
overfit=False,
126130
)
@@ -164,6 +168,13 @@ def predict(
164168
outputs = []
165169
for i in range(all_diffusion_feature.shape[1] - 15):
166170
batch["diffusion_inputs"]["features"] = all_diffusion_feature[:, i : i + 16]
171+
# put the batch on the device
172+
for key in batch:
173+
if isinstance(batch[key], torch.Tensor):
174+
batch[key] = batch[key].to(device)
175+
elif isinstance(batch[key], dict):
176+
for sub_key in batch[key]:
177+
batch[key][sub_key] = batch[key][sub_key].to(device)
167178
with torch.no_grad():
168179
output = model(batch)
169180
outputs.append(output["mse"].squeeze().detach().cpu().numpy())

src/membrain_pick/train.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
from membrain_pick.optimization.diffusion_training_pylit import DiffusionNetModule
1212

1313

14+
def get_optimal_num_workers():
15+
"""
16+
Dynamically determine an optimal number of DataLoader workers.
17+
18+
Returns:
19+
int: Recommended number of workers.
20+
"""
21+
cpu_count = os.cpu_count()
22+
if not cpu_count:
23+
return 0 # Fallback if CPU count is unavailable
24+
return min(cpu_count // 2, 16)
25+
26+
1427
def train(
1528
data_dir: str,
1629
training_dir: str = "./training_output",
@@ -43,6 +56,7 @@ def train(
4356
# Training parameters
4457
max_epochs: int = 1000,
4558
verbose: bool = True,
59+
num_workers: int = None,
4660
):
4761

4862
train_path = os.path.join(data_dir, "train")
@@ -69,7 +83,9 @@ def train(
6983
position_tokens=position_tokens,
7084
k_eig=k_eig,
7185
batch_size=1,
72-
num_workers=0,
86+
num_workers=(
87+
num_workers if num_workers is not None else get_optimal_num_workers()
88+
),
7389
pin_memory=False,
7490
)
7591
data_module.setup()
@@ -135,6 +151,7 @@ def on_epoch_start(self, trainer, pl_module):
135151
],
136152
max_epochs=max_epochs,
137153
enable_progress_bar=verbose,
154+
accumulate_grad_batches=16,
138155
)
139156

140157
# Start the training process

0 commit comments

Comments
 (0)