Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,38 +921,63 @@ def _to_model_dtype(model, model_dtype):
return model


def get_module(module, key):
"""Get module from model by key name.
def get_attr(module, key):
"""Get attribute from module by key name.

This function can access both modules and their attributes (like weight, bias).
For accessing only modules, prefer using get_module which uses PyTorch's native API.

Args:
module (torch.nn.Module): original model
key (str): module name to be replaced
key (str): attribute name (e.g., "layer.weight", "layer.bias")

Returns:
The attribute value, or None if not found
"""
name_list = key.split(".")
for name in name_list:
module = getattr(module, name, None)
return module


def set_module(model, key, new_module):
"""Set new module into model by key name.
def set_attr(model, key, new_attr):
"""Set attribute into model by key name.

This function can set both modules and their attributes (like weight, bias).
For setting only modules, prefer using set_module which uses PyTorch's native API.

Args:
model (torch.nn.Module): original model
key (str): module name to be replaced
new_module (torch.nn.Module): new module to be inserted
key (str): attribute name (e.g., "layer.weight", "layer.bias")
new_attr (object): new attribute to be inserted
"""
module = model
name_list = key.split(".")
for name in name_list[:-1]:
if hasattr(module, name):
module = getattr(module, name)
setattr(module, name_list[-1], new_module)
setattr(module, name_list[-1], new_attr)


def get_module(module, key):
"""Get module from model by key name using PyTorch native API.

Args:
module (torch.nn.Module): original model
key (str): module name
"""
return module.get_submodule(key)


def set_module(model, key, new_module):
"""Set new module into model by key name using PyTorch native API.

# For getting and setting attribution, such as 'lm_head.weight'
get_attr = get_module
set_attr = set_module
Args:
model (torch.nn.Module): original model
key (str): module name
new_module (torch.nn.Module): new module to be inserted
"""
model.set_submodule(key, new_module)


def get_layer_features(layer):
Expand Down
Loading