-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
174 lines (143 loc) · 5.39 KB
/
app.py
File metadata and controls
174 lines (143 loc) · 5.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Math Tutor Bot - Web Application
Flask backend that serves the Qwen2.5-Omni-3B model
"""
from flask import Flask, render_template, request, jsonify, send_file
from flask_cors import CORS
from transformers import Qwen2_5OmniForConditionalGeneration, AutoTokenizer
import torch
import os
import io
import soundfile as sf
from datetime import datetime
from kokoro import KPipeline
app = Flask(__name__)
CORS(app)
# Global variables for model, tokenizer, and TTS
model = None
tokenizer = None
tts = None
conversation_history = []
def load_model():
"""Load the Qwen2.5-Omni model and TTS on startup"""
global model, tokenizer, tts
print("Loading Qwen2.5-Omni-3B model...")
print("This may take a minute on first load...")
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-Omni-3B",
torch_dtype=torch.float16, # Half precision for 20-30% speedup
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Omni-3B")
print("Model loaded successfully!")
print(f"Device: {model.device}")
print("Loading Kokoro TTS for voice generation...")
tts = KPipeline(lang_code='a') # 'a' for American English
print("TTS loaded successfully!")
@app.route('/')
def home():
"""Serve the main page"""
return render_template('index.html')
@app.route('/api/chat', methods=['POST'])
def chat():
"""Handle text chat requests"""
global conversation_history
data = request.json
user_message = data.get('message', '')
mode = data.get('mode', 'text') # 'text' or 'voice'
if not user_message:
return jsonify({'error': 'No message provided'}), 400
# Add user message to history
conversation_history.append({
"role": "user",
"content": user_message
})
# Keep only last 10 messages for context (5 exchanges)
if len(conversation_history) > 10:
conversation_history = conversation_history[-10:]
# Prepare conversation with system prompt
full_conversation = [
{"role": "system", "content": "You are a helpful, patient, and encouraging math tutor. Explain concepts clearly, show step-by-step solutions, and help students understand the 'why' behind each step. Use simple language and encourage critical thinking."}
] + conversation_history
try:
# Generate response using tokenizer
text_input = tokenizer.apply_chat_template(
full_conversation,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer([text_input], return_tensors="pt").to(model.device)
# Generate text response first (optimized for speed)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=256, # Reduced from 512 for faster responses
do_sample=False,
use_cache=True
)
# Handle tuple return
if isinstance(output, tuple):
generated_ids = output[0]
else:
generated_ids = output
response_text = tokenizer.batch_decode(
generated_ids[:, inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)[0]
conversation_history.append({
"role": "assistant",
"content": response_text
})
# If voice mode, generate audio using Kokoro TTS
if mode == 'voice':
try:
# Generate audio from text using Kokoro
audio_samples = tts(response_text)
# Save audio file
audio_filename = f"response_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
audio_path = os.path.join('static', audio_filename)
sf.write(audio_path, audio_samples, 24000)
return jsonify({
'response': response_text,
'audio_url': f'/static/{audio_filename}'
})
except Exception as e:
print(f"Error generating audio: {e}")
# Fall back to text-only if audio fails
return jsonify({'response': response_text})
else:
# Text-only mode - just return the text
return jsonify({'response': response_text})
except Exception as e:
print(f"Error generating response: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/clear', methods=['POST'])
def clear_history():
"""Clear conversation history"""
global conversation_history
conversation_history = []
return jsonify({'status': 'success', 'message': 'Conversation history cleared'})
@app.route('/api/history', methods=['GET'])
def get_history():
"""Get conversation history"""
return jsonify({'history': conversation_history})
@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'model_loaded': model is not None,
'device': str(model.device) if model else None
})
if __name__ == '__main__':
# Load model before starting server
load_model()
print("\n" + "="*60)
print("Math Tutor Bot - Web Application")
print("="*60)
print("Server starting on http://localhost:5001")
print("Open this URL in your web browser")
print("Press Ctrl+C to stop the server")
print("="*60 + "\n")
# Run Flask app
app.run(host='0.0.0.0', port=5001, debug=False)