Skip to content

Conversation

dilawarm
Copy link

thanks for making this course open to the public!

Model architecture

  • Hybrid U‑Net Transformer: 16 layers, 8 heads, d_model=1024, seq_len=512
  • U‑Net skip mixing: store activations from layers 1–8, then mix into 9–16 with learned gates
    • Gate params: per‑layer gate with shape (1,1,d_model), initialized to 0.1
    • Mixing: gate * skip + (1 - gate) * current
  • Attention:
    • Layers 1–8: standard MHA with sliding‑window attention (window=256)
    • Layers 9–16: Multi‑Query Attention (MQA) to reduce KV memory
    • RoPE (rotary pos embeddings) applied to Q and K
  • FFN: Squared‑ReLU with learned exponent α (trainable, init 2.0)
    • Width: 4× in layers 1–8, 2.5× in layers 9–16
  • Norms: Fp32 layer norms for numerical stability
  • Gradient checkpointing: enabled (every block) to reduce memory

Precision and kernels

  • Mixed precision: AMP with bfloat16 autocast for forward/backward
  • FP8 attention: enabled via Transformer Engine
  • Flash/SDPA: uses PyTorch scaled‑dot‑product attention with flash/mem‑efficient kernels where masks allow
  • TF32: allowed on CUDA matmul and cuDNN; cuDNN benchmark enabled
  • torch.compile: enabled (mode “reduce‑overhead”, dynamic=True)

Optimizer and LR schedule

  • Muon (with auxiliary AdamW for 1D params)
  • LRs: embeddings 2e‑3, layers 1–8 8e‑4, layers 9–16 6e‑4, LM head 4e‑4
  • Scheduler: cosine with warmup (warmup_steps=1000, min_lr_scale=0.1)
  • Grad accumulation: 2 micro‑steps per optimizer step

Review & Reproduce

I invited @marcelroed to my repo. After following the steps in the README, run:

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python -m llm.cli.train --train-arrays /root/llm/cache/owt_train.npy --valid-arrays /root/llm/cache/owt_valid.npy --seq-len 512 --d-model 1024 --n-heads 8 --n-layers 16 --tokens-per-batch 32768 --grad-accum 2 --use-amp --gradient-checkpointing --num-workers 24 --prefetch-batches 4 --pin-memory --persistent-workers --non-blocking --auto-scale-batch --auto-scale-interval 100 --target-mem-fraction 0.95 --scale-up-factor 1.15 --checkpoint-dir /root/llm/checkpoints --project owt-train --run-name h100-l512-d1024-fast --log-hist-every 1000 --use-fp8-attention

@marcelroed
Copy link
Member

I'm pretty sure you're using the wrong tokenizer or otherwise have wrong loss calculations. This is a bit too much better than the other results.

@marcelroed
Copy link
Member

I can help validate this, but make sure you're using the OWT-trained tokenizer and the vocab size is actually 32000, and look validate that your dataloader is inputting samples fully spanning the range from 0 to 31999. Some common issues I've seen are loading/saving values in uint8 when uint16 is necessary.

@marcelroed
Copy link
Member

Also if you can add a branch for your code where you follow the structure outlined in the README (single reproducible script file, pyproject.toml, uv.lock) that would be great!

@marcelroed
Copy link
Member

A good sanity check for loss is to remember its link to perplexity. Cross entropy loss = average log perplexity per token. A reasonable translation would be exp(loss) ≈ average perplexity. Perplexity can be interpreted as "the number of guesses I need on average to get to the correct token. For 1.4999 you would get exp(1.4999) = 4.48 guesses on average to get to the right token, which is a lot better than exp(3.0781) = 21.72 guesses.

@dilawarm
Copy link
Author

I used GPT2TokenizerFast, missed the requirement of using the OWT-trained tokenizer. I'll update the table with the new results. Thanks!

@dilawarm dilawarm marked this pull request as draft August 26, 2025 18:18
@dilawarm
Copy link
Author

oh, after reading the full requirements more carefully, I also realize that I used a more powerful accelerator (more memory) than what's stated 🙃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants