diff --git a/bench/generation_regression_test.sh b/bench/generation_regression_test.sh new file mode 100755 index 0000000..73703d1 --- /dev/null +++ b/bench/generation_regression_test.sh @@ -0,0 +1,139 @@ +#!/bin/bash +# quant.cpp — Generation Regression Test +# +# Detects autoregressive generation collapse that PPL tests miss. +# Tests: T=0 greedy 500-token generation → verify no garbage output. +# +# The key insight: PPL (teacher-forced) is near-identical for FP32 and +# turbo_kv_4b at all context lengths. But autoregressive generation +# can collapse at ~500 tokens when T=0 repetition compounds KV quant error. +# +# This test catches that class of bugs by checking: +# 1. Loop detection triggers (prevents garbage, so verify it fires) +# 2. Output before loop detection is coherent (no random Unicode) +# 3. PPL sanity check at multiple context lengths +# +# Usage: +# bash bench/generation_regression_test.sh [model.gguf] +# +# Requires: built quant binary in build/ + +set -e + +MODEL="${1:-models/Llama-3.2-1B-Instruct-Q8_0.gguf}" +TQ_RUN="./build/quant" +THREADS=4 +PASS=0 +FAIL=0 + +if [ ! -f "$TQ_RUN" ]; then + echo "Error: $TQ_RUN not found. Build first." + exit 1 +fi +if [ ! -f "$MODEL" ]; then + echo "SKIP: Model not found: $MODEL" + exit 0 +fi + +echo "============================================" +echo " Generation Regression Test" +echo " Model: $MODEL" +echo "============================================" +echo "" + +check() { + local desc="$1" result="$2" + if [ "$result" = "PASS" ]; then + echo " [PASS] $desc" + PASS=$((PASS + 1)) + else + echo " [FAIL] $desc" + FAIL=$((FAIL + 1)) + fi +} + +# Test 1: T=0 generation should NOT produce garbage at 500 tokens +echo "[Test 1] T=0 500-token generation — no garbage output" +OUTPUT=$($TQ_RUN "$MODEL" -p "Explain the theory of relativity in detail" \ + -n 500 -T 0.0 -j $THREADS -k turbo_kv_4b --chat 2>/dev/null) + +# Check for garbage patterns: random Unicode, excessive non-ASCII +# Garbage typically has lots of CJK/Arabic/Thai mixed with Latin +GARBAGE_CHARS=$(echo "$OUTPUT" | tr -cd '\200-\377' | wc -c | tr -d ' ') +TOTAL_CHARS=$(echo "$OUTPUT" | wc -c | tr -d ' ') +if [ "$TOTAL_CHARS" -gt 0 ]; then + GARBAGE_RATIO=$((GARBAGE_CHARS * 100 / TOTAL_CHARS)) +else + GARBAGE_RATIO=100 +fi +if [ "$GARBAGE_RATIO" -lt 30 ]; then + check "turbo_kv_4b output coherence (${GARBAGE_RATIO}% non-ASCII)" "PASS" +else + check "turbo_kv_4b output coherence (${GARBAGE_RATIO}% non-ASCII, threshold 30%)" "FAIL" +fi + +# Test 2: Loop detection should fire for T=0 repetitive prompt +echo "" +echo "[Test 2] Loop detection fires on repetitive T=0 generation" +LOOP_OUTPUT=$($TQ_RUN "$MODEL" -p "what is your name?" \ + -n 1000 -T 0.0 -j $THREADS -k turbo_kv_4b 2>&1) + +if echo "$LOOP_OUTPUT" | grep -q "repetition loop detected"; then + LOOP_TOKENS=$(echo "$LOOP_OUTPUT" | grep "repetition loop" | grep -o "after [0-9]* tokens" | grep -o "[0-9]*") + check "loop detected at ${LOOP_TOKENS} tokens (before 500)" "PASS" +else + TOTAL_TOK=$(echo "$LOOP_OUTPUT" | grep "tok/s" | grep -o "^[0-9]*") + if [ "${TOTAL_TOK:-1000}" -lt 500 ]; then + check "EOS hit at ${TOTAL_TOK} tokens (no loop needed)" "PASS" + else + check "no loop detection in 1000 tokens" "FAIL" + fi +fi + +# Test 3: Non-repetitive generation should NOT trigger loop detection +echo "" +echo "[Test 3] Non-repetitive generation (T=0.7) — no false positives" +NORMAL_OUTPUT=$($TQ_RUN "$MODEL" -p "Tell me a creative story" \ + -n 200 -T 0.7 -j $THREADS -k turbo_kv_4b --chat 2>&1) + +if echo "$NORMAL_OUTPUT" | grep -q "repetition loop detected"; then + check "no false loop detection at T=0.7" "FAIL" +else + check "no false loop detection at T=0.7" "PASS" +fi + +# Test 4: FP32 vs turbo_kv_4b PPL sanity (if ppl data exists) +PPL_FILE="bench/data/ppl_test_1k.txt" +if [ -f "$PPL_FILE" ]; then + echo "" + echo "[Test 4] PPL sanity: turbo_kv_4b within 15% of FP32" + FP32_PPL=$($TQ_RUN "$MODEL" --ppl "$PPL_FILE" -k fp32 -j $THREADS 2>&1 \ + | grep "PPL_CSV" | cut -d, -f3) + Q4_PPL=$($TQ_RUN "$MODEL" --ppl "$PPL_FILE" -k turbo_kv_4b -j $THREADS 2>&1 \ + | grep "PPL_CSV" | cut -d, -f3) + + if [ -n "$FP32_PPL" ] && [ -n "$Q4_PPL" ]; then + # Compare using integer math (multiply by 1000) + FP32_INT=$(echo "$FP32_PPL" | awk '{printf "%d", $1 * 1000}') + Q4_INT=$(echo "$Q4_PPL" | awk '{printf "%d", $1 * 1000}') + THRESHOLD=$((FP32_INT * 115 / 100)) # 15% margin + if [ "$Q4_INT" -le "$THRESHOLD" ]; then + DELTA=$(echo "$FP32_PPL $Q4_PPL" | awk '{printf "%.1f", ($2/$1 - 1)*100}') + check "PPL delta: ${DELTA}% (within 15%)" "PASS" + else + DELTA=$(echo "$FP32_PPL $Q4_PPL" | awk '{printf "%.1f", ($2/$1 - 1)*100}') + check "PPL delta: ${DELTA}% (exceeds 15%)" "FAIL" + fi + else + check "PPL comparison (could not parse results)" "FAIL" + fi +fi + +echo "" +echo "============================================" +echo " Results: ${PASS} passed, ${FAIL} failed" +echo "============================================" + +if [ "$FAIL" -gt 0 ]; then + exit 1 +fi diff --git a/quant.h b/quant.h index 720521c..8cf1383 100644 --- a/quant.h +++ b/quant.h @@ -12179,8 +12179,13 @@ tq_model_t* tq_load_gguf(const char* path) { } const size_t MAX_FP32_BYTES = (size_t)16 * 1024 * 1024 * 1024ULL; /* 16 GB */ - /* TQ_NO_Q4=1 disables Q4 recompression → use direct GGUF dequant for better quality */ + /* TQ_NO_Q4=1 disables Q4 recompression → use direct GGUF dequant for better quality. + * Can be set via environment variable or compile-time define (useful for WASM). */ +#ifdef TQ_NO_Q4 + if (1) { +#else if (getenv("TQ_NO_Q4")) { +#endif fprintf(stderr, "tq_load_gguf: TQ_NO_Q4 set — skipping Q4 conversion, using GGUF on-the-fly dequant\n"); goto skip_q4_conversion; } diff --git a/src/engine/tq_generate.c b/src/engine/tq_generate.c index fb964e7..4180472 100644 --- a/src/engine/tq_generate.c +++ b/src/engine/tq_generate.c @@ -254,13 +254,22 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, int vocab_size = model->config.vocab_size; float rep_penalty = config->rep_penalty; int rep_window = config->rep_window; - if (rep_window > 64) rep_window = 64; - int recent_tokens[64]; + if (rep_window > 128) rep_window = 128; + int recent_tokens[128]; int recent_count = 0; + /* N-gram loop detection: track recent 4-grams to detect infinite loops. + * Small models with T=0 greedy decoding enter repetition loops where + * the same ~30-token pattern repeats endlessly. KV quantization error + * compounds through these repetitions, eventually collapsing output + * into garbage. Detecting loops early prevents wasted compute. */ + uint32_t ngram_hashes[64]; + int ngram_hash_count = 0; + int loop_detected = 0; + /* Seed recent tokens with tail of prompt for better penalty coverage */ for (int i = (n_prompt > rep_window ? n_prompt - rep_window : 0); i < n_prompt; i++) { - recent_tokens[recent_count % 64] = prompt_tokens[i]; + recent_tokens[recent_count % 128] = prompt_tokens[i]; recent_count++; } @@ -268,8 +277,8 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, if (rep_penalty > 1.0f) { int window = recent_count < rep_window ? recent_count : rep_window; for (int r = 0; r < window; r++) { - int idx = (recent_count - 1 - r) % 64; - if (idx < 0) idx += 64; + int idx = (recent_count - 1 - r) % 128; + if (idx < 0) idx += 128; int tok = recent_tokens[idx]; if (tok >= 0 && tok < vocab_size) { if (state->logits[tok] > 0) @@ -288,7 +297,7 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, &rng_state); /* Record first sampled token */ - recent_tokens[recent_count % 64] = next_token; + recent_tokens[recent_count % 128] = next_token; recent_count++; int generated = 0; @@ -483,8 +492,32 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, &rng_state); /* Record sampled token for repetition penalty */ - recent_tokens[recent_count % 64] = next_token; + recent_tokens[recent_count % 128] = next_token; recent_count++; + + /* N-gram loop detection: hash recent 4-gram and check for repeats */ + if (recent_count >= 4) { + uint32_t h = 0; + for (int r = 0; r < 4; r++) { + int gi = (recent_count - 4 + r) % 128; + h = h * 31 + (uint32_t)recent_tokens[gi]; + } + int matches = 0; + int ring_len = ngram_hash_count < 64 ? ngram_hash_count : 64; + for (int r = 0; r < ring_len; r++) { + if (ngram_hashes[r] == h) matches++; + } + ngram_hashes[ngram_hash_count % 64] = h; + ngram_hash_count++; + if (matches >= 3) { + loop_detected = 1; + break; + } + } + } + + if (loop_detected) { + fprintf(stderr, "[generate] repetition loop detected after %d tokens, stopping\n", generated); } /* Null-terminate output */ diff --git a/wasm/build.sh b/wasm/build.sh index a611472..df74c62 100755 --- a/wasm/build.sh +++ b/wasm/build.sh @@ -40,6 +40,7 @@ emcc "$SCRIPT_DIR/quant_wasm.c" \ -lm \ -DNDEBUG \ -D__EMSCRIPTEN__ \ + -DTQ_NO_Q4=1 \ -Wno-gnu-zero-variadic-macro-arguments \ -Wno-dollar-in-identifier-extension diff --git a/wasm/index.html b/wasm/index.html index a70c6a5..6f4a663 100644 --- a/wasm/index.html +++ b/wasm/index.html @@ -174,16 +174,11 @@
No install. No API key. No server.