Problem Description
Torch has module replace API since 1.x version, we should leverage it instead of maintaining ours
Reproduction Steps
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 30)
)
self.decoder = nn.Linear(30, 10)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
model = Model()
encoder = model.get_submodule('encoder')
print(f"Encoder: {encoder}")
first_layer = model.get_submodule('encoder.0') # encoder??????????
print(f"First layer: {first_layer}")
relu = model.get_submodule('encoder.1') # ReLU??
print(f"ReLU: {relu}")
model.set_submodule('encoder.1', nn.GELU())
print(f"After replacement: {model.get_submodule('encoder.1')}")
new_encoder = nn.Sequential(
nn.Linear(10, 15),
nn.Tanh(),
nn.Linear(15, 30)
)
model.set_submodule('encoder', new_encoder)
print(f"New encoder: {model.encoder}")
Environment Information
No response
Error Logs
Additional Context
No response
Problem Description
Torch has module replace API since 1.x version, we should leverage it instead of maintaining ours
Reproduction Steps
Environment Information
No response
Error Logs
Additional Context
No response