Skip to content

Commit d1f992a

Browse files
committed
[Knowledge Cleaning] Add Math Vision QA Extract Pipeline Demo
1 parent 0b2a82d commit d1f992a

10 files changed

Lines changed: 676 additions & 3 deletions

dataflow/operators/knowledge_cleaning/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
# from .generate.mathbook_question_extract import MathBookQuestionExtract
1212
# from .generate.kbc_multihop_qa_generator import KBCMultiHopQAGenerator
1313
from .generate.kbc_multihop_qa_generator_batch import KBCMultiHopQAGeneratorBatch
14+
from .generate.math_vqa_extract_pdf2img import MathVQAExtractPdf2Img
15+
from .generate.math_vqa_extract_doclayout import MathVQAExtractDocLayout
16+
from .generate.math_vqa_extract_pic_extractor import MathVQAExtractPicExtractor
17+
from .generate.math_vqa_extract_qapair_extractor import MathVQAExtractQAPairExtractor
18+
from .generate.math_vqa_extract_tag2img import MathVQAExtractTag2Img
1419

1520

1621
else:
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
from dataflow.core import OperatorABC
2+
from dataflow.utils.registry import OPERATOR_REGISTRY
3+
from dataflow import get_logger
4+
import os
5+
import cv2
6+
import json
7+
import math
8+
import torch
9+
import multiprocessing
10+
from collections import defaultdict
11+
from doclayout_yolo import YOLOv10
12+
from typing import List
13+
14+
15+
@OPERATOR_REGISTRY.register()
16+
class MathVQAExtractDocLayout(OperatorABC):
17+
def __init__(self, model_path: str):
18+
self.logger = get_logger()
19+
self.model_path = model_path
20+
21+
def check_overlap(self, rect1, rect2):
22+
"""检查两个矩形是否重叠"""
23+
return not (rect1[2] < rect2[0] or rect1[0] > rect2[2] or
24+
rect1[3] < rect2[1] or rect1[1] > rect2[3])
25+
26+
def find_best_label_position(self, x1, y1, x2, y2, text_w, text_h, img_shape, existing_boxes, margin=10):
27+
"""为检测框找到最佳的标签位置"""
28+
img_h, img_w = img_shape[:2]
29+
# 候选位置(优先级由上到下)
30+
candidates = [
31+
{'x': x1, 'y': y1 - text_h - margin, 'type': 'top'},
32+
{'x': x2 + margin, 'y': y1 + (y2-y1-text_h)//2, 'type': 'right'},
33+
{'x': x1, 'y': y2 + text_h + margin, 'type': 'bottom'},
34+
{'x': x1 - text_w - margin, 'y': y1 + (y2-y1-text_h)//2, 'type': 'left'},
35+
{'x': x1 + margin, 'y': y1 + text_h + margin, 'type': 'inside'},
36+
]
37+
38+
def valid(px, py):
39+
if px < 0 or px+text_w > img_w or py-text_h < 0 or py > img_h:
40+
return False
41+
label_rect = [px, py-text_h, px+text_w, py]
42+
for box in existing_boxes:
43+
if self.check_overlap(label_rect, box):
44+
return False
45+
return True
46+
47+
for c in candidates:
48+
if valid(c['x'], c['y']):
49+
return c['x'], c['y'], c['type']
50+
# fallback
51+
fx = max(0, min(x1, img_w-text_w))
52+
fy = max(text_h, y1 - margin)
53+
return fx, fy, 'fallback'
54+
55+
def draw_adaptive_label(self, image, x1, y1, x2, y2, text, existing_boxes,
56+
font=cv2.FONT_HERSHEY_SIMPLEX, fs=1.0, ft=2):
57+
"""在最佳位置绘制带背景的标签"""
58+
(tw, th), base = cv2.getTextSize(text, font, fs, ft)
59+
lx, ly, pos = self.find_best_label_position(x1, y1, x2, y2, tw, th, image.shape, existing_boxes)
60+
cmap = {
61+
'top': (0, 255, 255),
62+
'right': (255, 0, 255),
63+
'bottom': (0, 255, 255),
64+
'left': (255, 255, 0),
65+
'inside': (255, 165, 0),
66+
'fallback': (0, 0, 255),
67+
}
68+
color = cmap.get(pos, (0, 255, 255))
69+
pad = 5
70+
cv2.rectangle(image,
71+
(lx-pad, ly-th-pad),
72+
(lx+tw+pad, ly+pad),
73+
color, -1)
74+
cv2.putText(image, text, (lx, ly), font, fs, (0, 0, 0), ft)
75+
return pos
76+
77+
def worker_process(self, img_list, output_img_dir, output_json_dir, prefix, gpu_id, imgsz, conf_thres):
78+
"""
79+
worker 进程:
80+
- 在指定 gpu_id 上加载一次模型
81+
- 顺序处理分配到自己的多张图片
82+
"""
83+
# 设置设备
84+
if gpu_id != "":
85+
device_str = f"cuda:{gpu_id}"
86+
# 注意:在多进程环境中设置CUDA_VISIBLE_DEVICES可能不会按预期工作
87+
# 更好的做法是在启动进程前设置,或者使用torch.cuda.set_device()
88+
torch.cuda.set_device(gpu_id)
89+
else:
90+
device_str = "cpu"
91+
92+
# 加载模型
93+
model = YOLOv10(self.model_path)
94+
95+
for current_img_path in img_list:
96+
try:
97+
# 1)推理
98+
dets = model.predict(current_img_path, imgsz=imgsz, conf=conf_thres, device=device_str)
99+
# 2)读取原图
100+
img = cv2.imread(current_img_path)
101+
if img is None:
102+
self.logger.error(f"无法读取图片: {current_img_path}")
103+
continue
104+
105+
result = dets[0]
106+
boxes = result.boxes
107+
name_map = result.names
108+
109+
detections = []
110+
existing_boxes = []
111+
# 收集所有 box 坐标
112+
if boxes is not None and len(boxes) > 0:
113+
for box in boxes:
114+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
115+
existing_boxes.append([int(x1), int(y1), int(x2), int(y2)])
116+
117+
# 画框 + 标签
118+
cls_count = defaultdict(int)
119+
for box in boxes:
120+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
121+
x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
122+
conf = float(box.conf[0].cpu().numpy())
123+
cid = int(box.cls[0].cpu().numpy())
124+
cname = name_map[cid]
125+
cls_count[cname] += 1
126+
label = f"{cname}{cls_count[cname]}"
127+
128+
# 矩形框
129+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 3)
130+
# 自适应标签
131+
pos_type = self.draw_adaptive_label(img, x1, y1, x2, y2, label, existing_boxes,
132+
fs=0.8, ft=2)
133+
detections.append({
134+
"id": label,
135+
"class_name": cname,
136+
"confidence": conf,
137+
"bbox": [x1, y1, x2, y2],
138+
"label_position": pos_type
139+
})
140+
141+
# 3)保存图像与 JSON
142+
base_name = os.path.splitext(os.path.basename(current_img_path))[0]
143+
out_img_path = os.path.join(output_img_dir, f"{prefix}_{base_name}.jpg")
144+
out_json_path = os.path.join(output_json_dir, f"{prefix}_{base_name}.json")
145+
146+
# 确保输出目录存在
147+
os.makedirs(output_img_dir, exist_ok=True)
148+
os.makedirs(output_json_dir, exist_ok=True)
149+
150+
# 保存图片
151+
success = cv2.imwrite(out_img_path, img)
152+
if not success:
153+
self.logger.error(f"保存图片失败: {out_img_path}")
154+
155+
# 保存JSON
156+
with open(out_json_path, 'w', encoding='utf-8') as f:
157+
json.dump({
158+
"image_path": current_img_path,
159+
"total_detections": len(detections),
160+
"detections": detections
161+
}, f, ensure_ascii=False, indent=2)
162+
163+
self.logger.info(f"处理完成: {current_img_path} -> {out_img_path}, {out_json_path}")
164+
165+
except Exception as e:
166+
self.logger.error(f"处理图片 {current_img_path} 时出错: {str(e)}")
167+
168+
def batch_process(self, image_paths, output_folder, output_prefix, imgsz=1024, conf_thres=0.2):
169+
"""
170+
批量处理接口:
171+
image_paths: List[str] 待处理图片路径列表
172+
output_folder: str 输出文件夹
173+
output_prefix: str 输出图片/JSON 的前缀
174+
其余参数为可选模型配置
175+
"""
176+
os.makedirs(output_folder, exist_ok=True)
177+
178+
# 创建输出子目录
179+
img_output_dir = os.path.join(output_folder, "images")
180+
json_output_dir = os.path.join(output_folder, "json")
181+
os.makedirs(img_output_dir, exist_ok=True)
182+
os.makedirs(json_output_dir, exist_ok=True)
183+
184+
# 可用的 GPU 列表
185+
ngpu = torch.cuda.device_count()
186+
if ngpu == 0:
187+
# 如果没有 GPU,则当成 1 个 worker 在 CPU 上跑
188+
gpu_list = [""]
189+
else:
190+
gpu_list = list(range(ngpu))
191+
192+
# 将 image_paths 均匀切分给每个 GPU/worker
193+
chunks = []
194+
n = len(image_paths)
195+
k = len(gpu_list)
196+
per = math.ceil(n / k) if k > 0 else n
197+
for i in range(k):
198+
start_idx = i * per
199+
end_idx = min((i + 1) * per, n)
200+
if start_idx < end_idx:
201+
chunks.append(image_paths[start_idx:end_idx])
202+
203+
# 启动多进程
204+
procs = []
205+
for i, (gpu_id, img_chunk) in enumerate(zip(gpu_list, chunks)):
206+
if not img_chunk:
207+
continue
208+
209+
p = multiprocessing.Process(
210+
target=self.worker_process,
211+
args=(img_chunk, img_output_dir, json_output_dir, output_prefix,
212+
gpu_id, imgsz, conf_thres)
213+
)
214+
p.start()
215+
procs.append(p)
216+
217+
for p in procs:
218+
p.join()
219+
220+
def run(self, input_image_folder: str, output_folder: str, output_prefix: str = "doclay"):
221+
os.makedirs(output_folder, exist_ok=True)
222+
# 获取所有图片路径,确保扩展名为jpg or png
223+
image_list = []
224+
for f in os.listdir(input_image_folder):
225+
if f.lower().endswith(('.jpg', '.png', '.jpeg')):
226+
image_list.append(os.path.join(input_image_folder, f))
227+
228+
# 确保图片路径存在
229+
valid_image_list = []
230+
for img_path in image_list:
231+
if os.path.exists(img_path):
232+
valid_image_list.append(img_path)
233+
else:
234+
self.logger.warning(f"图片路径不存在: {img_path}")
235+
236+
if not valid_image_list:
237+
self.logger.warning("没有找到有效的图片文件")
238+
return
239+
240+
# 批量处理
241+
self.batch_process(
242+
image_paths=valid_image_list,
243+
output_folder=output_folder,
244+
output_prefix=output_prefix,
245+
imgsz=1024,
246+
conf_thres=0.2
247+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from dataflow.core import OperatorABC
2+
import fitz
3+
import os
4+
from dataflow.utils.registry import OPERATOR_REGISTRY
5+
from dataflow import get_logger
6+
7+
@OPERATOR_REGISTRY.register()
8+
class MathVQAExtractPdf2Img(OperatorABC):
9+
def __init__(self, dpi: int = 300):
10+
self.logger = get_logger()
11+
self.dpi = dpi
12+
def run(self, input_pdf_path: str, output_image_folder: str):
13+
'''
14+
用来把pdf文件转换为图片的辅助函数
15+
输入:
16+
pdf_path: pdf文件路径
17+
output_folder: 输出图片文件夹路径
18+
'''
19+
doc = fitz.open(input_pdf_path)
20+
# make output directory if it doesn't exist
21+
os.makedirs(output_image_folder, exist_ok=True)
22+
# convert each page to image
23+
for page_index in range(len(doc)):
24+
page = doc.load_page(page_index)
25+
pix = page.get_pixmap(dpi=self.dpi)
26+
pix.save(f"{output_image_folder}/page_{page_index}.jpg")
27+
self.logger.info(f"Converted page {page_index} to image")
28+
return output_image_folder
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from dataflow.utils.registry import OPERATOR_REGISTRY
2+
from dataflow import get_logger
3+
from dataflow.core import OperatorABC
4+
from dataflow.core import LLMServingABC
5+
import pandas as pd
6+
import random
7+
from dataflow.prompts.kbcleaning import MathVQAExtractPrompt
8+
import os
9+
from typing import List
10+
11+
@OPERATOR_REGISTRY.register()
12+
class MathVQAExtractPicExtractor(OperatorABC):
13+
def __init__(self,
14+
llm_serving: LLMServingABC = None,
15+
model: str = "o4-mini"
16+
):
17+
self.logger = get_logger()
18+
self.llm_serving = llm_serving
19+
self.prompt = MathVQAExtractPrompt()
20+
self.model = model
21+
22+
def _format_instructions(self, image_files: List[str]):
23+
list_of_image_paths = []
24+
list_of_image_labels = []
25+
labels = ["page_" + image_file.split("_")[-1].split(".")[0] for image_file in image_files]
26+
for index in range(len(image_files) - 1):
27+
list_of_image_paths.append([image_files[index], image_files[index + 1]])
28+
list_of_image_labels.append([labels[index], labels[index + 1]])
29+
# 对于最后一页,没有下一页,单独插入
30+
list_of_image_paths.append([image_files[-1]])
31+
list_of_image_labels.append([labels[-1]])
32+
return list_of_image_paths, list_of_image_labels
33+
34+
35+
def run(self, input_layout_path: str, output_folder: str):
36+
# 从layout_path/images中读取所有图片的文件名,确保为绝对路径
37+
image_files = [os.path.join(input_layout_path, "images", image_file) for image_file in os.listdir(os.path.join(input_layout_path, "images"))]
38+
# 确保end with jpg & png
39+
image_files = [image_file for image_file in image_files if image_file.endswith(".jpg") or image_file.endswith(".png")]
40+
41+
def filename2idx(filename: str):
42+
return int(filename.split("/")[-1].split(".")[0].split("_")[-1])
43+
# 按照文件名从小到大排序
44+
image_files.sort(key=filename2idx)
45+
46+
list_of_image_paths, list_of_image_labels = self._format_instructions(image_files)
47+
system_prompt = self.prompt.build_prompt()
48+
49+
responses = self.llm_serving.generate_from_input_multi_images(list_of_image_paths, list_of_image_labels, system_prompt, self.model)
50+
51+
# 将list of image paths和list of image labels和repsonses作为三列组织为jsonl
52+
list_of_dict = []
53+
for image_path, image_label, response in zip(list_of_image_paths, list_of_image_labels, responses):
54+
list_of_dict.append({"image_path": image_path, "image_label": image_label, "response": response})
55+
df = pd.DataFrame(list_of_dict)
56+
57+
# 将df保存为jsonl文件
58+
os.makedirs(output_folder, exist_ok=True)
59+
df.to_json(os.path.join(output_folder, "vqa_extract.jsonl"), orient="records", lines=True, force_ascii=False)
60+
61+
return df

0 commit comments

Comments
 (0)