Skip to content

Update Training a Classifier tutorial — replace deprecated transforms, use torch.accelerator #3878

@sekyondaMeta

Description

@sekyondaMeta

Description

The Training a Classifier tutorial contains deprecated APIs and outdated patterns that should be modernized.

Changes needed

Deprecated APIs

Issue Current Code Replacement Since
transforms.ToTensor() is deprecated transforms.ToTensor() transforms.v2.ToImage() + transforms.v2.ToDtype(torch.float32, scale=True) torchvision 0.16+ (v2 transforms). ToTensor() deprecated in torchvision 0.18, will be removed in a future version.
transforms.Normalize() from v1 namespace transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) transforms.v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) Should use v2 transforms namespace for consistency if migrating to v2.
transforms.Compose() from v1 namespace transforms.Compose([...]) transforms.v2.Compose([...]) Should use v2 transforms namespace for consistency.

Suboptimal / Outdated Patterns

Issue Current Code Modern Alternative Notes
Device selection only checks CUDA torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') torch.device(torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu') Does not account for MPS or XPU. torch.accelerator API available since PyTorch 2.4.
Model save uses .pth extension torch.save(net.state_dict(), './cifar_net.pth') Use .pt extension: './cifar_net.pt' PyTorch convention now prefers .pt over .pth to avoid confusion with Python path config files.
Links to GitHub master branch https://github.com/pytorch/examples/tree/master/imagenet https://github.com/pytorch/examples/tree/main/imagenet PyTorch repos have migrated default branch from master to main.

Files

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions