diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 94fb7198b..9d7acc91e 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -921,12 +921,18 @@ def _to_model_dtype(model, model_dtype): return model -def get_module(module, key): - """Get module from model by key name. +def get_attr(module, key): + """Get attribute from module by key name. + + This function can access both modules and their attributes (like weight, bias). + For accessing only modules, prefer using get_module which uses PyTorch's native API. Args: module (torch.nn.Module): original model - key (str): module name to be replaced + key (str): attribute name (e.g., "layer.weight", "layer.bias") + + Returns: + The attribute value, or None if not found """ name_list = key.split(".") for name in name_list: @@ -934,25 +940,44 @@ def get_module(module, key): return module -def set_module(model, key, new_module): - """Set new module into model by key name. +def set_attr(model, key, new_attr): + """Set attribute into model by key name. + + This function can set both modules and their attributes (like weight, bias). + For setting only modules, prefer using set_module which uses PyTorch's native API. Args: model (torch.nn.Module): original model - key (str): module name to be replaced - new_module (torch.nn.Module): new module to be inserted + key (str): attribute name (e.g., "layer.weight", "layer.bias") + new_attr (object): new attribute to be inserted """ module = model name_list = key.split(".") for name in name_list[:-1]: if hasattr(module, name): module = getattr(module, name) - setattr(module, name_list[-1], new_module) + setattr(module, name_list[-1], new_attr) + + +def get_module(module, key): + """Get module from model by key name using PyTorch native API. + Args: + module (torch.nn.Module): original model + key (str): module name + """ + return module.get_submodule(key) + + +def set_module(model, key, new_module): + """Set new module into model by key name using PyTorch native API. -# For getting and setting attribution, such as 'lm_head.weight' -get_attr = get_module -set_attr = set_module + Args: + model (torch.nn.Module): original model + key (str): module name + new_module (torch.nn.Module): new module to be inserted + """ + model.set_submodule(key, new_module) def get_layer_features(layer):