Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# training output
wandb/
checkpoints/
test_checkpoints/

# poetry
poetry.lock
Expand All @@ -30,4 +31,6 @@ save/
**/save/

# claude
CLAUDE.md
CLAUDE.md

venv_clt/
1 change: 1 addition & 0 deletions scripts/launch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def main():
n_train_batch_per_buffer=36,
total_training_tokens=total_training_tokens,
train_batch_size_tokens=train_batch_size_tokens,
gradient_accumulation_steps=1, # Set > 1 to accumulate gradients
adam_beta1=0.9,
adam_beta2=0.999,
lr=2e-4,
Expand Down
1 change: 1 addition & 0 deletions scripts/launch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def main():
n_train_batch_per_buffer=36,
total_training_tokens=total_training_tokens,
train_batch_size_tokens=train_batch_size_tokens,
gradient_accumulation_steps=1, # Set > 1 to accumulate gradients over multiple micro-batches
adam_beta1=0.9,
adam_beta2=0.999,
lr=2e-4,
Expand Down
Binary file modified src/clt/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file modified src/clt/__pycache__/clt.cpython-311.pyc
Binary file not shown.
Binary file modified src/clt/__pycache__/clt_training_runner.cpython-311.pyc
Binary file not shown.
Binary file modified src/clt/__pycache__/load_model.cpython-311.pyc
Binary file not shown.
Binary file modified src/clt/__pycache__/utils.cpython-311.pyc
Binary file not shown.
5 changes: 3 additions & 2 deletions src/clt/clt_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from clt.config import CLTTrainingRunnerConfig, CLTConfig
from clt.utils import DTYPE_MAP, DummyModel
from clt.clt import CLT
from clt import logger
from clt.load_model import load_model
from clt.training.activations_store import ActivationsStore
from clt.training.clt_trainer import CLTTrainer
Expand Down Expand Up @@ -161,7 +162,7 @@ def run(self):
logger.info(f"lr: {self.cfg.lr}")
logger.info(f"dead_penalty_coef: {self.cfg.dead_penalty_coef}")

