-
Notifications
You must be signed in to change notification settings - Fork 112
Open
Labels
featureNew feature or requestNew feature or request
Description
The function:
graphnet/src/graphnet/models/easy_model.py
Line 323 in 401f28b
| def predict_as_dataframe( |
Can be really slow when querying for additional attributes, which is handled here:
graphnet/src/graphnet/models/easy_model.py
Lines 381 to 397 in 401f28b
| for batch in dataloader: | |
| for attr in attributes: | |
| attribute = batch[attr] | |
| if isinstance(attribute, torch.Tensor): | |
| attribute = attribute.detach().cpu().numpy() | |
| # Check if node level predictions | |
| # If true, additional attributes are repeated | |
| # to make dimensions fit | |
| if pulse_level_predictions: | |
| if len(attribute) < np.sum( | |
| batch.n_pulses.detach().cpu().numpy() | |
| ): | |
| attribute = np.repeat( | |
| attribute, batch.n_pulses.detach().cpu().numpy() | |
| ) | |
| attributes[attr].extend(attribute) |
In the prediction loop, we are already looping through the whole data. Can we think of a way to utilise that initial loop already to get additional attributes without looping through the dataloader again? I think the least we can do is put a progress bar in the second loop to let the user know that the program hasn't stalled (I was very confused at times, why does it take so long.?).
Metadata
Metadata
Assignees
Labels
featureNew feature or requestNew feature or request