Skip to content

[Enhancement]: Replace get_module, set_module with torch API get_submodule, set_submodule #1362

@xin3he

Description

@xin3he

Problem Description

Torch has module replace API since 1.x version, we should leverage it instead of maintaining ours

Reproduction Steps

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 30)
        )
        self.decoder = nn.Linear(30, 10)
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Model()

encoder = model.get_submodule('encoder')
print(f"Encoder: {encoder}")

first_layer = model.get_submodule('encoder.0')  # encoder??????????
print(f"First layer: {first_layer}")

relu = model.get_submodule('encoder.1')  # ReLU??
print(f"ReLU: {relu}")

model.set_submodule('encoder.1', nn.GELU())
print(f"After replacement: {model.get_submodule('encoder.1')}")

new_encoder = nn.Sequential(
    nn.Linear(10, 15),
    nn.Tanh(),
    nn.Linear(15, 30)
)
model.set_submodule('encoder', new_encoder)
print(f"New encoder: {model.encoder}")

Environment Information

No response

Error Logs

Additional Context

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions