Skip to content

Commit e68d98d

Browse files
authored
Merge pull request #3 from CompOmics/Batch-prediction
Batch prediction and clear error log
2 parents 388772e + 1ebd382 commit e68d98d

7 files changed

Lines changed: 395 additions & 136 deletions

File tree

ideeplc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""iDeepLC: A deep Learning-based retention time predictor for unseen modified peptides with a novel encoding system"""
22

3-
__version__ = "1.3.1"
3+
__version__ = "1.3.2"

ideeplc/data_initialize.py

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import logging
2-
from typing import Tuple, Union
2+
from typing import Tuple, Union, Iterator
33
import pandas as pd
44
import numpy as np
5-
from torch.utils.data import Dataset, DataLoader
5+
from torch.utils.data import Dataset
66
from ideeplc.utilities import df_to_matrix, reform_seq
77

88
LOGGER = logging.getLogger(__name__)
99

1010

11-
# Making the pytorch dataset
1211
class MyDataset(Dataset):
1312
def __init__(self, sequences: np.ndarray, retention: np.ndarray) -> None:
1413
self.sequences = sequences
@@ -25,15 +24,14 @@ def data_initialize(
2524
csv_path: str, **kwargs
2625
) -> Union[Tuple[MyDataset, np.ndarray], Tuple[MyDataset, np.ndarray]]:
2726
"""
28-
Initialize peptides matrices based on a CSV file containing raw peptide sequences.
27+
Initialize peptide matrices based on a CSV file containing raw peptide sequences.
2928
3029
:param csv_path: Path to the CSV file containing raw peptide sequences.
31-
:return: DataLoader for prediction.
30+
:return: Dataset for prediction or fine-tuning and x_shape.
3231
"""
33-
3432
LOGGER.info(f"Loading peptides from {csv_path}")
33+
3534
try:
36-
# Load peptides from CSV file
3735
df = pd.read_csv(csv_path)
3836
except FileNotFoundError:
3937
LOGGER.error(f"File {csv_path} not found.")
@@ -63,22 +61,108 @@ def data_initialize(
6361
LOGGER.info(
6462
f"Loaded and reformed {len(reformed_peptides)} peptides sequences from the file."
6563
)
64+
6665
try:
67-
# Convert sequences to matrix format
6866
sequences, tr, errors = df_to_matrix(reformed_peptides, df)
6967
except Exception as e:
7068
LOGGER.error(f"Error converting sequences to matrix format: {e}")
7169
raise
70+
7271
if errors:
7372
LOGGER.warning(f"Errors encountered during conversion: {errors}")
7473

7574
prediction_dataset = MyDataset(sequences, tr)
7675

77-
# Create DataLoader objects
78-
dataloader_pred = DataLoader(prediction_dataset)
79-
# passing the training X shape
80-
for batch in dataloader_pred:
81-
x_shape = batch[0].shape
82-
break
76+
if len(prediction_dataset) == 0:
77+
LOGGER.error("No valid peptide entries were found in the input file.")
78+
raise ValueError("No valid peptide entries were found in the input file.")
79+
80+
# Keep historical x_shape contract expected by model/tests: (batch, channels, length)
81+
x_shape = (1,) + prediction_dataset[0][0].shape
8382
LOGGER.info(f"Dataset initialized with data shape {x_shape}.")
8483
return prediction_dataset, x_shape
84+
85+
86+
def data_initialize_chunked(
87+
csv_path: str, chunk_size: int = 10000, **kwargs
88+
) -> Iterator[Tuple[pd.DataFrame, MyDataset, np.ndarray]]:
89+
"""
90+
Initialize peptide matrices from a CSV file in chunks.
91+
92+
:param csv_path: Path to the CSV file containing raw peptide sequences.
93+
:param chunk_size: Number of rows to load per chunk.
94+
:return: Iterator yielding dataframe chunk, dataset chunk, and x_shape.
95+
"""
96+
LOGGER.info(f"Loading peptides from {csv_path} in chunks of {chunk_size}")
97+
98+
try:
99+
chunk_iter = pd.read_csv(csv_path, chunksize=chunk_size)
100+
except FileNotFoundError:
101+
LOGGER.error(f"File {csv_path} not found.")
102+
raise
103+
except pd.errors.EmptyDataError:
104+
LOGGER.error(f"File {csv_path} is empty.")
105+
raise
106+
except Exception as e:
107+
LOGGER.error(f"Error reading {csv_path}: {e}")
108+
raise
109+
110+
for chunk_idx, df in enumerate(chunk_iter, start=1):
111+
if "seq" not in df.columns:
112+
LOGGER.error("CSV file must contain a 'seq' column with peptide sequences.")
113+
raise ValueError("Missing 'seq' column in the CSV file.")
114+
if "modifications" not in df.columns:
115+
LOGGER.error(
116+
"CSV file must contain a 'modifications' column with peptide modifications."
117+
)
118+
raise ValueError("Missing 'modifications' column in the CSV file.")
119+
if "tr" not in df.columns:
120+
LOGGER.error("CSV file must contain a 'tr' column with retention times.")
121+
raise ValueError("Missing 'tr' column in the CSV file.")
122+
123+
reformed_peptides = [
124+
reform_seq(seq, mod) for seq, mod in zip(df["seq"], df["modifications"])
125+
]
126+
LOGGER.info(
127+
f"Chunk {chunk_idx}: loaded and reformed {len(reformed_peptides)} peptides sequences."
128+
)
129+
130+
try:
131+
sequences, tr, errors = df_to_matrix(reformed_peptides, df)
132+
except Exception as e:
133+
LOGGER.error(
134+
f"Error converting sequences to matrix format in chunk {chunk_idx}: {e}"
135+
)
136+
raise
137+
138+
if errors:
139+
LOGGER.warning(f"Errors encountered during conversion in chunk {chunk_idx}: {errors}")
140+
141+
prediction_dataset = MyDataset(sequences, tr)
142+
143+
if len(prediction_dataset) == 0:
144+
LOGGER.warning(f"Chunk {chunk_idx} contains no valid peptide entries.")
145+
continue
146+
147+
# Keep historical x_shape contract expected by model/tests: (batch, channels, length)
148+
x_shape = (1,) + prediction_dataset[0][0].shape
149+
LOGGER.info(f"Chunk {chunk_idx} initialized with data shape {x_shape}.")
150+
yield df, prediction_dataset, x_shape
151+
152+
153+
def get_input_shape_from_first_chunk(csv_path: str, chunk_size: int = 10000):
154+
"""
155+
Get the input shape from the first valid chunk of a CSV file.
156+
157+
:param csv_path: Path to the CSV file containing raw peptide sequences.
158+
:param chunk_size: Number of rows to load per chunk.
159+
:return: x_shape for model initialization.
160+
"""
161+
for _, dataset_chunk, x_shape in data_initialize_chunked(
162+
csv_path=csv_path, chunk_size=chunk_size
163+
):
164+
LOGGER.info(f"Detected input shape from first valid chunk: {x_shape}")
165+
return x_shape
166+
167+
LOGGER.error("No valid chunks found in the input file.")
168+
raise ValueError("No valid chunks found in the input file.")

ideeplc/fine_tuning.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(
2424
validation_data=None,
2525
validation_split=0.1,
2626
patience=5,
27+
num_workers=0,
28+
pin_memory=False,
2729
):
2830
"""
2931
Initialize the fine-tuner with the model and data loaders.
@@ -38,6 +40,8 @@ def __init__(
3840
:param validation_data: Optional validation dataset.
3941
:param validation_split: Fraction of training data to use for validation.
4042
:param patience: Number of epochs with no improvement after which training will be stopped.
43+
:param num_workers: Number of workers for the DataLoader.
44+
:param pin_memory: Whether to pin memory in the DataLoader.
4145
"""
4246
self.model = model.to(device)
4347
self.train_data = train_data
@@ -49,6 +53,8 @@ def __init__(
4953
self.validation_data = validation_data
5054
self.validation_split = validation_split
5155
self.patience = patience
56+
self.num_workers = num_workers
57+
self.pin_memory = pin_memory
5258

5359
def _freeze_layers(self, layers_to_freeze):
5460
"""
@@ -71,34 +77,52 @@ def prepare_data(self, data, shuffle=True):
7177
:param shuffle: Whether to shuffle the data.
7278
:return: DataLoader for the dataset.
7379
"""
74-
return DataLoader(data, batch_size=self.batch_size, shuffle=shuffle)
80+
return DataLoader(
81+
data,
82+
batch_size=self.batch_size,
83+
shuffle=shuffle,
84+
num_workers=self.num_workers,
85+
pin_memory=self.pin_memory,
86+
)
7587

7688
def fine_tune(self, layers_to_freeze=None):
7789
"""
7890
Fine-tune the iDeepLC model on the training dataset.
7991
8092
:param layers_to_freeze: List of layer names to freeze during fine-tuning.
93+
:return: Best model based on validation loss.
8194
"""
8295
LOGGER.info("Starting fine-tuning...")
96+
8397
if layers_to_freeze:
8498
self._freeze_layers(layers_to_freeze)
8599

86-
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
100+
optimizer = torch.optim.Adam(
101+
filter(lambda p: p.requires_grad, self.model.parameters()),
102+
lr=self.learning_rate,
103+
)
87104
loss_fn = self.loss_function
88-
# Prepare DataLoader
89-
if self.validation_data:
90-
dataloader_train = self.prepare_data(self.train_data)
105+
106+
if self.validation_data is not None:
107+
dataloader_train = self.prepare_data(self.train_data, shuffle=True)
91108
dataloader_val = self.prepare_data(self.validation_data, shuffle=False)
92109
else:
93-
# Split the training data into training and validation sets
94110
train_size = int((1 - self.validation_split) * len(self.train_data))
95111
val_size = len(self.train_data) - train_size
112+
113+
if train_size == 0 or val_size == 0:
114+
raise ValueError(
115+
"Training dataset is too small for the requested validation split."
116+
)
117+
96118
train_dataset, val_dataset = torch.utils.data.random_split(
97119
self.train_data, [train_size, val_size]
98120
)
99-
dataloader_train = self.prepare_data(train_dataset)
121+
dataloader_train = self.prepare_data(train_dataset, shuffle=True)
100122
dataloader_val = self.prepare_data(val_dataset, shuffle=False)
123+
101124
LOGGER.info(f"Training on {len(dataloader_train.dataset)} samples.")
125+
LOGGER.info(f"Validating on {len(dataloader_val.dataset)} samples.")
102126

103127
best_model = copy.deepcopy(self.model)
104128
best_loss = float("inf")
@@ -107,15 +131,15 @@ def fine_tune(self, layers_to_freeze=None):
107131
for epoch in range(self.epochs):
108132
self.model.train()
109133
running_loss = 0.0
134+
110135
for batch in dataloader_train:
111136
inputs, target = batch
112-
inputs, target = inputs.to(self.device), target.to(self.device)
137+
inputs = inputs.to(self.device, non_blocking=True)
138+
target = target.to(self.device, non_blocking=True)
113139

114-
# Forward pass
115140
outputs = self.model(inputs.float())
116141
loss = loss_fn(outputs, target.float().view(-1, 1))
117142

118-
# Backward pass and optimization
119143
optimizer.zero_grad()
120144
loss.backward()
121145
optimizer.step()
@@ -125,22 +149,24 @@ def fine_tune(self, layers_to_freeze=None):
125149
avg_loss = running_loss / len(dataloader_train.dataset)
126150
LOGGER.info(f"Epoch [{epoch + 1}/{self.epochs}], Loss: {avg_loss:.4f}")
127151

128-
# Validate the model after each epoch
129-
if dataloader_val:
130-
val_loss, _, _, _ = validate(
131-
self.model, dataloader_val, loss_fn, self.device
152+
val_loss, _, _, _ = validate(
153+
self.model, dataloader_val, loss_fn, self.device
154+
)
155+
156+
if val_loss < best_loss:
157+
best_loss = val_loss
158+
best_model = copy.deepcopy(self.model)
159+
patience_counter = 0
160+
LOGGER.info(f"New best validation loss: {best_loss:.4f}")
161+
else:
162+
patience_counter += 1
163+
LOGGER.info(
164+
f"No improvement in validation loss. Patience: {patience_counter}/{self.patience}"
132165
)
133-
if val_loss < best_loss:
134-
best_loss = val_loss
135-
best_model = copy.deepcopy(self.model)
136-
patience_counter = 0
137-
LOGGER.info(f"New best validation loss: {best_loss:.4f}")
138-
else:
139-
patience_counter += 1
140-
141-
if patience_counter >= self.patience:
142-
LOGGER.info("Early stopping triggered.")
143-
break
166+
167+
if patience_counter >= self.patience:
168+
LOGGER.info("Early stopping triggered.")
169+
break
144170

145171
LOGGER.info("Fine-tuning complete.")
146-
return best_model
172+
return best_model

0 commit comments

Comments
 (0)