We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f0e7da3 commit f623e3eCopy full SHA for f623e3e
1 file changed
src/pquant/core/torch/layers.py
@@ -608,6 +608,7 @@ def add_compression_layers(model, config, input_shape=None):
608
model.to("cuda")
609
if input_shape is not None:
610
model(torch.rand(input_shape).to("cuda"))
611
+ model.to("cuda")
612
return model
613
614
0 commit comments