-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdebug_prompt_comparison.py
More file actions
184 lines (148 loc) · 6.4 KB
/
debug_prompt_comparison.py
File metadata and controls
184 lines (148 loc) · 6.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python3
"""
Debug script to compare prompt construction between our framework and direct generation.
"""
import sys
from pathlib import Path
# Add src to path
sys.path.append(str(Path(__file__).parent / "src"))
from models.model_factory import ModelFactory
from adaptive.adaptive_cot import AdaptiveCoT
def test_prompt_construction():
"""Test prompt construction differences."""
print("🔬 Debug Prompt Construction")
print("=" * 60)
# Load model
print("🔧 Loading model...")
model = ModelFactory.create_model(
model_type="deepseek",
model_name="/raid/LLM/llama3.1-8b-instruct",
config={"gpu_id": 0}
)
model.load_model()
# Test problem
problem = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
print(f"Problem: {problem}")
print()
# Test 1: Zero-shot direct generation prompt
print("🔧 Test 1: Zero-shot Direct Generation Prompt")
print("-" * 50)
direct_prompt = f"Q: {problem}\nA:"
print(f"Direct prompt: '{direct_prompt}'")
print(f"Direct prompt repr: {repr(direct_prompt)}")
print()
# Test 2: Zero-shot our framework prompt
print("🔧 Test 2: Zero-shot Our Framework Prompt")
print("-" * 50)
config = {
"adaptive_branching": False,
"min_branches": 1,
"max_branches": 1,
"default_branches": 1,
"num_fewshot": 0,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": 512,
}
adaptive_cot = AdaptiveCoT(model, config)
# Check what prompt would be used in _generate_reasoning_paths
if adaptive_cot.num_fewshot > 0:
examples = adaptive_cot.fewshot_loader.get_fewshot_examples("gsm8k", adaptive_cot.num_fewshot)
framework_prompt = adaptive_cot.fewshot_loader.format_fewshot_prompt(examples, problem)
else:
framework_prompt = f"Q: {problem}\nA:"
print(f"Framework prompt: '{framework_prompt}'")
print(f"Framework prompt repr: {repr(framework_prompt)}")
print()
# Test 3: 8-shot direct generation prompt
print("🔧 Test 3: 8-shot Direct Generation Prompt")
print("-" * 50)
from src.adaptive.fewshot_examples import FewShotExampleLoader
fewshot_loader = FewShotExampleLoader()
examples = fewshot_loader.get_fewshot_examples("gsm8k", 8)
direct_8shot_prompt = fewshot_loader.format_fewshot_prompt(examples, problem)
print(f"Direct 8-shot prompt: '{direct_8shot_prompt[:200]}...'")
print(f"Direct 8-shot prompt repr: {repr(direct_8shot_prompt[:200])}...")
print()
# Test 4: 8-shot our framework prompt
print("🔧 Test 4: 8-shot Our Framework Prompt")
print("-" * 50)
config_8shot = {
"adaptive_branching": False,
"min_branches": 1,
"max_branches": 1,
"default_branches": 1,
"num_fewshot": 8,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": 512,
}
adaptive_cot_8shot = AdaptiveCoT(model, config_8shot)
if adaptive_cot_8shot.num_fewshot > 0:
examples = adaptive_cot_8shot.fewshot_loader.get_fewshot_examples("gsm8k", adaptive_cot_8shot.num_fewshot)
framework_8shot_prompt = adaptive_cot_8shot.fewshot_loader.format_fewshot_prompt(examples, problem)
else:
framework_8shot_prompt = f"Q: {problem}\nA:"
print(f"Framework 8-shot prompt: '{framework_8shot_prompt[:200]}...'")
print(f"Framework 8-shot prompt repr: {repr(framework_8shot_prompt[:200])}...")
print()
# Test 5: Compare prompts
print("🔧 Test 5: Prompt Comparisons")
print("-" * 50)
print(f"Zero-shot prompts identical: {direct_prompt == framework_prompt}")
print(f"8-shot prompts identical: {direct_8shot_prompt == framework_8shot_prompt}")
if direct_prompt != framework_prompt:
print("\nZero-shot prompt differences:")
for i, (c1, c2) in enumerate(zip(direct_prompt, framework_prompt)):
if c1 != c2:
print(f" Position {i}: '{c1}' vs '{c2}' (ord: {ord(c1)} vs {ord(c2)})")
if direct_8shot_prompt != framework_8shot_prompt:
print("\n8-shot prompt differences:")
for i, (c1, c2) in enumerate(zip(direct_8shot_prompt, framework_8shot_prompt)):
if c1 != c2:
print(f" Position {i}: '{c1}' vs '{c2}' (ord: {ord(c1)} vs {ord(c2)})")
# Test 6: Check if there are any other differences in the generation process
print("\n🔧 Test 6: Generation Process Differences")
print("-" * 50)
# Check the exact method being called in our framework
print("Our framework calls _generate_reasoning_paths with:")
print(f" problem: {problem[:50]}...")
print(f" num_branches: 1")
print(f" prefill_signals: {{}}")
# Check what happens inside _generate_reasoning_paths
print("\nInside _generate_reasoning_paths:")
print(f" num_fewshot: {adaptive_cot.num_fewshot}")
print(f" prompt construction: {'few-shot' if adaptive_cot.num_fewshot > 0 else 'zero-shot'}")
# Test 7: Check if there are any other parameters being passed
print("\n🔧 Test 7: Generation Parameters")
print("-" * 50)
# Override the generate method to log all parameters
original_generate = model.generate
def logged_generate(prompt, **kwargs):
print(f"Parameters passed to model.generate:")
for key, value in kwargs.items():
print(f" {key}: {value}")
print(f"Prompt length: {len(prompt)}")
print(f"Prompt first 100 chars: {prompt[:100]}")
print()
return original_generate(prompt, **kwargs)
model.generate = logged_generate
# Test zero-shot generation
print("Zero-shot generation parameters:")
torch.manual_seed(42)
np.random.seed(42)
generated = model.generate(
direct_prompt,
max_tokens=512,
temperature=0.0,
top_p=1.0,
do_sample=False,
num_return_sequences=1
)
# Test framework generation
print("\nFramework generation parameters:")
reasoning_paths = adaptive_cot._generate_reasoning_paths(problem, 1, {})
if __name__ == "__main__":
import torch
import numpy as np
test_prompt_construction()