Skip to content

slow predict_as_dataframe function #830

@sevmag

Description

@sevmag

The function:

def predict_as_dataframe(

Can be really slow when querying for additional attributes, which is handled here:

for batch in dataloader:
for attr in attributes:
attribute = batch[attr]
if isinstance(attribute, torch.Tensor):
attribute = attribute.detach().cpu().numpy()
# Check if node level predictions
# If true, additional attributes are repeated
# to make dimensions fit
if pulse_level_predictions:
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
attributes[attr].extend(attribute)

In the prediction loop, we are already looping through the whole data. Can we think of a way to utilise that initial loop already to get additional attributes without looping through the dataloader again? I think the least we can do is put a progress bar in the second loop to let the user know that the program hasn't stalled (I was very confused at times, why does it take so long.?).

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions