Skip to content

Missing prime_optimizer_state when using FSDP backend #1717

@lintangsutawika

Description

@lintangsutawika

It seems that the quickstart

git clone https://github.com/NovaSky-AI/SkyRL.git
cd SkyRL
uv venv --python 3.12
uv sync --extra tinker --extra fsdp

then

uv run --extra tinker --extra fsdp -m skyrl.tinker.api \
    --base-model "Qwen/Qwen3-0.6B" \
    --backend fsdp

with

git clone https://github.com/thinking-machines-lab/tinker-cookbook.git
cd tinker-cookbook
TINKER_API_KEY=tml-dummy uv run --with tinker --with datasets \
    python -m tinker_cookbook.recipes.sl_loop \
    base_url=http://localhost:8000 \
    model_name="Qwen/Qwen3-0.6B" \
    train_on_what=LAST_ASSISTANT_MESSAGE

Does not work due to a missing prime_optimizer_state

╭─────────────────────────────────────────────────────────────────────────────────── Traceback (most recent call last) ────────────────────────────────────────────────────────────────────────────────────╮
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/skyrl/tinker/engine.py:715 in process_single_requests                                                                                                │
│                                                                                                                                                                                                          │
│   712 │   │   for request_id, (model_id, request_type, request_data) in requests.items():                                                                                                                │
│   713 │   │   │   with log_timing(f"process_single_request({request_type.value})"):                                                                                                                      │
│   714 │   │   │   │   try:                                                                                                                                                                               │
│ ❱ 715 │   │   │   │   │   result = self.process_single_request(request_type, model_id,                                                                                                                   │
│       request_data)                                                                                                                                                                                      │
│   716 │   │   │   │   except Exception as e:                                                                                                                                                             │
│   717 │   │   │   │   │   logger.exception(f"Error processing request {request_id}: {e}")                                                                                                                │
│   718 │   │   │   │   │   result = types.ErrorResponse(error=str(e), status="failed")                                                                                                                    │
│                                                                                                                                                                                                          │
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/skyrl/tinker/engine.py:687 in process_single_request                                                                                                 │
│                                                                                                                                                                                                          │
│   684 │   def process_single_request(self, request_type: types.RequestType, model_id: str,                                                                                                               │
│       request_data: dict) -> BaseModel:                                                                                                                                                                  │
│   685 │   │   match request_type:                                                                                                                                                                        │
│   686 │   │   │   case types.RequestType.CREATE_MODEL:                                                                                                                                                   │
│ ❱ 687 │   │   │   │   return self.process_create_model(model_id,                                                                                                                                         │
│       types.CreateModelInput.model_validate(request_data))                                                                                                                                               │
│   688 │   │   │   case types.RequestType.OPTIM_STEP:                                                                                                                                                     │
│   689 │   │   │   │   return self.process_optim_step(model_id,                                                                                                                                           │
│       types.OptimStepInput.model_validate(request_data))                                                                                                                                                 │
│   690 │   │   │   case types.RequestType.SAVE_WEIGHTS_FOR_SAMPLER:                                                                                                                                       │
│                                                                                                                                                                                                          │
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/skyrl/tinker/engine.py:477 in process_create_model                                                                                                   │
│                                                                                                                                                                                                          │
│   474 │   def process_create_model(self, model_id: str, request_data: types.CreateModelInput)                                                                                                            │
│       -> types.CreateModelOutput:                                                                                                                                                                        │
│   475 │   │   """Create and initialize a model."""                                                                                                                                                       │
│   476 │   │   # Create model in backend (allocates adapter_index, creates optimizer, and                                                                                                                 │
│       configures adapter)                                                                                                                                                                                │
│ ❱ 477 │   │   self.backend.create_model(model_id, request_data.lora_config,                                                                                                                              │
│       model_role=request_data.model_role)                                                                                                                                                                │
│   478 │   │                                                                                                                                                                                              │
│   479 │   │   logger.info(f"Created LoRA model {model_id}")                                                                                                                                              │
│   480                                                                                                                                                                                                    │
│                                                                                                                                                                                                          │
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/skyrl/backends/skyrl_train_backend.py:460 in create_model                                                                                            │
│                                                                                                                                                                                                          │
│    457 │   │   │   │   raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}")                                                                                                          │
│    458 │   │   │                                                                                                                                                                                         │
│    459 │   │   │   logger.info("Building models.")                                                                                                                                                       │
│ ❱  460 │   │   │   self._build_policy(PolicyWorker, model_id=model_id)                                                                                                                                   │
│    461 │   │   │   if is_lora:                                                                                                                                                                           │
│    462 │   │   │   │   self._base_lora_signature = self._lora_signature_from(lora_config)                                                                                                                │
│    463 │   │   elif model_role == "critic":                                                                                                                                                              │
│                                                                                                                                                                                                          │
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/skyrl/backends/skyrl_train_backend.py:268 in _build_policy                                                                                           │
│                                                                                                                                                                                                          │
│    265 │   │   │   # the freshly-initialised LoRA into a per-worker pristine slot, then                                                                                                                  │
│    266 │   │   │   # register the first adapter under `model_id`. Must happen while the                                                                                                           [0/1894]
│    267 │   │   │   # model + optimizer are still GPU-resident (i.e. before the offload).                                                                                                                 │
│ ❱  268 │   │   │   ray.get(policy_model.async_run_ray_method("pass_through",                                                                                                                             │
│        "prime_optimizer_state"))                                                                                                                                                                         │
│    269 │   │   │   ray.get(policy_model.async_run_ray_method("pass_through",                                                                                                                             │
│        "register_pristine_adapter"))                                                                                                                                                                     │
│    270 │   │   │   ray.get(policy_model.async_run_ray_method("pass_through",                                                                                                                             │
│        "register_adapter", model_id))                                                                                                                                                                    │
│    271                                                                                                                                                                                                   │
│                                                                                                                                                                                                          │
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/skyrl/backends/skyrl_train/workers/worker.py:672 in async_run_ray_method                                                                             │
│                                                                                                                                                                                                          │
│    669 │   │   args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)                                                                                                                     │
│    670 │   │                                                                                                                                                                                             │
│    671 │   │   # Dispatch the method call                                                                                                                                                                │
│ ❱  672 │   │   object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args,                                                                                                               │
│        **kwargs)                                                                                                                                                                                         │
│    673 │   │   return object_refs                                                                                                                                                                        │
│    674                                                                                                                                                                                                   │
│    675                                                                                                                                                                                                   │
│                                                                                                                                                                                                          │
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/skyrl/backends/skyrl_train/distributed/dispatch.py:223 in dispatch                                                                                   │
│                                                                                                                                                                                                          │
│   220 │                                                                                                                                                                                                  │
│   221 │   @classmethod                                                                                                                                                                                   │
│   222 │   def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) ->                                                                                                               │
│       List[ObjectRef]:                                                                                                                                                                                   │
│ ❱ 223 │   │   return [getattr(actor_info.handle, method).remote(*args, **kwargs) for                                                                                                                     │
│       actor_info in actor_infos]                                                                                                                                                                         │
│   224 │                                                                                                                                                                                                  │
│   225 │   @classmethod                                                                                                                                                                                   │
│   226 │   def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:                                                                                                              │
│                                                                                                                                                                                                          │
│ /data/user_data/lsutawik/ssm-self-swe-machine/SkyRL/.venv/lib/python3.12/site-packages/ray/actor.py:2203 in __getattr__                                                                                  │
│                                                                                                                                                                                                          │
│   2200 │   │   │   return self._method_shells[item].bind(self)                                                                                                                                           │
│   2201 │   │                                                                                                                                                                                             │
│   2202 │   │   if not self._ray_is_cross_language:                                                                                                                                                       │
│ ❱ 2203 │   │   │   raise AttributeError(                                                                                                                                                                 │
│   2204 │   │   │   │   f"'{type(self).__name__}' object has " f"no attribute '{item}'"                                                                                                                   │
│   2205 │   │   │   )                                                                                                                                                                                     │
│   2206 │   │   if item in ["__ray_terminate__"]:                                                                                                                                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'ActorHandle' object has no attribute 'prime_optimizer_state'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions