@@ -2257,3 +2257,398 @@ cv2.destroyAllWindows()
22572257```
22582258 </TabItem >
22592259</Tabs >
2260+
2261+ ### PT格式转OM格式
2262+
2263+ pt可以转为onnx再转为om,om格式是国产芯片的推理格式。使用OM格式的模型推理代码如下:
2264+
2265+ ``` python showLineNumbers
2266+ import sys
2267+ import os
2268+ import cv2
2269+ import time
2270+ import numpy as np
2271+ from typing import List, Dict, Tuple
2272+ import argparse
2273+ import subprocess
2274+ import tempfile
2275+ import shutil
2276+
2277+ try :
2278+ import ais_bench
2279+ AISBENCH_AVAILABLE = True
2280+ except ImportError :
2281+ AISBENCH_AVAILABLE = False
2282+ print (" 警告: ais_bench 未安装,OM 模型推理将不可用。" )
2283+
2284+ # ==========================================
2285+ # YOLOv8 基础逻辑 (预处理、NMS、后处理)
2286+ # ==========================================
2287+ class BaseYOLOv8Logic :
2288+ """ YOLOv8 的核心计算逻辑"""
2289+ def __init__ (self , conf_threshold : float = 0.25 , iou_threshold : float = 0.5 ):
2290+ self .conf_threshold = conf_threshold
2291+ self .iou_threshold = iou_threshold
2292+ self .target_h = 640
2293+ self .target_w = 640
2294+
2295+ # 类别定义
2296+ self .vehicle_classes = [3 , 4 , 5 , 8 , 9 ]
2297+ self .class_names = {
2298+ 0 : ' pedestrian' , 1 : ' people' , 2 : ' bicycle' , 3 : ' car' ,
2299+ 4 : ' van' , 5 : ' truck' , 6 : ' tricycle' , 7 : ' awning-tricycle' ,
2300+ 8 : ' bus' , 9 : ' motor' , 10 : ' others'
2301+ }
2302+
2303+ def _preprocess (self , frame : np.ndarray) -> Tuple[np.ndarray, float , float , Tuple[int , int ]]:
2304+ """ 预处理图像"""
2305+ h, w = frame.shape[:2 ]
2306+ scale = min (self .target_h / h, self .target_w / w)
2307+ new_h, new_w = int (h * scale), int (w * scale)
2308+
2309+ # 注意:必须使用 (new_w, new_h) 避免维度不匹配错误
2310+ img_resized = cv2.resize(frame, (new_w, new_h), interpolation = cv2.INTER_LINEAR )
2311+
2312+ img_padded = np.full((self .target_h, self .target_w, 3 ), 114 , dtype = np.uint8)
2313+
2314+ pad_h = (self .target_h - new_h) // 2
2315+ pad_w = (self .target_w - new_w) // 2
2316+ img_padded[pad_h:pad_h+ new_h, pad_w:pad_w+ new_w] = img_resized
2317+
2318+ img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_BGR2RGB )
2319+ img_normalized = img_rgb.astype(np.float32) / 255.0
2320+ img_nchw = np.transpose(img_normalized, (2 , 0 , 1 ))
2321+ img_batch = np.expand_dims(img_nchw, axis = 0 )
2322+
2323+ scale_w = new_w / w
2324+ scale_h = new_h / h
2325+
2326+ return img_batch, scale_w, scale_h, (pad_w, pad_h)
2327+
2328+ def _apply_nms (self , boxes : np.ndarray, scores : np.ndarray) -> np.ndarray:
2329+ """ 非极大值抑制"""
2330+ if len (boxes) == 0 : return np.array([])
2331+
2332+ x1, y1, x2, y2 = boxes[:, 0 ], boxes[:, 1 ], boxes[:, 2 ], boxes[:, 3 ]
2333+ areas = (x2 - x1 + 1 ) * (y2 - y1 + 1 )
2334+ order = scores.argsort()[::- 1 ]
2335+
2336+ keep = []
2337+ while order.size > 0 :
2338+ i = order[0 ]
2339+ keep.append(i)
2340+ if order.size == 1 : break
2341+
2342+ xx1 = np.maximum(x1[i], x1[order[1 :]])
2343+ yy1 = np.maximum(y1[i], y1[order[1 :]])
2344+ xx2 = np.minimum(x2[i], x2[order[1 :]])
2345+ yy2 = np.minimum(y2[i], y2[order[1 :]])
2346+
2347+ w = np.maximum(0.0 , xx2 - xx1 + 1 )
2348+ h = np.maximum(0.0 , yy2 - yy1 + 1 )
2349+ inter = w * h
2350+
2351+ iou = inter / (areas[i] + areas[order[1 :]] - inter)
2352+ inds = np.where(iou <= self .iou_threshold)[0 ]
2353+ order = order[inds + 1 ]
2354+
2355+ return np.array(keep)
2356+
2357+ def _postprocess (self , output : np.ndarray,
2358+ scale_w : float , scale_h : float ,
2359+ pad_offset : Tuple[int , int ],
2360+ original_shape : Tuple[int , int ]) -> List[Dict]:
2361+ """ 后处理模型输出"""
2362+ detections = []
2363+ if output.ndim == 3 :
2364+ output = np.transpose(output, (0 , 2 , 1 ))
2365+ if output.shape[0 ] == 1 :
2366+ output = output[0 ]
2367+
2368+ coords = output[:, :4 ]
2369+ class_scores = output[:, 4 :]
2370+
2371+ confidences = np.max(class_scores, axis = 1 )
2372+ class_ids = np.argmax(class_scores, axis = 1 )
2373+
2374+ valid_mask = confidences >= self .conf_threshold
2375+ valid_coords = coords[valid_mask]
2376+ valid_ids = class_ids[valid_mask]
2377+ valid_confs = confidences[valid_mask]
2378+
2379+ if len (valid_coords) > 0 :
2380+ boxes = np.zeros_like(valid_coords)
2381+ boxes[:, 0 ] = valid_coords[:, 0 ] - valid_coords[:, 2 ] / 2
2382+ boxes[:, 1 ] = valid_coords[:, 1 ] - valid_coords[:, 3 ] / 2
2383+ boxes[:, 2 ] = valid_coords[:, 0 ] + valid_coords[:, 2 ] / 2
2384+ boxes[:, 3 ] = valid_coords[:, 1 ] + valid_coords[:, 3 ] / 2
2385+ boxes = np.clip(boxes, 0 , self .target_w)
2386+
2387+ keep = self ._apply_nms(boxes, valid_confs)
2388+
2389+ pad_w, pad_h = pad_offset
2390+ oh, ow = original_shape[:2 ]
2391+
2392+ for idx in keep:
2393+ x1, y1, x2, y2 = boxes[idx]
2394+ x1 -= pad_w; y1 -= pad_h; x2 -= pad_w; y2 -= pad_h
2395+
2396+ unp_h = self .target_h - 2 * pad_h if pad_h > 0 else self .target_h
2397+ unp_w = self .target_w - 2 * pad_w if pad_w > 0 else self .target_w
2398+
2399+ if unp_h > 0 and unp_w > 0 :
2400+ x1 = x1 / unp_w * ow
2401+ y1 = y1 / unp_h * oh
2402+ x2 = x2 / unp_w * ow
2403+ y2 = y2 / unp_h * oh
2404+
2405+ x1 = max (0 , min (ow, x1))
2406+ y1 = max (0 , min (oh, y1))
2407+ x2 = max (0 , min (ow, x2))
2408+ y2 = max (0 , min (oh, y2))
2409+
2410+ detections.append({
2411+ ' bbox' : [float (x1), float (y1), float (x2), float (y2)],
2412+ ' confidence' : float (valid_confs[idx]),
2413+ ' class_id' : int (valid_ids[idx]),
2414+ ' class_name' : self .class_names.get(valid_ids[idx], f ' Class_ { valid_ids[idx]} ' )
2415+ })
2416+ return detections
2417+
2418+
2419+ # ==========================================
2420+ # OM 模型推理类
2421+ # ==========================================
2422+ class OMInference (BaseYOLOv8Logic ):
2423+ def __init__ (self , model_path : str , device_id : int = 0 , conf_threshold : float = 0.25 , iou_threshold : float = 0.5 ):
2424+ super ().__init__ (conf_threshold, iou_threshold)
2425+ self .model_path = model_path
2426+ self .device_id = device_id
2427+ self .model_type = " OM"
2428+
2429+ if not os.path.exists(model_path):
2430+ raise FileNotFoundError (f " 模型不存在: { model_path} " )
2431+
2432+ self ._check_and_fix_permissions(model_path)
2433+
2434+ self .use_python_api = False
2435+ self .session = None
2436+
2437+ if AISBENCH_AVAILABLE :
2438+ try :
2439+ from ais_bench.infer.interface import InferSession
2440+ self .session = InferSession(device_id, model_path)
2441+ self .use_python_api = True
2442+ print (f " [ { self .model_type} ] ✓ 使用 Python API 加载成功 " )
2443+ except Exception as e:
2444+ print (f " [ { self .model_type} ] API 加载失败: { e} " )
2445+
2446+ self .temp_dir = tempfile.mkdtemp(prefix = " om_infer_" )
2447+ os.environ[' QT_QPA_PLATFORM' ] = ' offscreen'
2448+
2449+ def _check_and_fix_permissions (self , path : str ):
2450+ """ 修复 root 用户权限问题"""
2451+ try :
2452+ stat_info = os.stat(path)
2453+ if os.getuid() == 0 and stat_info.st_uid != 0 :
2454+ print (f " [ { self .model_type} ] 检测到权限问题,尝试修复... " )
2455+ try :
2456+ os.chown(path, 0 , 0 )
2457+ if os.path.dirname(path): os.chown(os.path.dirname(path), 0 , 0 )
2458+ print (f " [ { self .model_type} ] ✓ 权限修复完成 " )
2459+ except Exception :
2460+ pass
2461+ except Exception :
2462+ pass
2463+
2464+ def warmup (self , warmup_iterations : int = 3 ):
2465+ """ 预热"""
2466+ print (f " [ { self .model_type} ] 预热模型 ( { warmup_iterations} 次)... " )
2467+ dummy_img = np.random.randint(0 , 255 , (640 , 640 , 3 ), dtype = np.uint8)
2468+ for _ in range (warmup_iterations):
2469+ self .predict(dummy_img)
2470+ print (f " [ { self .model_type} ] ✓ 预热完成 " )
2471+
2472+ def predict (self , frame : np.ndarray) -> Tuple[List[Dict], int , float ]:
2473+ """ 单帧推理"""
2474+ # 1. 预处理
2475+ img_batch, scale_w, scale_h, pad_offset = self ._preprocess(frame)
2476+
2477+ # 2. 推理
2478+ start_time = time.time()
2479+ output = np.array([])
2480+
2481+ if self .use_python_api:
2482+ try :
2483+ out = self .session.infer([img_batch])
2484+ if isinstance (out, list ) and len (out) > 0 : output = out[0 ]
2485+ except Exception as e:
2486+ print (f " 推理错误: { e} " )
2487+ else :
2488+ # CLI 回退
2489+ temp_input = os.path.join(self .temp_dir, " input.npy" )
2490+ np.save(temp_input, img_batch)
2491+ output_dir = os.path.join(self .temp_dir, " output" )
2492+ if os.path.exists(output_dir): shutil.rmtree(output_dir)
2493+ os.makedirs(output_dir)
2494+
2495+ cmd = [" python3" , " -m" , " ais_bench" , " --model" , self .model_path, " --input" , temp_input,
2496+ " --output" , output_dir, " --outfmt" , " NPY" , " --device" , str (self .device_id), " --loop" , " 1" ]
2497+ res = subprocess.run(cmd, capture_output = True , text = True )
2498+ if res.returncode == 0 :
2499+ for f in os.listdir(output_dir):
2500+ if f.endswith(' .npy' ):
2501+ output = np.load(os.path.join(output_dir, f))
2502+ break
2503+
2504+ inference_time = (time.time() - start_time) * 1000
2505+
2506+ # 3. 后处理
2507+ detections = []
2508+ if output.size > 0 :
2509+ detections = self ._postprocess(output, scale_w, scale_h, pad_offset, frame.shape)
2510+
2511+ vehicle_count = sum (1 for d in detections if d[' class_id' ] in self .vehicle_classes)
2512+ return detections, vehicle_count, inference_time
2513+
2514+ def draw_detections (self , frame : np.ndarray, detections : List[Dict]) -> np.ndarray:
2515+ """ 绘制结果"""
2516+ img = frame.copy()
2517+ for det in detections:
2518+ x1, y1, x2, y2 = map (int , det[' bbox' ])
2519+ cid = det[' class_id' ]
2520+ color = (0 , 255 , 0 ) if cid in self .vehicle_classes else (0 , 0 , 255 )
2521+
2522+ cv2.rectangle(img, (x1, y1), (x2, y2), color, 2 )
2523+ label = f " { det[' class_name' ]} { det[' confidence' ]:.2f } "
2524+ (tw, th), _ = cv2.getTextSize(label, 0 , 0.5 , 1 )
2525+ cv2.rectangle(img, (x1, y1 - th - 5 ), (x1 + tw, y1), color, - 1 )
2526+ cv2.putText(img, label, (x1, y1 - 5 ), 0 , 0.5 , (255 , 255 , 255 ), 1 )
2527+ return img
2528+
2529+
2530+ # ==========================================
2531+ # 主程序
2532+ # ==========================================
2533+ def main ():
2534+ parser = argparse.ArgumentParser(description = " Ascend OM 模型推理工具 (调试增强版)" )
2535+ parser.add_argument(" --model" , type = str , required = True , help = " OM 模型路径" )
2536+ parser.add_argument(" --input" , type = str , required = True , help = " 输入路径 (图片或视频)" )
2537+ parser.add_argument(" --output-dir" , type = str , default = " ./om_output_debug" , help = " 结果保存目录" )
2538+ parser.add_argument(" --device-id" , type = int , default = 0 , help = " Ascend 设备 ID" )
2539+ parser.add_argument(" --conf" , type = float , default = 0.25 , help = " 置信度阈值" )
2540+ parser.add_argument(" --iou" , type = float , default = 0.5 , help = " NMS IoU 阈值" )
2541+ parser.add_argument(" --max-frames" , type = int , default = None , help = " 视频最大处理帧数 (None=全部)" )
2542+ parser.add_argument(" --verbose" , action = " store_true" , help = " 开启详细打印:输出每一帧的每个检测框坐标" )
2543+
2544+ args = parser.parse_args()
2545+
2546+ print (" \n " + " =" * 60 )
2547+ print (" Ascend OM 模型推理工具 (调试增强版)" )
2548+ print (" =" * 60 )
2549+
2550+ # 初始化
2551+ try :
2552+ inferencer = OMInference(args.model, args.device_id, args.conf, args.iou)
2553+ inferencer.warmup()
2554+ except Exception as e:
2555+ print (f " 初始化失败: { e} " )
2556+ return
2557+
2558+ os.makedirs(args.output_dir, exist_ok = True )
2559+
2560+ # 判断输入类型
2561+ if args.input.lower().endswith((' .jpg' , ' .png' , ' .jpeg' , ' .bmp' )):
2562+ print (f " \n 处理图片: { args.input} " )
2563+ frame = cv2.imread(args.input)
2564+ if frame is None :
2565+ print (" 读取图片失败" )
2566+ return
2567+
2568+ dets, count, time_cost = inferencer.predict(frame)
2569+ res_img = inferencer.draw_detections(frame, dets)
2570+
2571+ out_name = os.path.basename(args.input)
2572+ out_path = os.path.join(args.output_dir, f " res_ { out_name} " )
2573+ cv2.imwrite(out_path, res_img)
2574+
2575+ print (f " 推理完成: { len (dets)} 个对象, { count} 个车辆, 耗时 { time_cost:.1f } ms " )
2576+ print (f " 结果已保存: { out_path} " )
2577+
2578+ if args.verbose:
2579+ print (" \n [详细检测结果]" )
2580+ for i, det in enumerate (dets):
2581+ print (f " { i+ 1 } . { det[' class_name' ]} ( { det[' confidence' ]:.2f } ) - { det[' bbox' ]} " )
2582+
2583+ elif args.input.lower().endswith((' .mp4' , ' .avi' , ' .mov' , ' .mkv' )):
2584+ print (f " \n 处理视频: { args.input} " )
2585+ cap = cv2.VideoCapture(args.input)
2586+ if not cap.isOpened():
2587+ print (" 打开视频失败" )
2588+ return
2589+
2590+ video_name = os.path.splitext(os.path.basename(args.input))[0 ]
2591+ # 创建子目录专门存放该视频的帧
2592+ frame_dir = os.path.join(args.output_dir, f " { video_name} _frames " )
2593+ os.makedirs(frame_dir, exist_ok = True )
2594+
2595+ total_frames = int (cap.get(cv2.CAP_PROP_FRAME_COUNT ))
2596+ f_count = 0
2597+ total_time = 0
2598+
2599+ print (f " 总帧数: { total_frames} " )
2600+ print (f " 保存目录: { frame_dir} " )
2601+ print (f " 调试模式: { ' 开启 (打印所有检测框)' if args.verbose else ' 关闭 (仅打印摘要)' } " )
2602+ print (" -" * 80 )
2603+
2604+ while True :
2605+ if args.max_frames and f_count >= args.max_frames:
2606+ print (f " \n [完成] 达到最大帧数限制: { args.max_frames} " )
2607+ break
2608+
2609+ ret, frame = cap.read()
2610+ if not ret:
2611+ print (" \n [完成] 视频处理完毕或读取结束" )
2612+ break
2613+
2614+ f_count += 1
2615+
2616+ # 推理
2617+ dets, count, time_cost = inferencer.predict(frame)
2618+ total_time += time_cost
2619+
2620+ # 绘制并保存 (每一帧都保存)
2621+ res_img = inferencer.draw_detections(frame, dets)
2622+ save_name = f " frame_ { f_count:05d } .jpg "
2623+ save_path = os.path.join(frame_dir, save_name)
2624+ cv2.imwrite(save_path, res_img)
2625+
2626+ # 打印每一帧的调试信息
2627+ # 格式: [帧号/总帧] 耗时 检测数 车辆数 保存路径
2628+ print (f " [Frame { f_count:04d } / { total_frames} ] Time: { time_cost:6.2f } ms | "
2629+ f " Detections: { len (dets):2d } | Vehicles: { count:2d } | "
2630+ f " Saved: { save_name} " )
2631+
2632+ # 如果开启 verbose,打印每个检测框的详情
2633+ if args.verbose:
2634+ for i, det in enumerate (dets):
2635+ bbox_str = f " [ { det[' bbox' ][0 ]:.0f } , { det[' bbox' ][1 ]:.0f } , { det[' bbox' ][2 ]:.0f } , { det[' bbox' ][3 ]:.0f } ] "
2636+ print (f " -> [ { i+ 1 } ] { det[' class_name' ]:15s } conf: { det[' confidence' ]:.2f } bbox: { bbox_str} " )
2637+
2638+ cap.release()
2639+ avg_time = total_time / f_count if f_count > 0 else 0
2640+ print (" -" * 80 )
2641+ print (f " \n 视频处理统计: " )
2642+ print (f " 处理帧数: { f_count} " )
2643+ print (f " 总耗时: { total_time:.1f } ms " )
2644+ print (f " 平均耗时: { avg_time:.1f } ms " )
2645+ print (f " 平均FPS: { 1000 / avg_time:.1f } " if avg_time > 0 else " 平均FPS: N/A" )
2646+ print (f " 结果保存在: { frame_dir} " )
2647+
2648+ else :
2649+ print (" 不支持的文件格式" )
2650+
2651+
2652+ if __name__ == " __main__" :
2653+ main()
2654+ ```
0 commit comments