1+ from dataflow .core import OperatorABC
2+ from dataflow .utils .registry import OPERATOR_REGISTRY
3+ from dataflow import get_logger
4+ import os
5+ import cv2
6+ import json
7+ import math
8+ import torch
9+ import multiprocessing
10+ from collections import defaultdict
11+ from doclayout_yolo import YOLOv10
12+ from typing import List
13+
14+
15+ @OPERATOR_REGISTRY .register ()
16+ class MathVQAExtractDocLayout (OperatorABC ):
17+ def __init__ (self , model_path : str ):
18+ self .logger = get_logger ()
19+ self .model_path = model_path
20+
21+ def check_overlap (self , rect1 , rect2 ):
22+ """检查两个矩形是否重叠"""
23+ return not (rect1 [2 ] < rect2 [0 ] or rect1 [0 ] > rect2 [2 ] or
24+ rect1 [3 ] < rect2 [1 ] or rect1 [1 ] > rect2 [3 ])
25+
26+ def find_best_label_position (self , x1 , y1 , x2 , y2 , text_w , text_h , img_shape , existing_boxes , margin = 10 ):
27+ """为检测框找到最佳的标签位置"""
28+ img_h , img_w = img_shape [:2 ]
29+ # 候选位置(优先级由上到下)
30+ candidates = [
31+ {'x' : x1 , 'y' : y1 - text_h - margin , 'type' : 'top' },
32+ {'x' : x2 + margin , 'y' : y1 + (y2 - y1 - text_h )// 2 , 'type' : 'right' },
33+ {'x' : x1 , 'y' : y2 + text_h + margin , 'type' : 'bottom' },
34+ {'x' : x1 - text_w - margin , 'y' : y1 + (y2 - y1 - text_h )// 2 , 'type' : 'left' },
35+ {'x' : x1 + margin , 'y' : y1 + text_h + margin , 'type' : 'inside' },
36+ ]
37+
38+ def valid (px , py ):
39+ if px < 0 or px + text_w > img_w or py - text_h < 0 or py > img_h :
40+ return False
41+ label_rect = [px , py - text_h , px + text_w , py ]
42+ for box in existing_boxes :
43+ if self .check_overlap (label_rect , box ):
44+ return False
45+ return True
46+
47+ for c in candidates :
48+ if valid (c ['x' ], c ['y' ]):
49+ return c ['x' ], c ['y' ], c ['type' ]
50+ # fallback
51+ fx = max (0 , min (x1 , img_w - text_w ))
52+ fy = max (text_h , y1 - margin )
53+ return fx , fy , 'fallback'
54+
55+ def draw_adaptive_label (self , image , x1 , y1 , x2 , y2 , text , existing_boxes ,
56+ font = cv2 .FONT_HERSHEY_SIMPLEX , fs = 1.0 , ft = 2 ):
57+ """在最佳位置绘制带背景的标签"""
58+ (tw , th ), base = cv2 .getTextSize (text , font , fs , ft )
59+ lx , ly , pos = self .find_best_label_position (x1 , y1 , x2 , y2 , tw , th , image .shape , existing_boxes )
60+ cmap = {
61+ 'top' : (0 , 255 , 255 ),
62+ 'right' : (255 , 0 , 255 ),
63+ 'bottom' : (0 , 255 , 255 ),
64+ 'left' : (255 , 255 , 0 ),
65+ 'inside' : (255 , 165 , 0 ),
66+ 'fallback' : (0 , 0 , 255 ),
67+ }
68+ color = cmap .get (pos , (0 , 255 , 255 ))
69+ pad = 5
70+ cv2 .rectangle (image ,
71+ (lx - pad , ly - th - pad ),
72+ (lx + tw + pad , ly + pad ),
73+ color , - 1 )
74+ cv2 .putText (image , text , (lx , ly ), font , fs , (0 , 0 , 0 ), ft )
75+ return pos
76+
77+ def worker_process (self , img_list , output_img_dir , output_json_dir , prefix , gpu_id , imgsz , conf_thres ):
78+ """
79+ worker 进程:
80+ - 在指定 gpu_id 上加载一次模型
81+ - 顺序处理分配到自己的多张图片
82+ """
83+ # 设置设备
84+ if gpu_id != "" :
85+ device_str = f"cuda:{ gpu_id } "
86+ # 注意:在多进程环境中设置CUDA_VISIBLE_DEVICES可能不会按预期工作
87+ # 更好的做法是在启动进程前设置,或者使用torch.cuda.set_device()
88+ torch .cuda .set_device (gpu_id )
89+ else :
90+ device_str = "cpu"
91+
92+ # 加载模型
93+ model = YOLOv10 (self .model_path )
94+
95+ for current_img_path in img_list :
96+ try :
97+ # 1)推理
98+ dets = model .predict (current_img_path , imgsz = imgsz , conf = conf_thres , device = device_str )
99+ # 2)读取原图
100+ img = cv2 .imread (current_img_path )
101+ if img is None :
102+ self .logger .error (f"无法读取图片: { current_img_path } " )
103+ continue
104+
105+ result = dets [0 ]
106+ boxes = result .boxes
107+ name_map = result .names
108+
109+ detections = []
110+ existing_boxes = []
111+ # 收集所有 box 坐标
112+ if boxes is not None and len (boxes ) > 0 :
113+ for box in boxes :
114+ x1 , y1 , x2 , y2 = box .xyxy [0 ].cpu ().numpy ()
115+ existing_boxes .append ([int (x1 ), int (y1 ), int (x2 ), int (y2 )])
116+
117+ # 画框 + 标签
118+ cls_count = defaultdict (int )
119+ for box in boxes :
120+ x1 , y1 , x2 , y2 = box .xyxy [0 ].cpu ().numpy ()
121+ x1 , y1 , x2 , y2 = map (int , (x1 , y1 , x2 , y2 ))
122+ conf = float (box .conf [0 ].cpu ().numpy ())
123+ cid = int (box .cls [0 ].cpu ().numpy ())
124+ cname = name_map [cid ]
125+ cls_count [cname ] += 1
126+ label = f"{ cname } { cls_count [cname ]} "
127+
128+ # 矩形框
129+ cv2 .rectangle (img , (x1 , y1 ), (x2 , y2 ), (0 , 255 , 0 ), 3 )
130+ # 自适应标签
131+ pos_type = self .draw_adaptive_label (img , x1 , y1 , x2 , y2 , label , existing_boxes ,
132+ fs = 0.8 , ft = 2 )
133+ detections .append ({
134+ "id" : label ,
135+ "class_name" : cname ,
136+ "confidence" : conf ,
137+ "bbox" : [x1 , y1 , x2 , y2 ],
138+ "label_position" : pos_type
139+ })
140+
141+ # 3)保存图像与 JSON
142+ base_name = os .path .splitext (os .path .basename (current_img_path ))[0 ]
143+ out_img_path = os .path .join (output_img_dir , f"{ prefix } _{ base_name } .jpg" )
144+ out_json_path = os .path .join (output_json_dir , f"{ prefix } _{ base_name } .json" )
145+
146+ # 确保输出目录存在
147+ os .makedirs (output_img_dir , exist_ok = True )
148+ os .makedirs (output_json_dir , exist_ok = True )
149+
150+ # 保存图片
151+ success = cv2 .imwrite (out_img_path , img )
152+ if not success :
153+ self .logger .error (f"保存图片失败: { out_img_path } " )
154+
155+ # 保存JSON
156+ with open (out_json_path , 'w' , encoding = 'utf-8' ) as f :
157+ json .dump ({
158+ "image_path" : current_img_path ,
159+ "total_detections" : len (detections ),
160+ "detections" : detections
161+ }, f , ensure_ascii = False , indent = 2 )
162+
163+ self .logger .info (f"处理完成: { current_img_path } -> { out_img_path } , { out_json_path } " )
164+
165+ except Exception as e :
166+ self .logger .error (f"处理图片 { current_img_path } 时出错: { str (e )} " )
167+
168+ def batch_process (self , image_paths , output_folder , output_prefix , imgsz = 1024 , conf_thres = 0.2 ):
169+ """
170+ 批量处理接口:
171+ image_paths: List[str] 待处理图片路径列表
172+ output_folder: str 输出文件夹
173+ output_prefix: str 输出图片/JSON 的前缀
174+ 其余参数为可选模型配置
175+ """
176+ os .makedirs (output_folder , exist_ok = True )
177+
178+ # 创建输出子目录
179+ img_output_dir = os .path .join (output_folder , "images" )
180+ json_output_dir = os .path .join (output_folder , "json" )
181+ os .makedirs (img_output_dir , exist_ok = True )
182+ os .makedirs (json_output_dir , exist_ok = True )
183+
184+ # 可用的 GPU 列表
185+ ngpu = torch .cuda .device_count ()
186+ if ngpu == 0 :
187+ # 如果没有 GPU,则当成 1 个 worker 在 CPU 上跑
188+ gpu_list = ["" ]
189+ else :
190+ gpu_list = list (range (ngpu ))
191+
192+ # 将 image_paths 均匀切分给每个 GPU/worker
193+ chunks = []
194+ n = len (image_paths )
195+ k = len (gpu_list )
196+ per = math .ceil (n / k ) if k > 0 else n
197+ for i in range (k ):
198+ start_idx = i * per
199+ end_idx = min ((i + 1 ) * per , n )
200+ if start_idx < end_idx :
201+ chunks .append (image_paths [start_idx :end_idx ])
202+
203+ # 启动多进程
204+ procs = []
205+ for i , (gpu_id , img_chunk ) in enumerate (zip (gpu_list , chunks )):
206+ if not img_chunk :
207+ continue
208+
209+ p = multiprocessing .Process (
210+ target = self .worker_process ,
211+ args = (img_chunk , img_output_dir , json_output_dir , output_prefix ,
212+ gpu_id , imgsz , conf_thres )
213+ )
214+ p .start ()
215+ procs .append (p )
216+
217+ for p in procs :
218+ p .join ()
219+
220+ def run (self , input_image_folder : str , output_folder : str , output_prefix : str = "doclay" ):
221+ os .makedirs (output_folder , exist_ok = True )
222+ # 获取所有图片路径,确保扩展名为jpg or png
223+ image_list = []
224+ for f in os .listdir (input_image_folder ):
225+ if f .lower ().endswith (('.jpg' , '.png' , '.jpeg' )):
226+ image_list .append (os .path .join (input_image_folder , f ))
227+
228+ # 确保图片路径存在
229+ valid_image_list = []
230+ for img_path in image_list :
231+ if os .path .exists (img_path ):
232+ valid_image_list .append (img_path )
233+ else :
234+ self .logger .warning (f"图片路径不存在: { img_path } " )
235+
236+ if not valid_image_list :
237+ self .logger .warning ("没有找到有效的图片文件" )
238+ return
239+
240+ # 批量处理
241+ self .batch_process (
242+ image_paths = valid_image_list ,
243+ output_folder = output_folder ,
244+ output_prefix = output_prefix ,
245+ imgsz = 1024 ,
246+ conf_thres = 0.2
247+ )
0 commit comments