Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ logs/
separate-logs/
*.distcp

wandb/
wandb/

data/

models/
5 changes: 3 additions & 2 deletions maester/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ class Config(BaseSettings):
data_parallel_shard_degree: int = 8
data_parallel_replicate_degree: int = 1
tensor_parallel_degree: int = 1
context_parallel_degree: int = 1
train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 8192 seqlen = ~4M tokens per batch
gradient_accumulation_steps: int = 1
gradient_accumulation_sync_each_step: bool = False
train_num_steps: int = 1000
compile: bool = True
enable_loss_parallel: bool = True
enable_cut_cross_entropy: bool = True
enable_cut_cross_entropy: bool = False
init_timeout_seconds: int = 300
train_timeout_seconds: int = 100

Expand Down Expand Up @@ -143,7 +144,7 @@ class Config(BaseSettings):
# lr schedule
scheduler: str = "linear_warmup_cosine"
warmup_steps: int = 50
cooldown_steps: int = 100 # used for some schedules
cooldown_steps: int = 50

# fsdp
mixed_precision_param: str = 'bfloat16'
Expand Down
12 changes: 6 additions & 6 deletions maester/models/gemma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
__all__ = ["GemmaTextModel", "ModelArgs"]

gemma3_configs = {
"270M": ModelArgs(
vocab_size=262_144,
dim=640,
n_layers=18,
"debug": ModelArgs(
vocab_size=262_144, # Actual size from google/gemma-3-1b-pt tokenizer
dim=1152,
n_layers=5,
n_heads=4,
num_key_value_heads=1,
head_dim=256,
intermediate_size=2048,
attn_types=["local_sliding", "local_sliding", "local_sliding", "local_sliding", "local_sliding", "global"],
intermediate_size=6912,
attn_types=["local_sliding", "local_sliding", "global", "local_sliding", "local_sliding"],
use_post_ffw_norm=True,
use_pre_ffw_norm=True,
sliding_window_size=512,
Expand Down
43 changes: 31 additions & 12 deletions maester/models/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from cut_cross_entropy import linear_cross_entropy, LinearCrossEntropyImpl

from torch.distributed import DeviceMesh

@dataclass
class ModelArgs:
"""
Expand Down Expand Up @@ -40,6 +42,7 @@ class ModelArgs:
vision_config: dict | None = None # For multimodal models
tied_embeddings: bool = True # For training compatibility
init_std: float = 0.02 # For weight initialization
attention_backend: str = "flex" # "eager", "flex", or "sdpa", but "flex" is recommended as the others might be incorrect

def precompute_freqs_cis(dim: int,
end: int,
Expand Down Expand Up @@ -216,12 +219,25 @@ def _ensure_long(val):
return wrapped_mask_fn


@torch._dynamo.disable
def _no_compile_sdpa(q, k, v, scale: float, is_causal: bool = True, attn_mask: torch.Tensor | None = None):
# q,k,v: [B, H, S, D]; CP sharding on S (dim=2)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return F.scaled_dot_product_attention(
q, k, v,
dropout_p=0.0,
attn_mask=attn_mask,
scale=scale,
is_causal=is_causal,
)

class GemmaAttention(nn.Module):

def __init__(
self,
config: ModelArgs,
attn_type: str
attn_type: str,
cp_device_mesh = None
):
super().__init__()

Expand Down Expand Up @@ -347,13 +363,15 @@ class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
config: ModelArgs,
attn_type: str
attn_type: str,
cp_device_mesh: DeviceMesh | None
):
super().__init__()
self.attn_type = attn_type
self.self_attn = GemmaAttention(
config=config,
attn_type=attn_type
attn_type=attn_type,
cp_device_mesh=cp_device_mesh
)
self.mlp = GemmaMLP(
hidden_size=config.dim,
Expand Down Expand Up @@ -418,7 +436,7 @@ def init_weights(self, init_std: float):
self.mlp.init_weights(init_std)

class GemmaModel(nn.Module):
def __init__(self, config: ModelArgs):
def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
Expand All @@ -430,7 +448,7 @@ def __init__(self, config: ModelArgs):
if config.attn_types is not None
else "global"
)
self.layers.append(Gemma2DecoderLayer(config, attn_type))
self.layers.append(Gemma2DecoderLayer(config, attn_type, cp_device_mesh=cp_device_mesh))
self.norm = RMSNorm(config.dim, eps=config.rms_norm_eps)

def forward(
Expand All @@ -444,7 +462,7 @@ def forward(
layer: Gemma2DecoderLayer = self.layers[i] # type: ignore
hidden_states = layer(
hidden_states=hidden_states,
freqs_cis=freqs_cis.get(layer.attn_type),
freqs_cis=freqs_cis[layer.attn_type],
mask=mask,
local_mask=local_mask,
)
Expand All @@ -460,22 +478,23 @@ def init_weights(self, init_std: float):

class GemmaTextModel(nn.Module):
"""Text-only Gemma model compatible with training setup."""
def __init__(self, config: ModelArgs):
def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None = None):
super().__init__()
self.config = config
self.model_args = config # For compatibility with training code
self.vocab_size = config.vocab_size
self.n_layers = config.n_layers

self.cp_device_mesh = cp_device_mesh

# Text embeddings
self.tok_embeddings = Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.dim
)

# Core transformer model
self.model = GemmaModel(config)
self.model = GemmaModel(config, cp_device_mesh=cp_device_mesh)

# Precompute RoPE frequencies following multimodal pattern
head_dim = config.head_dim
max_seq_len = config.max_seq_len
Expand Down Expand Up @@ -651,9 +670,9 @@ def forward(
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "GemmaTextModel":
def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: DeviceMesh | None = None) -> "GemmaTextModel":
"""Initialize from model args (compatible with training loop)."""
return cls(model_args)
return cls(model_args, cp_device_mesh=cp_device_mesh)


class Gemma3MultiModalModel(nn.Module):
Expand Down
5 changes: 4 additions & 1 deletion maester/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from maester.models.norms import create_norm
from maester.models.llama.tied_linear import TiedLinear

from torch.distributed.device_mesh import DeviceMesh


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -490,12 +492,13 @@ def forward(
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: Optional[DeviceMesh] = None) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.

Args:
model_args (ModelArgs): Model configuration arguments.
cp_device_mesh (Optional[DeviceMesh]): Device mesh for context parallelism.

Returns:
Transformer: Transformer model.
Expand Down
72 changes: 54 additions & 18 deletions maester/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,76 @@ class ParallelDims:
dp_replicate: int
dp_shard: int
tp: int
cp: int
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, tp = self.dp_replicate, self.dp_shard, self.tp
for d in (dp_replicate, tp):
dp_replicate, dp_shard, tp, cp = self.dp_replicate, self.dp_shard, self.tp, self.cp
for d in (dp_replicate, tp, cp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp)
self.dp_shard = dp_shard = dp // dp_replicate
if dp_shard < 0:
self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp * cp)
assert dp_shard >= 1

assert dp_replicate >= 1
assert dp_shard >= 1
assert tp >= 1, tp
assert dp_replicate * dp_shard * tp == self.world_size, (
assert cp >= 1, cp
assert dp_replicate * dp_shard * tp * cp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) != WORLD_SIZE({self.world_size})"
f"tp({tp}) * cp({cp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.dp_replicate, self.dp_shard, self.tp],
["dp_replicate", "dp_shard", "tp"],
[self.dp_replicate, self.dp_shard, self.tp, self.cp],
["dp_replicate", "dp_shard", "tp", "cp"],
):
if d > 1:
dims.append(d)
if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
else:
names.append(name)
names.append(name)
if dims == []: # edge case for non-distributed mesh w/ 1 GPU
dims = [1]
names = ("dp",)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
dp_mesh_dim_names = [] # for data loading (no comms)
dp_shard_cp_mesh_dim_names = [] # for param sharding
dp_cp_mesh_dim_names = [] # for loss all-reduce

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")
dp_shard_cp_mesh_dim_names.append("dp_shard")
dp_cp_mesh_dim_names.append("dp_shard")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")

if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
if dp_shard_cp_mesh_dim_names != []:
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
mesh_dim_name="dp_shard_cp"
)
if dp_cp_mesh_dim_names != []:
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")

return mesh

@property
Expand All @@ -74,6 +94,18 @@ def dp_replicate_enabled(self):
@property
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def dp_cp_enabled(self):
return self.dp_enabled or self.cp_enabled

@property
def cp_enabled(self):
return self.cp > 1

@property
def fsdp_enabled(self):
return self.dp_shard_enabled or self.cp_enabled

@property
def tp_enabled(self):
Expand All @@ -85,4 +117,8 @@ def loss_parallel_enabled(self):

@cached_property
def model_parallel_size(self):
return self.tp
return self.tp

@cached_property
def non_data_parallel_size(self):
return self.cp * self.tp
10 changes: 4 additions & 6 deletions maester/parallelisms/parallelize_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def parallelize_gemma(

# Compile each layer individually
if config.compile:
apply_compile(model)
apply_compile(model, fullgraph=not parallel_dims.cp_enabled) # TODO: fullgraph for CP?

# Apply FSDP
use_fsdp = parallel_dims.dp_enabled or (
Expand All @@ -66,10 +66,9 @@ def parallelize_gemma(

apply_fsdp(
model,
dp_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce],
tp_enabled=parallel_dims.tp_enabled,
#pp_enabled=parallel_dims.pp_enabled,
)

Expand Down Expand Up @@ -257,10 +256,10 @@ def apply_ac(model: nn.Module, config: Config):
logger.info("Applied activation checkpointing to the model")


def apply_compile(model: nn.Module):
def apply_compile(model: nn.Module, fullgraph: bool = False):
"""Compile each transformer layer individually."""
for layer_id, layer in enumerate(model.model.layers):
compiled_layer = torch.compile(layer, fullgraph=True)
compiled_layer = torch.compile(layer, fullgraph=fullgraph)
model.model.layers[layer_id] = compiled_layer
logger.info("Compiled each transformer layer with torch.compile")

Expand All @@ -270,7 +269,6 @@ def apply_fsdp(
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool = False,
):
"""Apply FSDP to Gemma model."""
Expand Down
Loading