## Initialising FishLeg ##
opt = FishLeg(
model_FishLeg,
aux_loader,
likelihood,
lr=lr,
beta=beta,
weight_decay=weight_decay,
aux_lr=aux_lr,
aux_betas=(0.9, 0.999),
aux_eps=aux_eps,
damping=damping,
update_aux_every=update_aux_every,
writer=writer,
method="antithetic",
method_kwargs={"eps": 1e-4},
precondition_aux=True,
aux_log=True
)
needs to include
device=device
so it becomes
## Initialising FishLeg ##
opt = FishLeg(
model_FishLeg,
aux_loader,
likelihood,
lr=lr,
beta=beta,
weight_decay=weight_decay,
aux_lr=aux_lr,
aux_betas=(0.9, 0.999),
aux_eps=aux_eps,
damping=damping,
update_aux_every=update_aux_every,
writer=writer,
method="antithetic",
method_kwargs={"eps": 1e-4},
precondition_aux=True,
aux_log=True,
device=device
)
because the default for FishLeg.__init__() is to use the cpu, but when using mps or cuda (as is supported in the tutorial) this does not function because the tensors end up on seperate devices resulting in this error
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
or alternatively
Expected all tensors to be on the same device, but found at least two devices, mps and cpu!
this fix has been tested on cuda and cpu, i would appreciate someone checking this on mps
needs to include
device=deviceso it becomes
because the default for
FishLeg.__init__()is to use the cpu, but when using mps or cuda (as is supported in the tutorial) this does not function because the tensors end up on seperate devices resulting in this erroror alternatively
this fix has been tested on cuda and cpu, i would appreciate someone checking this on mps