forked from Avika2211/pdf-image-classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathai_classifier.py
More file actions
335 lines (285 loc) · 14.9 KB
/
ai_classifier.py
File metadata and controls
335 lines (285 loc) · 14.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import os
import json
import logging
import io
import time
import random
import numpy as np
from PIL import Image
import google.generativeai as genai
import streamlit as st
class AIFigureClassifier:
"""AI-powered figure classifier using Google Gemini."""
def __init__(self, api_key=None):
self.logger = logging.getLogger(__name__)
self.api_key = api_key or os.environ.get("GEMINI_API_KEY")
self.model = None
self.confidence_score = 0.0
self.api_configured = False
# Try to configure Gemini API
if self.api_key:
try:
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel("gemini-1.5-flash")
# Test the API with a simple request
test_response = self.model.generate_content("Test connection")
self.api_configured = True
self.logger.info("Gemini API configured successfully")
except Exception as e:
self.logger.error(f"Failed to configure Gemini API: {str(e)}")
self.api_configured = False
if "API_KEY_INVALID" in str(e) or "401" in str(e):
raise ValueError(f"Invalid Gemini API key: {str(e)}")
else:
st.warning(f"Gemini API configuration failed: {str(e)}. Using fallback classification.")
else:
self.logger.warning("No Gemini API key provided. Using fallback classification only.")
st.warning("⚠️ No Gemini API key provided. Using basic heuristic classification. For better results, please provide a valid Gemini API key.")
self.figure_categories = {
"bar_chart": "Bar Chart - Shows data using rectangular bars",
"pie_chart": "Pie Chart - Circular chart showing proportions",
"line_graph": "Line Graph - Shows trends over time or continuous data",
"scatter_plot": "Scatter Plot - Shows relationship between two variables",
"histogram": "Histogram - Shows distribution of data",
"box_plot": "Box Plot - Shows statistical distribution",
"heatmap": "Heatmap - Shows data intensity with colors",
"flowchart": "Flowchart - Shows process or workflow",
"organizational_chart": "Organizational Chart - Shows hierarchy",
"network_diagram": "Network Diagram - Shows connections between entities",
"scientific_diagram": "Scientific Diagram - Technical/scientific illustration",
"medical_diagram": "Medical Diagram - Anatomical or medical illustration",
"engineering_diagram": "Engineering Diagram - Technical drawing or schematic",
"map": "Map - Geographic or spatial representation",
"floor_plan": "Floor Plan - Architectural layout",
"timeline": "Timeline - Shows events over time",
"table": "Table - Structured data in rows and columns",
"infographic": "Infographic - Visual information presentation",
"photograph": "Photograph - Real-world image",
"screenshot": "Screenshot - Computer screen capture",
"logo": "Logo - Brand or company symbol",
"chart_other": "Other Chart Type - Specialized chart not in main categories",
"diagram_other": "Other Diagram - General diagram or illustration",
"unknown": "Unknown - Cannot determine figure type"
}
def is_api_available(self):
"""Check if Gemini API is properly configured and available."""
return self.api_configured and self.model is not None
def classify_figure(self, image):
"""Classify a figure using Gemini API or fallback to heuristic method."""
# If API is not configured, use fallback immediately
if not self.is_api_available():
st.info("🔄 Using heuristic classification (no API key provided)")
return self._fallback_classification(image)
# Try Gemini API classification
max_retries = 3
base_delay = 1.0
for attempt in range(max_retries):
try:
st.info(f"🤖 Using Gemini AI for classification (attempt {attempt + 1})")
prompt = self._create_classification_prompt()
# Prepare image
img_buffer = io.BytesIO()
image.save(img_buffer, format='PNG')
img_buffer.seek(0)
img = Image.open(img_buffer).convert("RGB")
# Make API request
response = self.model.generate_content([prompt, img])
if response.text and response.text.strip().startswith("{"):
try:
result = json.loads(response.text)
except json.JSONDecodeError as e:
self.logger.error(f"JSON parsing failed: {e}")
st.warning("⚠️ API response parsing failed, using fallback")
return self._fallback_classification(image)
if isinstance(result, list) and len(result) > 0:
result = result[0]
self.confidence_score = result.get('confidence', 0.5)
st.success("✅ AI classification successful!")
return {
'classification': result.get('type', 'unknown'),
'confidence': self.confidence_score,
'description': result.get('description', 'No description available'),
'details': result.get('details', {}),
'reasoning': result.get('reasoning', ''),
'method': 'gemini_api'
}
st.warning("⚠️ Empty response from Gemini API, using fallback")
return self._fallback_classification(image)
except Exception as e:
error_msg = str(e)
self.logger.error(f"AI classification attempt {attempt + 1} failed: {error_msg}")
# Handle specific error types
if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt) + random.uniform(0, 1)
st.warning(f"⏳ Rate limit hit, retrying in {delay:.1f}s...")
time.sleep(delay)
continue
else:
st.error("❌ Rate limit exceeded. Using fallback classification.")
return self._fallback_classification(image)
elif "400" in error_msg or "401" in error_msg or "403" in error_msg:
st.error("❌ API key invalid or expired. Using fallback classification.")
self.api_configured = False # Disable API for future calls
return self._fallback_classification(image)
elif "SAFETY" in error_msg:
st.warning("⚠️ Content filtered by safety settings. Using fallback classification.")
return self._fallback_classification(image)
else:
st.warning(f"⚠️ API error: {error_msg}. Using fallback classification.")
return self._fallback_classification(image)
# If all retries failed, use fallback
st.warning("⚠️ All API attempts failed. Using fallback classification.")
return self._fallback_classification(image)
def _create_classification_prompt(self):
categories_text = "\n".join([f"- {key}: {desc}" for key, desc in self.figure_categories.items()])
return f"""
Analyze this figure/image and classify it into one of the following categories.
AVAILABLE CATEGORIES:
{categories_text}
CLASSIFICATION REQUIREMENTS:
1. Look carefully at visual structure and content.
2. For graphs: identify if bar/pie/line/scatter etc.
3. For diagrams: identify domain (scientific, engineering, etc.).
4. For photos/screenshots: recognize realism or UI elements.
OUTPUT FORMAT (JSON):
{{
"type": "category_key_from_list_above",
"confidence": 0.95,
"description": "Brief description of what you see",
"details": {{
"visual_elements": ["key", "elements"],
"data_type": "type of data if any",
"domain": "subject domain if known"
}},
"reasoning": "Why you chose this classification"
}}
"""
def _fallback_classification(self, image=None):
"""Enhanced fallback classification using image analysis heuristics."""
try:
if image is not None:
img_array = np.array(image)
height, width = img_array.shape[:2]
aspect_ratio = width / height
# Convert to grayscale for analysis
if len(img_array.shape) == 3:
gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140])
else:
gray = img_array
# Calculate various metrics
std_intensity = np.std(gray)
mean_intensity = np.mean(gray)
# Color analysis if available
if len(img_array.shape) == 3:
std_color = np.std(img_array)
color_variance = np.var(img_array, axis=(0, 1)).mean()
else:
std_color = std_intensity
color_variance = 0
# Enhanced heuristic classification
if aspect_ratio > 3:
classification = 'timeline'
confidence = 0.7
description = 'Very wide layout suggests timeline or horizontal process'
elif aspect_ratio > 2:
classification = 'timeline'
confidence = 0.6
description = 'Wide layout suggests timeline or process flow'
elif 0.8 <= aspect_ratio <= 1.2: # Square-ish
if std_color > 80:
classification = 'pie_chart'
confidence = 0.6
description = 'Square, colorful layout suggests pie chart or similar'
elif std_color > 50:
classification = 'chart_other'
confidence = 0.5
description = 'Square layout with moderate color variation suggests chart'
else:
classification = 'diagram_other'
confidence = 0.4
description = 'Square, simple layout suggests diagram'
elif aspect_ratio > 1.5: # Wide
if std_color > 60:
classification = 'bar_chart'
confidence = 0.6
description = 'Wide, colorful layout suggests bar chart'
else:
classification = 'flowchart'
confidence = 0.5
description = 'Wide layout suggests flowchart or process diagram'
elif aspect_ratio < 0.7: # Tall
if std_color > 60:
classification = 'bar_chart'
confidence = 0.5
description = 'Tall, colorful layout suggests vertical bar chart'
else:
classification = 'organizational_chart'
confidence = 0.5
description = 'Tall layout suggests organizational chart or hierarchy'
else: # Regular proportions
if std_color > 100:
classification = 'photograph'
confidence = 0.6
description = 'High color variance suggests photograph'
elif std_color > 70:
classification = 'chart_other'
confidence = 0.5
description = 'Moderate color variance suggests chart or graph'
elif mean_intensity > 200:
classification = 'screenshot'
confidence = 0.4
description = 'High brightness suggests screenshot'
else:
classification = 'diagram_other'
confidence = 0.4
description = 'Basic characteristics suggest diagram'
self.confidence_score = confidence
return {
'classification': classification,
'confidence': confidence,
'description': description,
'details': {
'visual_elements': ['heuristic analysis'],
'aspect_ratio': f'{aspect_ratio:.2f}',
'color_variance': f'{color_variance:.1f}',
'intensity_std': f'{std_intensity:.1f}',
'analysis_method': 'heuristic fallback'
},
'reasoning': f'Used heuristic analysis: aspect ratio {aspect_ratio:.2f}, color variance {color_variance:.1f}',
'method': 'heuristic'
}
# If no image provided
self.confidence_score = 0.3
return {
'classification': 'unknown',
'confidence': 0.3,
'description': 'Figure present, classification unavailable',
'details': {},
'reasoning': 'No image data available for analysis',
'method': 'fallback'
}
except Exception as e:
self.confidence_score = 0.2
return {
'classification': 'unknown',
'confidence': 0.2,
'description': 'Figure detected but analysis failed',
'details': {},
'reasoning': f'Fallback error: {str(e)}',
'method': 'error_fallback'
}
def get_supported_categories(self):
return self.figure_categories
def get_confidence(self):
return self.confidence_score
def batch_classify(self, images, progress_callback=None):
"""Classify multiple images with progress tracking."""
results = []
total = len(images)
for i, image in enumerate(images):
if progress_callback:
progress_callback(i + 1, total)
result = self.classify_figure(image)
results.append(result)
return results