trainer = CLTTrainer(
self.trainer = CLTTrainer(
clt=self.clt,
activations_store=self.activations_store,
save_checkpoint_fn=self.save_checkpoint,
Expand All @@ -170,7 +171,7 @@ def run(self):
world_size=self.world_size
)

clt = trainer.fit()
clt = self.trainer.fit()

if self.cfg.log_to_wandb and self.is_main_process:
wandb.finish()
Expand Down
Binary file modified src/clt/config/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file modified src/clt/config/__pycache__/clt_config.cpython-311.pyc
Binary file not shown.
Binary file not shown.
9 changes: 8 additions & 1 deletion src/clt/config/clt_training_runner_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class CLTTrainingRunnerConfig(BaseModel):
# -----Training/Optimization--------------
total_training_tokens: int = 100_000_000
train_batch_size_tokens: int = 4096
gradient_accumulation_steps: int = 1
adam_beta1: float = 0.0
adam_beta2: float = 0.999
lr: float = 1e-5
Expand Down Expand Up @@ -199,6 +200,10 @@ def model_post_init(self, __context):
logger.info("d_latent : %d", self.d_latent)
logger.info("total tokens : %.3e", self.total_training_tokens)
logger.info("batch (tokens) : %d", self.train_batch_size_tokens)
if self.gradient_accumulation_steps > 1:
effective_batch_size = self.train_batch_size_tokens * self.gradient_accumulation_steps
logger.info("grad accum steps: %d", self.gradient_accumulation_steps)
logger.info("effective batch : %d", effective_batch_size)
total_steps = self.total_training_tokens // self.train_batch_size_tokens
logger.info("total steps : %d", total_steps)
n_tokens_per_buffer = (
Expand Down Expand Up @@ -228,7 +233,9 @@ def to_dict(self, *, exclude_none: bool = True,**kw) -> Dict[str, Any]:

@property
def total_training_steps(self) -> int:
return int(self.total_training_tokens // self.train_batch_size_tokens)
# Total optimizer steps, accounting for gradient accumulation
micro_batches = int(self.total_training_tokens // self.train_batch_size_tokens)
return micro_batches // self.gradient_accumulation_steps

@property
def is_distributed(self) -> bool:
Expand Down
Binary file modified src/clt/training/__pycache__/activations_store.cpython-311.pyc
Binary file not shown.
Binary file modified src/clt/training/__pycache__/clt_trainer.cpython-311.pyc
Binary file not shown.
Binary file modified src/clt/training/__pycache__/optim.cpython-311.pyc
Binary file not shown.
6 changes: 4 additions & 2 deletions src/clt/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,10 @@ def __iter__(self):
def load_dataset_auto(path_or_name: str, split: str = "train", is_multilingual_split_dataset: bool = False):
if os.path.exists(path_or_name):
logger.info("Loading from disk")

# return load_from_disk(path_or_name)

# Check if it's a dataset saved with save_to_disk
if Path(path_or_name, "state.json").exists():
return load_from_disk(path_or_name)

return load_dataset(
path_or_name,
Expand Down
57 changes: 40 additions & 17 deletions src/clt/training/clt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(

self.n_tokens: int = 0
self.monitoring_l0 = None
self.accumulation_step: int = 0

def _initialize_b_enc(self, n_batches: int = 10):

Expand Down Expand Up @@ -148,6 +149,7 @@ def fit(self):
if self.cfg.from_pretrained_path is None:
self._initialize_b_enc()

#print(f"[TRAINER] GPU {self.rank} - b_enc mean: {self.clt.b_enc.mean().item():.4f}, b_enc sum: {self.clt.b_enc.sum().item():.4f}", flush=True)
logger.info(f"GPU {self.rank} - b_enc mean: {self.clt.b_enc.mean().item():.4f}, b_enc sum: {self.clt.b_enc.sum().item():.4f}")

while self.n_tokens < self.cfg.total_training_tokens:
Expand All @@ -167,11 +169,16 @@ def fit(self):
)

self.n_tokens += self.cfg.train_batch_size_tokens
self.n_training_steps += 1
if self.is_main_process:

# Only log, checkpoint, and count steps after completing accumulation cycle
if self.accumulation_step == 0:
self.n_training_steps += 1

#print(f"[TRAINER] Step {self.n_training_steps} - MSE: {loss_metrics.mse_loss:.4f}, L0: {loss_metrics.l0_loss:.4f}", flush=True)
logger.info(f"Training step {self.n_training_steps}")
self._log_train_step(loss_metrics)
self._run_and_log_evals()
self._checkpoint_if_needed()
self._checkpoint_if_needed()

# if self.cfg.functional_loss is not None and self.fc_scheduler.get_lr() > 0 and start_func_finetuning:
# self._enable_functional_training()
Expand Down Expand Up @@ -302,14 +309,19 @@ def _compute_training_step_loss(self, act_in: torch.Tensor, act_out: torch.Tenso
if self.n_training_steps < 5:
logger.info(f"GPU {self.rank} - act_in sum: {act_in.sum().item():.4f}, shape: {act_in.shape}")

self.optimizer.zero_grad()
# Only zero gradients at the start of accumulation
if self.accumulation_step == 0:
self.optimizer.zero_grad()

if self.scaler is not None:
with autocast(device_type='cuda', dtype=torch.bfloat16):
loss, loss_metrics = self.clt(act_in, act_out, self.l0_scheduler.get_lr(), df_coef=self.cfg.dead_penalty_coef)
else:
loss, loss_metrics = self.clt(act_in, act_out, self.l0_scheduler.get_lr(), df_coef=self.cfg.dead_penalty_coef)

# Scale loss by accumulation steps
loss = loss / self.cfg.gradient_accumulation_steps

if self.n_training_steps == 0 and self.rank == 0:
logger.info(f"feat_act shape: {loss_metrics.feature_acts.shape}")
logger.info(f"act_pred shape: {loss_metrics.act_pred.shape}")
Expand All @@ -324,26 +336,37 @@ def _compute_training_step_loss(self, act_in: torch.Tensor, act_out: torch.Tenso

if self.scaler is not None:
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.clt.parameters(), 1.0)

if self.cfg.is_sharded:
self._synchronize_feature_sharding_gradients()

self.scaler.step(self.optimizer)
self.scaler.update()
# Only step optimizer every N accumulation steps
if (self.accumulation_step + 1) % self.cfg.gradient_accumulation_steps == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.clt.parameters(), 1.0)

if self.cfg.is_sharded:
self._synchronize_feature_sharding_gradients()

self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()

if self.cfg.is_sharded:
self._synchronize_feature_sharding_gradients()

self.optimizer.step()
# Only step optimizer every N accumulation steps
if (self.accumulation_step + 1) % self.cfg.gradient_accumulation_steps == 0:
if self.cfg.is_sharded:
self._synchronize_feature_sharding_gradients()

self.optimizer.step()

# Increment accumulation counter
self.accumulation_step = (self.accumulation_step + 1) % self.cfg.gradient_accumulation_steps

self._log_debug_info(loss_metrics)

self.update_optimizer_lr()
self.l0_scheduler.step()
# Only update learning rate when we actually step the optimizer
if self.accumulation_step == 0:
self.update_optimizer_lr()
self.l0_scheduler.step()

return loss_metrics

def update_optimizer_lr(self) -> float:
Expand Down
Binary file not shown.
Binary file added tests/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
161 changes: 161 additions & 0 deletions tests/training/test_gradient_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
Entirely made by Claude
"""

"""
Test gradient accumulation by running actual CLT training on NeelNanda dataset
"""
import pytest
import torch
from pathlib import Path
from clt.config import CLTConfig, CLTTrainingRunnerConfig
from clt.clt_training_runner import CLTTrainingRunner
import wandb
from clt import logger


# Get test data path
test_dir = Path(__file__).resolve().parent.parent
dataset_path = str(test_dir / "data" / "NeelNanda_c4_10k_tokenized")


def test_gradient_accumulation_training():
"""
Test gradient accumulation by running actual training and verifying:
1. Losses decrease over time
2. Scheduler steps match expected count
3. Training completes successfully
"""

print("\n" + "="*70)
print("Testing Gradient Accumulation with Actual Training")
print("="*70)

# Small training run configuration
total_optimizer_steps = 200 # Number of actual optimizer updates
gradient_accumulation_steps = 4
train_batch_size_tokens = 128

# Calculate total tokens needed
total_training_tokens = train_batch_size_tokens * total_optimizer_steps * gradient_accumulation_steps

print(f"\nConfiguration:")
print(f" Dataset: {dataset_path}")
print(f" Gradient accumulation steps: {gradient_accumulation_steps}")
print(f" Micro-batch size: {train_batch_size_tokens} tokens")
print(f" Effective batch size: {train_batch_size_tokens * gradient_accumulation_steps} tokens")
print(f" Target optimizer steps: {total_optimizer_steps}")
print(f" Total training tokens: {total_training_tokens}")

cfg = CLTTrainingRunnerConfig(
device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32",
seed=42,
n_checkpoints=0, # No checkpoints for testing
checkpoint_path="test_checkpoints/grad_accum",
logger_verbose=True,
model_class_name="HookedTransformer",
model_name="roneneldan/TinyStories-33M",
dataset_path=dataset_path,
context_size=16,
from_pretrained_path=None,
d_in=768,
expansion_factor=4, # Small for fast testing
jumprelu_init_threshold=0.03,
jumprelu_bandwidth=1.0,
n_batches_in_buffer=4,
store_batch_size_prompts=8,
total_training_tokens=total_training_tokens,
train_batch_size_tokens=train_batch_size_tokens,
gradient_accumulation_steps=gradient_accumulation_steps,
adam_beta1=0.9,
adam_beta2=0.999,
lr=1e-3,
lr_warm_up_steps=5,
lr_decay_steps=5,
final_lr_scale=0.5,
l0_coefficient=1.0,
dead_penalty_coef=0.0,
dead_feature_window=50,
l0_warm_up_steps=10,
l0_waiting_steps=0,
decay_stable_steps=35,
cross_layer_decoders=True,
log_to_wandb=False,
wandb_project="test-grad-accum",
wandb_id="test_grad_accum_001",
wandb_log_frequency=5,
eval_every_n_wandb_logs=10,
run_name="test_gradient_accumulation",
wandb_entity=None,
ddp=False,
fsdp=False,
feature_sharding=False,
)

print(f"\nStarting training...")
print("-"*70)

# Run training
runner = CLTTrainingRunner(cfg)
print(f"\nStarting training...")
print("-"*70)

# Run training
clt = runner.run()

# Access trainer after run() completes
trainer = runner.trainer

print("-"*70)
print(f"Training completed!")
print(f"\nTraining summary:")
print(f" Total optimizer steps: {trainer.n_training_steps}")
print(f" Total tokens processed: {trainer.n_tokens}")

# Verify results
print("\n" + "="*70)
print("Verification:")
print("="*70)

# 1. Check that we completed the expected number of optimizer steps
actual_steps = trainer.n_training_steps
print(f"✓ Optimizer steps: {actual_steps} (expected: {total_optimizer_steps})")
assert actual_steps == total_optimizer_steps, \
f"Expected {total_optimizer_steps} optimizer steps, got {actual_steps}"

# 2. Check that total tokens processed is correct
expected_tokens = total_training_tokens
actual_tokens = trainer.n_tokens
print(f"✓ Tokens processed: {actual_tokens} (expected: {expected_tokens})")
assert actual_tokens == expected_tokens, \
f"Expected {expected_tokens} tokens, got {actual_tokens}"

# 3. Verify gradient accumulation worked by checking losses decreased
# This is the key test for gradient accumulation - training should work correctly
if hasattr(trainer, '_losses') and len(trainer._losses) > 0:
first_loss = trainer._losses[0]
last_loss = trainer._losses[-1]
print(f"✓ Loss progression: {first_loss:.4f} → {last_loss:.4f}")
# Loss should generally decrease (allowing some variance)
if last_loss < first_loss * 1.5: # Allow some increase but not too much
print(f"✓ Training converged successfully")
else:
print(f"⚠ Warning: Loss increased significantly")

# 4. Verify accumulation counter behavior (if accessible)
if hasattr(trainer, 'accumulation_step'):
# After training completes, accumulation_step should be 0 (reset after last batch)
print(f"✓ Final accumulation step: {trainer.accumulation_step}")

# 5. Training completed successfully
print(f"✓ Training completed without errors")

print("\n" + "="*70)
print("✅ All gradient accumulation tests PASSED!")
print("="*70)


if __name__ == "__main__":
test_gradient_accumulation_training()
print("\n✅ Test completed successfully!")