-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtree_instruct.py
More file actions
209 lines (169 loc) · 9.56 KB
/
tree_instruct.py
File metadata and controls
209 lines (169 loc) · 9.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from source.interaction_instructions import *
from source.agent_personas import *
from source.parse_utils import *
import torch
import pickle
import os
from transformers import pipeline
import re
import argparse
import glob
from source.model_definitions import *
from agents.verifier import Verifier
from agents.instructor import Instructor
from agents.student import Student
def fix_misunderstanding(student: Student, instructor: Instructor, verifier: Verifier, state_representation, target_rep):
global convo_history
global buggy_code
level = 0
level_questions = {}
level_indices = {}
is_student_done = False
level_explanations = {}
# error handling
level_questions[-1] = [f'Can you walk through the logic of your code?']
prefix = ""
candidate_questions = instructor.generate_candidate_questions(convo_history=convo_history, target=target_rep, tag="initial")
while not is_student_done:
if level not in level_questions.keys():
level_questions[level] = candidate_questions
level_indices[level] = 0 # setting i
# get student answer for question i at level l
instructor_question = prefix + level_questions[level][level_indices[level]]
convo_history.append("Instructor: " + instructor_question)
student_question_response = student.ask_student(instructor_question)
convo_history.append("Student: " + student_question_response)
level_indices[level] += 1
# use verifier to see if student understands the curr level questions
clu_explanation, is_curr_level_understand = verifier.assess_understanding_of_curr_level(instructor_question, student_question_response)
if is_curr_level_understand:
target_idx = state_representation[0].index(target_rep)
for i in range(target_idx, len(state_representation[0])):
rep_task = state_representation[0][i]
# is this state repr already true?
if state_representation[1][i]:
continue
# is this state attribute actually resolved?
exp, flag = verifier.assess_state_level_understanding(instructor_question, student_question_response, rep_task, i == target_idx, convo_history)
# if it is resolved -> update state representation and move onto next attribute
# if it is not resolved -> if no progression (i == idx), just prefix_next_level;
# -> if progression (i > idx), ask student to generate bug fixes
if flag:
state_representation[1][i] = True
is_student_done = True
elif i == target_idx:
level_explanations[level] = exp
if not is_student_done:
level += 1
prefix = prefix_next_level
else:
# no change in level
prefix = prefix_same_level
level_explanations[level] = clu_explanation
with open(f'{LOG_FOLDER}/{FILE_NAME}/convo.txt', 'a+') as f:
try:
f.write(f'Instructor: {instructor_question}\n')
f.write(f'Student: {student_question_response}\n')
f.write(f'\tCurrent Target: {target_rep}\n')
f.write(f'\tCurrent Level: {is_curr_level_understand}, {clu_explanation}\n')
f.write(f'\tCurrent State Representation: {state_representation[1]}\n\n')
except:
f.write(f'\tCurrent Level: N/A, N/A\n')
# generate new questions
if not is_student_done:
# check to add teaching
if not is_curr_level_understand and (len(level_questions[level]) >= 3 or level >= 7):
teaching = instructor.generate_teaching(level_questions[level][0], target_rep)
candidate_questions = f"Consider the following: {teaching}\n{level_questions[level][-1]}"
level_questions[level].append(candidate_questions)
log(f'------------------------------------------------------------------------\nADDING TEACHING:{candidate_questions}\n------------------------------------------------------------------------', os.path.join(LOG_FOLDER, FILE_NAME))
elif not is_curr_level_understand: # same level
candidate_questions = instructor.generate_candidate_questions(convo_history=convo_history, prev_qs='\n'.join(level_questions[level]), target=target_rep, explanation=level_explanations[level], tag="same")
level_questions[level].extend(candidate_questions)
else: # next level
candidate_questions = instructor.generate_candidate_questions(convo_history=convo_history, prev_qs='\n'.join(level_questions[level - 1]), target=target_rep, explanation=level_explanations[level-1], tag="next")
# generate bug fixes
student_bug_fixes = student.generate_bug_fixes(convo_history)
with open(f'{LOG_FOLDER}/{FILE_NAME}/bug_fixes.txt', 'a+') as f:
f.write(f'{student_bug_fixes}\n')
if len(student_bug_fixes):
is_code_correct = verifier.check_bug_fixes(student_bug_fixes)
if is_code_correct:
state_representation[1] = [True]*len(state_representation[1])
else:
new_code = buggy_code
return state_representation, buggy_code
def run():
global problem_statement, correct_code, buggy_code, bug_fixes, bug_description, LOG_FOLDER, FILE_NAME
log_file = os.path.join(LOG_FOLDER, FILE_NAME)
verifier = Verifier(problem_statement, correct_code, buggy_code, bug_fixes, bug_description, model=llama_8b_model, log_file=log_file)
instructor = Instructor(problem_statement, buggy_code, bug_fixes, bug_description, model=verifier.model, log_file=log_file)
student = Student(problem_statement, buggy_code, model=mistral_model, log_file=log_file)
did_student_understand = False
# get state representation based on student initial progress
# starting point: buggy code
# how to we get to ending point: correct code
state_attributes = verifier.get_state_repr()
state_repr = [state_attributes, [False]*len(state_attributes)]
temp = ""
for x, y in zip(state_repr[0], state_repr[1]):
temp += f"{x}, {y}\n"
log(f"State Representation: {temp}", os.path.join(LOG_FOLDER, FILE_NAME))
global convo_history
while not did_student_understand:
if False in state_repr[1]:
target = state_repr[0][state_repr[1].index(False)]
state_repr, student_new_code = fix_misunderstanding(student, instructor, verifier, state_repr, target)
buggy_code = student_new_code
convo_history = []
temp = ""
for x, y in zip(state_repr[0], state_repr[1]):
temp += f"{x}, {y}\n"
log(f"State Representation: {temp}", os.path.join(LOG_FOLDER, FILE_NAME))
else:
did_student_understand = True
with open(f'{LOG_FOLDER}/{FILE_NAME}/correct_code.txt', 'w+') as f:
f.write(buggy_code)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--file', type=str, default='..data_pkls/')
parser.add_argument('--bug_num', type=int, default=1)
parser.add_argument('--log_folder', type=str, default='single_bug_llama')
args = parser.parse_args()
LOG_FOLDER = args.log_folder
try:
os.mkdir(f'{LOG_FOLDER}')
except:
hi = 9
files = glob.glob(f'{args.file}/*.pkl')
for f in files:
try:
FILE_NAME = f[f.rfind('/')+1:]
print("HELLO", FILE_NAME)
try:
os.mkdir(f'{LOG_FOLDER}/{FILE_NAME}')
except OSError:
if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'log.txt')):
os.remove(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'log.txt'))
if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'bug_fixes.txt')):
os.remove(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'bug_fixes.txt'))
if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'convo.txt')):
os.remove(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'convo.txt'))
if os.path.exists(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'correct_code.txt')):
os.remove(os.path.join(f'{LOG_FOLDER}/{FILE_NAME}', 'correct_code.txt'))
extracted_data = pickle.load(open(f, 'rb'))
problem_statement = extracted_data['problem']
buggy_code = extracted_data['buggy_code']
bug_fixes = extracted_data['bug_fixes']
if args.bug_num == 1:
first_bug_fix = re.findall(r'^---\nbug_fixes:\n([\S\s]*)\n---\n$', bug_fixes, re.IGNORECASE)[0].split("\n")[0]
bug_fixes = re.sub(r'^---\nbug_fixes:\n[\S\s]*\n---\n$', f'---\nbug_fixes:\n{first_bug_fix}\n---\n', bug_fixes)
bug_description = extracted_data['bug_desc']
correct_code = extracted_data['correct_code']
unit_tests = ''
convo_history = []
log(f"problem statement:\n{problem_statement}\nbuggy_code:\n{buggy_code}\ncorrect_code:\n{correct_code}\nbug_fixes:\n{bug_fixes}", os.path.join(LOG_FOLDER, FILE_NAME))
run()
except:
with open('FAILURE_CASES.txt', 'a+') as f:
f.write(f'{FILE_NAME}\n')