Skip to content

Error when running explore_MNIST.ipynb #1

@PikaPei

Description

@PikaPei

Hello!

I read your preprint and found it really interesting!
I tried running the explore_MNIST.ipynb script, but I ran into the following issue:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[11], line 9
      5 data_generator.manual_seed(data_seed)
      7 network = get_network(config_file_path, network_seed)
----> 9 train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)
     11 analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)
     13 DTP_LDS_network = network

Cell In[4], line 2, in train_network_mnist(network, train_dataloader, val_dataloader, train_steps)
      1 def train_network_mnist(network, train_dataloader, val_dataloader, train_steps):
----> 2     network.train(train_dataloader, val_dataloader, samples_per_epoch=train_steps, val_interval=(0, -1, 100), store_history=True, \
      3                   store_history_interval=(0, -1, 100), status_bar=True)

File ~/project/EIANN/EIANN/network.py:538, in Network.train(self, train_dataloader, val_dataloader, epochs, val_interval, samples_per_epoch, store_history, store_dynamics, store_params, store_history_interval, store_params_interval, save_to_file, status_bar)
    536             post_pop.bias_learning_rule.step()
    537         for projection in post_pop:
--> 538             projection.learning_rule.step()
    540 self.constrain_weights_and_biases()
    542 # Update learning rule parameters

File ~/project/EIANN/EIANN/rules/dendritic_loss.py:36, in DendriticLoss_6.step(self)
     31     delta_weight = torch.outer(
     32         torch.clamp(self.projection.post.forward_dendritic_state.detach().clone(), min=-1, max=1),
     33         torch.clamp(self.projection.pre.forward_prev_activity, min=0, max=1))
     35 # Ensure delta_weight has same dtype as weight for AMP compatibility
---> 36 if self.projection.post.network.use_amp and delta_weight.dtype != self.projection.weight.dtype:
     37     delta_weight = delta_weight.to(self.projection.weight.dtype)
     39 self.projection.weight.data += self.sign * self.learning_rate * delta_weight

File ~/project/EIANN/.pixi/envs/default/lib/python3.12/site-packages/torch/nn/modules/module.py:1964, in Module.__getattr__(self, name)
   1962     if name in modules:
   1963         return modules[name]
-> 1964 raise AttributeError(
   1965     f"'{type(self).__name__}' object has no attribute '{name}'"
   1966 )

AttributeError: 'Network' object has no attribute 'use_amp'

Any advice on how to fix this would be greatly appreciated!
Please let me know if you need any additional information. Thank you.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions