@@ -49,22 +49,28 @@ class PerceptualLoss(nn.Module):
4949
5050 Args:
5151 spatial_dims: number of spatial dimensions.
52- network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``,
53- ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``}
54- Specifies the network architecture to use. Defaults to ``"alex"``.
52+ network_type: type of network for perceptual loss. One of:
53+ - "alex"
54+ - "vgg"
55+ - "squeeze"
56+ - "radimagenet_resnet50"
57+ - "medicalnet_resnet10_23datasets"
58+ - "medicalnet_resnet50_23datasets"
59+ - "resnet50"
60+ >>>>>>> a32e2a80e (Fix docstring formatting for network_type)
5561 is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
5662 fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
5763 cache_dir: path to cache directory to save the pretrained network weights.
5864 pretrained: whether to load pretrained weights. This argument only works when using networks from
59- LIPIS or Torchvision. Defaults to ``" True" ``.
65+ LIPIS or Torchvision. Defaults to ``True``.
6066 pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
6167 via using this argument. This argument only works when ``"network_type"`` is "resnet50".
6268 Defaults to `None`.
6369 pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
6470 extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
6571 Defaults to `None`.
6672 channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
67- Defaults to ``False``.
73+ Defaults to ``False``.
6874 """
6975
7076 def __init__ (
0 commit comments