Skip to content

Auto-resolve seq_len in config (because math is hard) #155

@adamimos

Description

@adamimos

The Problem

Every time we write a training script, we perform the same sacred ritual:

bos = cfg.generative_process.bos_token
eos = cfg.generative_process.eos_token
ctx_len = cfg.predictive_model.instance.cfg.n_ctx
seq_len = ctx_len - int(bos is not None) - int(eos is not None)  # 🧙‍♂️ ancient incantation

This is:

  1. Repetitive - we copy-paste this into every script like it's 2005
  2. Error-prone - one day someone will forget the int() cast and spend 3 hours debugging
  3. Undignified - we have a beautiful config resolution system that auto-resolves vocab_size, bos_token, eos_token, and d_vocab, but we draw the line at subtraction?

The Dream

seq_len = cfg.generative_process.seq_len  # ✨ it just works ✨

The Solution

Simplexity already does the hard work of resolving bos_token and eos_token. We just need to:

  1. Add seq_len: int = MISSING to GenerativeProcessConfig
  2. Grab n_ctx from the model config (it's not dynamically resolved, just a YAML value)
  3. Compute seq_len = n_ctx - int(bos is not None) - int(eos is not None) during resolution
  4. Never think about this again

Why This Isn't Already Done

Honestly? No idea. Maybe we thought "subtraction builds character." Maybe we were saving it for a rainy day. Well, it's raining now. ☔


🤖 Prompt for Your Favorite LLM Coding Agent

(Probably Claude, let's be real)

I need you to add automatic seq_len resolution to simplexity. Here's what to do:

1. In `simplexity/structured_configs/generative_process.py`:
   - Add `seq_len: int = MISSING` to the `GenerativeProcessConfig` dataclass
   - Modify `resolve_generative_process_config` to accept an optional `n_ctx: int | None = None` parameter
   - At the end of that function, if n_ctx is provided and seq_len is MISSING, compute:
     `cfg.seq_len = n_ctx - int(bos is not None) - int(eos is not None)`
   - Log it like the other resolved values

2. In `simplexity/run_management/run_management.py`:
   - Add a helper function `_get_model_n_ctx(cfg, instance_keys)` that:
     - Filters for predictive model targets
     - If it's a HookedTransformer config, returns `instance.cfg.n_ctx`
     - Otherwise returns None
   - In `_setup_generative_processes`, before calling `resolve_generative_process_config`:
     - Call `n_ctx = _get_model_n_ctx(cfg, instance_keys)`
     - Pass `n_ctx=n_ctx` to `resolve_generative_process_config`

That's ~20-30 lines total. The patterns for dynamic resolution already exist in the codebase - follow them.

This issue was generated with the help of Claude, who is very tired of watching humans do arithmetic.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions