Skip to content

FSDP2训练保存的lora权重文件为空 #156

@PlutoQyl

Description

@PlutoQyl
elif self.accelerator.is_fsdp2:
            # FSDP/FSDP2
            from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
            if state_dict_keys is not None:
                # Temporarily mark unwanted params as frozen
                # This `requires_grad` trick does not work correctly. Don't know why.
                original_state = {}
                
                # Freeze unwanted params
                for name, param in model.named_parameters():
                    original_state[name] = param.requires_grad
                    param.requires_grad = is_param_match_key(name, state_dict_keys)
                
                options = StateDictOptions(
                    full_state_dict=True,
                    broadcast_from_rank0=True,
                    cpu_offload=True,
                    ignore_frozen_params=True,
                )
                state_dict = get_model_state_dict(model, options=options)
                
                # Restore original state
                for name, param in model.named_parameters():
                    param.requires_grad = original_state[name]

看到代码有这一段,不知道作者现在是否有解决办法?我在训练Online DPO时发现这个问题(zero2在H20上训练Flux2klein-9B会爆显存)

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions