Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyrl-tx/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"rich>=14.1.0",
"safetensors>=0.6.2",
"tokenizers>=0.21.2",
"transformers>=4.56.1",
"transformers>=4.56.1,<5",
"typer>=0.17.4",
# "wandb>=0.22.0",
"peft",
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_llama3(tp: int):

base_config = AutoConfig.from_pretrained(model_name)
config = Llama3Config(base_config, max_lora_adapters=1, max_lora_rank=1, shard_attention_heads=True)
mesh = jax.make_mesh((1, tp), ("dp", "tp"))
mesh = jax.make_mesh((1, tp), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Llama3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp, config, model)
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_llama3_lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_lora_training():
config = Llama3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)

checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"])
mesh = jax.make_mesh((1, 1), ("dp", "tp"))
mesh = jax.make_mesh((1, 1), ("dp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Llama3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0))
load_safetensors(checkpoint_path, config, model)
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_model(tmp_dir, model_name, config_cls, model_cls, mesh_axes, *, loss_ch
loss_chunk_size=loss_chunk_size,
gradient_checkpointing=False,
)
mesh = jax.make_mesh((1, 1), mesh_axes)
mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp_dir, config, model)
Expand Down
26 changes: 13 additions & 13 deletions skyrl-tx/tests/models/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_qwen3(tp: int):

base_config = PretrainedConfig.from_pretrained("Qwen/Qwen3-0.6B")
config = Qwen3Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True)
mesh = jax.make_mesh((1, tp), ("fsdp", "tp"))
mesh = jax.make_mesh((1, tp), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp, config, model)
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_qwen3_moe_layer(ep: int, tp: int):
with torch.no_grad():
hf_final_hidden_states, hf_router_logits = hf_moe_layer.forward(x)

mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"))
mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)
with jax.set_mesh(mesh):
moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_moe_base_weights(moe_layer, hf_moe_layer)
Expand All @@ -99,10 +99,10 @@ def load_lora_weights(
and jax_module.lora_scaling is not None
and jax_module.lora_ranks is not None
)
jax_module.lora_A.value = jax_module.lora_A.value.at[adapter_idx].set(jnp.array(lora_A_weights))
jax_module.lora_B.value = jax_module.lora_B.value.at[adapter_idx].set(jnp.array(lora_B_weights))
jax_module.lora_scaling.value = jax_module.lora_scaling.value.at[adapter_idx].set(scaling)
jax_module.lora_ranks.value = jax_module.lora_ranks.value.at[adapter_idx].set(rank)
jax_module.lora_A[...] = jax_module.lora_A[...].at[adapter_idx].set(jnp.array(lora_A_weights))
jax_module.lora_B[...] = jax_module.lora_B[...].at[adapter_idx].set(jnp.array(lora_B_weights))
jax_module.lora_scaling[...] = jax_module.lora_scaling[...].at[adapter_idx].set(scaling)
jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank)


@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)])
Expand All @@ -116,7 +116,7 @@ def test_qwen3_moe_layer_lora(ep: int, tp: int):
hf_moe_layer = hf_model.model.layers[0].mlp
x = torch.randn(3, 4, config.hidden_size)

mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"))
mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)
with jax.set_mesh(mesh):
moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_moe_base_weights(moe_layer, hf_moe_layer)
Expand All @@ -128,8 +128,8 @@ def test_qwen3_moe_layer_lora(ep: int, tp: int):
for adapter_idx in range(config.max_lora_adapters):
for proj in [moe_layer.experts.gate_proj, moe_layer.experts.up_proj, moe_layer.experts.down_proj]:
assert proj.lora_A is not None and proj.lora_B is not None
lora_A = rng.normal(0, 1.0, proj.lora_A.value.shape[1:])
lora_B = rng.normal(0, 1.0, proj.lora_B.value.shape[1:])
lora_A = rng.normal(0, 1.0, proj.lora_A[...].shape[1:])
lora_B = rng.normal(0, 1.0, proj.lora_B[...].shape[1:])
load_lora_weights(proj, adapter_idx, lora_A, lora_B, scaling, rank)

# Test with different adapters per sample
Expand All @@ -150,12 +150,12 @@ def test_qwen3_moe_layer_lora(ep: int, tp: int):

# For each expert, merge: base + scaling * (lora_A @ lora_B)
for expert_idx in range(config.num_experts):
lora_A = proj.lora_A.value[adapter_idx, expert_idx, :, :]
lora_B = proj.lora_B.value[adapter_idx, expert_idx, :, :]
lora_A = proj.lora_A[...][adapter_idx, expert_idx, :, :]
lora_B = proj.lora_B[...][adapter_idx, expert_idx, :, :]
lora_delta = scaling * (lora_A @ lora_B)

merged_weight = proj.weight[expert_idx, :, :] + lora_delta
proj_merged.weight.value = proj_merged.weight.value.at[expert_idx, :, :].set(merged_weight)
proj_merged.weight[...] = proj_merged.weight[...].at[expert_idx, :, :].set(merged_weight)

# Run merged model on this sample
x_sample = x[sample_idx : sample_idx + 1].numpy()
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_qwen3_lora():
shard_attention_heads=True,
)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(base_tmp, config, model)
Expand Down
4 changes: 2 additions & 2 deletions skyrl-tx/tests/models/test_qwen3_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_qwen3_generate():
base_config = PretrainedConfig.from_pretrained(model_name)
config = Qwen3Config(base_config, max_lora_adapters=2, max_lora_rank=32, shard_attention_heads=True)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp, config, model)
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_qwen3_generate_speed():

with tempfile.TemporaryDirectory() as tmp:
hf_model.save_pretrained(tmp, safe_serialization=True)
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Qwen3ForCausalLM(config, dtype=jnp.bfloat16, rngs=nnx.Rngs(0))
load_safetensors(tmp, config, model)
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_qwen3_lora_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_lora_training():
config = Qwen3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)

checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"])
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Qwen3ForCausalLM(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0))
load_safetensors(checkpoint_path, config, model)
Expand Down
10 changes: 5 additions & 5 deletions skyrl-tx/tests/utils/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_test_model(base_model_name: str, rank: int, alpha: int, adapter_index

config = Qwen3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = Qwen3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
init_lora_adapter(model, adapter_index=adapter_index, lora_config=LoraConfig(rank=rank, alpha=alpha, seed=0))
Expand All @@ -57,12 +57,12 @@ def test_save_load_lora_checkpoint(storage_type: str, monkeypatch, tmp_path: Pat
# Set LoRA weights to random values for testing (to catch transpose bugs)
q_proj = model.model.layers[0].self_attn.q_proj
rng1, rng2 = jax.random.split(jax.random.PRNGKey(42))
q_proj.lora_A.value = jax.random.normal(rng1, q_proj.lora_A.value.shape)
q_proj.lora_B.value = jax.random.normal(rng2, q_proj.lora_B.value.shape)
q_proj.lora_A[...] = jax.random.normal(rng1, q_proj.lora_A[...].shape)
q_proj.lora_B[...] = jax.random.normal(rng2, q_proj.lora_B[...].shape)

# Store expected values (trimmed to rank and transposed)
expected_lora_A = np.array(q_proj.lora_A.value[adapter_index, :, :rank].T)
expected_lora_B = np.array(q_proj.lora_B.value[adapter_index, :rank, :].T)
expected_lora_A = np.array(q_proj.lora_A[...][adapter_index, :, :rank].T)
expected_lora_B = np.array(q_proj.lora_B[...][adapter_index, :rank, :].T)

# Save and verify checkpoint exists
models.save_lora_checkpoint(model, base_model_name, adapter_config, adapter_index, output_path)
Expand Down
24 changes: 14 additions & 10 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> ja
def T(self):
"""Return a callable that projects hidden states back to vocabulary space."""
# TODO: Apply lora adapters here as well
return lambda hidden_states, adapter_indices=None: hidden_states @ self.embedding.value.T
return lambda hidden_states, adapter_indices=None: hidden_states @ self.embedding[...].T


class LoRALinear(LoRAMixin, nnx.Linear):
Expand Down Expand Up @@ -224,8 +224,8 @@ def __init__(

self.weight = Param(num_experts, in_features, out_features, dtype=dtype, kernel_init=kernel_init, rngs=rngs)

assert self.weight.value.sharding is not None, "LoRAExpert layer needs sharding"
sharding = self.weight.value.sharding.spec
assert self.weight[...].sharding is not None, "LoRAExpert layer needs sharding"
sharding = self.weight[...].sharding.spec
self.init_lora(
max_lora_adapters=max_lora_adapters,
max_lora_rank=max_lora_rank,
Expand All @@ -245,7 +245,7 @@ def __call__(
*,
group_offset: jax.Array | None = None,
) -> jax.Array:
base_out = ragged_dot(x, self.weight.value, group_sizes, group_offset=group_offset)
base_out = ragged_dot(x, self.weight[...], group_sizes, group_offset=group_offset)

if self.max_lora_adapters == 0 or adapter_indices_sorted is None:
return base_out
Expand All @@ -259,14 +259,18 @@ def __call__(
# Expert-first flattening so local expert groups are contiguous
flattened_indices = expert_indices * self.max_lora_adapters + adapter_indices_sorted
num_flattened_groups = self.num_experts * self.max_lora_adapters
num_local_experts = self.lora_A.value.shape[1]
num_local_experts = self.lora_A[...].shape[1]

# Reshape LoRA weights in expert-first order
lora_A = self.lora_A.value.transpose((1, 0, 2, 3)).reshape(
self.max_lora_adapters * num_local_experts, self.in_features, self.max_lora_rank
lora_A = (
self.lora_A[...]
.transpose((1, 0, 2, 3))
.reshape(self.max_lora_adapters * num_local_experts, self.in_features, self.max_lora_rank)
)
lora_B = self.lora_B.value.transpose((1, 0, 2, 3)).reshape(
self.max_lora_adapters * num_local_experts, self.max_lora_rank, self.out_features
lora_B = (
self.lora_B[...]
.transpose((1, 0, 2, 3))
.reshape(self.max_lora_adapters * num_local_experts, self.max_lora_rank, self.out_features)
)

# Sort tokens by combined index
Expand All @@ -281,7 +285,7 @@ def __call__(

# Unsort and apply scaling
lora_output = lora_output_sorted[unsort_indices]
lora_output = lora_output * self.lora_scaling.value[adapter_indices_sorted, None]
lora_output = lora_output * self.lora_scaling[...][adapter_indices_sorted, None]

return base_out + lora_output

Expand Down
1 change: 1 addition & 0 deletions skyrl-tx/tx/tinker/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig):
config.tensor_parallel_size,
),
("fsdp", "ep", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 3,
)
with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True):
self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0))
Expand Down
Loading