Skip to content

Commit bfd2c6e

Browse files
添加OM格式推理代码
1 parent 6632a42 commit bfd2c6e

1 file changed

Lines changed: 395 additions & 0 deletions

File tree

docs/docs/深度学习/opencv.mdx

Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,3 +2257,398 @@ cv2.destroyAllWindows()
22572257
```
22582258
</TabItem>
22592259
</Tabs>
2260+
2261+
### PT格式转OM格式
2262+
2263+
pt可以转为onnx再转为om,om格式是国产芯片的推理格式。使用OM格式的模型推理代码如下:
2264+
2265+
```python showLineNumbers
2266+
import sys
2267+
import os
2268+
import cv2
2269+
import time
2270+
import numpy as np
2271+
from typing import List, Dict, Tuple
2272+
import argparse
2273+
import subprocess
2274+
import tempfile
2275+
import shutil
2276+
2277+
try:
2278+
import ais_bench
2279+
AISBENCH_AVAILABLE = True
2280+
except ImportError:
2281+
AISBENCH_AVAILABLE = False
2282+
print("警告: ais_bench 未安装,OM 模型推理将不可用。")
2283+
2284+
# ==========================================
2285+
# YOLOv8 基础逻辑 (预处理、NMS、后处理)
2286+
# ==========================================
2287+
class BaseYOLOv8Logic:
2288+
"""YOLOv8 的核心计算逻辑"""
2289+
def __init__(self, conf_threshold: float = 0.25, iou_threshold: float = 0.5):
2290+
self.conf_threshold = conf_threshold
2291+
self.iou_threshold = iou_threshold
2292+
self.target_h = 640
2293+
self.target_w = 640
2294+
2295+
# 类别定义
2296+
self.vehicle_classes = [3, 4, 5, 8, 9]
2297+
self.class_names = {
2298+
0: 'pedestrian', 1: 'people', 2: 'bicycle', 3: 'car',
2299+
4: 'van', 5: 'truck', 6: 'tricycle', 7: 'awning-tricycle',
2300+
8: 'bus', 9: 'motor', 10: 'others'
2301+
}
2302+
2303+
def _preprocess(self, frame: np.ndarray) -> Tuple[np.ndarray, float, float, Tuple[int, int]]:
2304+
"""预处理图像"""
2305+
h, w = frame.shape[:2]
2306+
scale = min(self.target_h / h, self.target_w / w)
2307+
new_h, new_w = int(h * scale), int(w * scale)
2308+
2309+
# 注意:必须使用 (new_w, new_h) 避免维度不匹配错误
2310+
img_resized = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
2311+
2312+
img_padded = np.full((self.target_h, self.target_w, 3), 114, dtype=np.uint8)
2313+
2314+
pad_h = (self.target_h - new_h) // 2
2315+
pad_w = (self.target_w - new_w) // 2
2316+
img_padded[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img_resized
2317+
2318+
img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_BGR2RGB)
2319+
img_normalized = img_rgb.astype(np.float32) / 255.0
2320+
img_nchw = np.transpose(img_normalized, (2, 0, 1))
2321+
img_batch = np.expand_dims(img_nchw, axis=0)
2322+
2323+
scale_w = new_w / w
2324+
scale_h = new_h / h
2325+
2326+
return img_batch, scale_w, scale_h, (pad_w, pad_h)
2327+
2328+
def _apply_nms(self, boxes: np.ndarray, scores: np.ndarray) -> np.ndarray:
2329+
"""非极大值抑制"""
2330+
if len(boxes) == 0: return np.array([])
2331+
2332+
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
2333+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
2334+
order = scores.argsort()[::-1]
2335+
2336+
keep = []
2337+
while order.size > 0:
2338+
i = order[0]
2339+
keep.append(i)
2340+
if order.size == 1: break
2341+
2342+
xx1 = np.maximum(x1[i], x1[order[1:]])
2343+
yy1 = np.maximum(y1[i], y1[order[1:]])
2344+
xx2 = np.minimum(x2[i], x2[order[1:]])
2345+
yy2 = np.minimum(y2[i], y2[order[1:]])
2346+
2347+
w = np.maximum(0.0, xx2 - xx1 + 1)
2348+
h = np.maximum(0.0, yy2 - yy1 + 1)
2349+
inter = w * h
2350+
2351+
iou = inter / (areas[i] + areas[order[1:]] - inter)
2352+
inds = np.where(iou <= self.iou_threshold)[0]
2353+
order = order[inds + 1]
2354+
2355+
return np.array(keep)
2356+
2357+
def _postprocess(self, output: np.ndarray,
2358+
scale_w: float, scale_h: float,
2359+
pad_offset: Tuple[int, int],
2360+
original_shape: Tuple[int, int]) -> List[Dict]:
2361+
"""后处理模型输出"""
2362+
detections = []
2363+
if output.ndim == 3:
2364+
output = np.transpose(output, (0, 2, 1))
2365+
if output.shape[0] == 1:
2366+
output = output[0]
2367+
2368+
coords = output[:, :4]
2369+
class_scores = output[:, 4:]
2370+
2371+
confidences = np.max(class_scores, axis=1)
2372+
class_ids = np.argmax(class_scores, axis=1)
2373+
2374+
valid_mask = confidences >= self.conf_threshold
2375+
valid_coords = coords[valid_mask]
2376+
valid_ids = class_ids[valid_mask]
2377+
valid_confs = confidences[valid_mask]
2378+
2379+
if len(valid_coords) > 0:
2380+
boxes = np.zeros_like(valid_coords)
2381+
boxes[:, 0] = valid_coords[:, 0] - valid_coords[:, 2] / 2
2382+
boxes[:, 1] = valid_coords[:, 1] - valid_coords[:, 3] / 2
2383+
boxes[:, 2] = valid_coords[:, 0] + valid_coords[:, 2] / 2
2384+
boxes[:, 3] = valid_coords[:, 1] + valid_coords[:, 3] / 2
2385+
boxes = np.clip(boxes, 0, self.target_w)
2386+
2387+
keep = self._apply_nms(boxes, valid_confs)
2388+
2389+
pad_w, pad_h = pad_offset
2390+
oh, ow = original_shape[:2]
2391+
2392+
for idx in keep:
2393+
x1, y1, x2, y2 = boxes[idx]
2394+
x1 -= pad_w; y1 -= pad_h; x2 -= pad_w; y2 -= pad_h
2395+
2396+
unp_h = self.target_h - 2 * pad_h if pad_h > 0 else self.target_h
2397+
unp_w = self.target_w - 2 * pad_w if pad_w > 0 else self.target_w
2398+
2399+
if unp_h > 0 and unp_w > 0:
2400+
x1 = x1 / unp_w * ow
2401+
y1 = y1 / unp_h * oh
2402+
x2 = x2 / unp_w * ow
2403+
y2 = y2 / unp_h * oh
2404+
2405+
x1 = max(0, min(ow, x1))
2406+
y1 = max(0, min(oh, y1))
2407+
x2 = max(0, min(ow, x2))
2408+
y2 = max(0, min(oh, y2))
2409+
2410+
detections.append({
2411+
'bbox': [float(x1), float(y1), float(x2), float(y2)],
2412+
'confidence': float(valid_confs[idx]),
2413+
'class_id': int(valid_ids[idx]),
2414+
'class_name': self.class_names.get(valid_ids[idx], f'Class_{valid_ids[idx]}')
2415+
})
2416+
return detections
2417+
2418+
2419+
# ==========================================
2420+
# OM 模型推理类
2421+
# ==========================================
2422+
class OMInference(BaseYOLOv8Logic):
2423+
def __init__(self, model_path: str, device_id: int = 0, conf_threshold: float = 0.25, iou_threshold: float = 0.5):
2424+
super().__init__(conf_threshold, iou_threshold)
2425+
self.model_path = model_path
2426+
self.device_id = device_id
2427+
self.model_type = "OM"
2428+
2429+
if not os.path.exists(model_path):
2430+
raise FileNotFoundError(f"模型不存在: {model_path}")
2431+
2432+
self._check_and_fix_permissions(model_path)
2433+
2434+
self.use_python_api = False
2435+
self.session = None
2436+
2437+
if AISBENCH_AVAILABLE:
2438+
try:
2439+
from ais_bench.infer.interface import InferSession
2440+
self.session = InferSession(device_id, model_path)
2441+
self.use_python_api = True
2442+
print(f"[{self.model_type}] ✓ 使用 Python API 加载成功")
2443+
except Exception as e:
2444+
print(f"[{self.model_type}] API 加载失败: {e}")
2445+
2446+
self.temp_dir = tempfile.mkdtemp(prefix="om_infer_")
2447+
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
2448+
2449+
def _check_and_fix_permissions(self, path: str):
2450+
"""修复 root 用户权限问题"""
2451+
try:
2452+
stat_info = os.stat(path)
2453+
if os.getuid() == 0 and stat_info.st_uid != 0:
2454+
print(f"[{self.model_type}] 检测到权限问题,尝试修复...")
2455+
try:
2456+
os.chown(path, 0, 0)
2457+
if os.path.dirname(path): os.chown(os.path.dirname(path), 0, 0)
2458+
print(f"[{self.model_type}] ✓ 权限修复完成")
2459+
except Exception:
2460+
pass
2461+
except Exception:
2462+
pass
2463+
2464+
def warmup(self, warmup_iterations: int = 3):
2465+
"""预热"""
2466+
print(f"[{self.model_type}] 预热模型 ({warmup_iterations}次)...")
2467+
dummy_img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
2468+
for _ in range(warmup_iterations):
2469+
self.predict(dummy_img)
2470+
print(f"[{self.model_type}] ✓ 预热完成")
2471+
2472+
def predict(self, frame: np.ndarray) -> Tuple[List[Dict], int, float]:
2473+
"""单帧推理"""
2474+
# 1. 预处理
2475+
img_batch, scale_w, scale_h, pad_offset = self._preprocess(frame)
2476+
2477+
# 2. 推理
2478+
start_time = time.time()
2479+
output = np.array([])
2480+
2481+
if self.use_python_api:
2482+
try:
2483+
out = self.session.infer([img_batch])
2484+
if isinstance(out, list) and len(out) > 0: output = out[0]
2485+
except Exception as e:
2486+
print(f"推理错误: {e}")
2487+
else:
2488+
# CLI 回退
2489+
temp_input = os.path.join(self.temp_dir, "input.npy")
2490+
np.save(temp_input, img_batch)
2491+
output_dir = os.path.join(self.temp_dir, "output")
2492+
if os.path.exists(output_dir): shutil.rmtree(output_dir)
2493+
os.makedirs(output_dir)
2494+
2495+
cmd = ["python3", "-m", "ais_bench", "--model", self.model_path, "--input", temp_input,
2496+
"--output", output_dir, "--outfmt", "NPY", "--device", str(self.device_id), "--loop", "1"]
2497+
res = subprocess.run(cmd, capture_output=True, text=True)
2498+
if res.returncode == 0:
2499+
for f in os.listdir(output_dir):
2500+
if f.endswith('.npy'):
2501+
output = np.load(os.path.join(output_dir, f))
2502+
break
2503+
2504+
inference_time = (time.time() - start_time) * 1000
2505+
2506+
# 3. 后处理
2507+
detections = []
2508+
if output.size > 0:
2509+
detections = self._postprocess(output, scale_w, scale_h, pad_offset, frame.shape)
2510+
2511+
vehicle_count = sum(1 for d in detections if d['class_id'] in self.vehicle_classes)
2512+
return detections, vehicle_count, inference_time
2513+
2514+
def draw_detections(self, frame: np.ndarray, detections: List[Dict]) -> np.ndarray:
2515+
"""绘制结果"""
2516+
img = frame.copy()
2517+
for det in detections:
2518+
x1, y1, x2, y2 = map(int, det['bbox'])
2519+
cid = det['class_id']
2520+
color = (0, 255, 0) if cid in self.vehicle_classes else (0, 0, 255)
2521+
2522+
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
2523+
label = f"{det['class_name']} {det['confidence']:.2f}"
2524+
(tw, th), _ = cv2.getTextSize(label, 0, 0.5, 1)
2525+
cv2.rectangle(img, (x1, y1 - th - 5), (x1 + tw, y1), color, -1)
2526+
cv2.putText(img, label, (x1, y1 - 5), 0, 0.5, (255, 255, 255), 1)
2527+
return img
2528+
2529+
2530+
# ==========================================
2531+
# 主程序
2532+
# ==========================================
2533+
def main():
2534+
parser = argparse.ArgumentParser(description="Ascend OM 模型推理工具 (调试增强版)")
2535+
parser.add_argument("--model", type=str, required=True, help="OM 模型路径")
2536+
parser.add_argument("--input", type=str, required=True, help="输入路径 (图片或视频)")
2537+
parser.add_argument("--output-dir", type=str, default="./om_output_debug", help="结果保存目录")
2538+
parser.add_argument("--device-id", type=int, default=0, help="Ascend 设备 ID")
2539+
parser.add_argument("--conf", type=float, default=0.25, help="置信度阈值")
2540+
parser.add_argument("--iou", type=float, default=0.5, help="NMS IoU 阈值")
2541+
parser.add_argument("--max-frames", type=int, default=None, help="视频最大处理帧数 (None=全部)")
2542+
parser.add_argument("--verbose", action="store_true", help="开启详细打印:输出每一帧的每个检测框坐标")
2543+
2544+
args = parser.parse_args()
2545+
2546+
print("\n" + "="*60)
2547+
print("Ascend OM 模型推理工具 (调试增强版)")
2548+
print("="*60)
2549+
2550+
# 初始化
2551+
try:
2552+
inferencer = OMInference(args.model, args.device_id, args.conf, args.iou)
2553+
inferencer.warmup()
2554+
except Exception as e:
2555+
print(f"初始化失败: {e}")
2556+
return
2557+
2558+
os.makedirs(args.output_dir, exist_ok=True)
2559+
2560+
# 判断输入类型
2561+
if args.input.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')):
2562+
print(f"\n处理图片: {args.input}")
2563+
frame = cv2.imread(args.input)
2564+
if frame is None:
2565+
print("读取图片失败")
2566+
return
2567+
2568+
dets, count, time_cost = inferencer.predict(frame)
2569+
res_img = inferencer.draw_detections(frame, dets)
2570+
2571+
out_name = os.path.basename(args.input)
2572+
out_path = os.path.join(args.output_dir, f"res_{out_name}")
2573+
cv2.imwrite(out_path, res_img)
2574+
2575+
print(f"推理完成: {len(dets)} 个对象, {count} 个车辆, 耗时 {time_cost:.1f}ms")
2576+
print(f"结果已保存: {out_path}")
2577+
2578+
if args.verbose:
2579+
print("\n[详细检测结果]")
2580+
for i, det in enumerate(dets):
2581+
print(f" {i+1}. {det['class_name']} ({det['confidence']:.2f}) - {det['bbox']}")
2582+
2583+
elif args.input.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
2584+
print(f"\n处理视频: {args.input}")
2585+
cap = cv2.VideoCapture(args.input)
2586+
if not cap.isOpened():
2587+
print("打开视频失败")
2588+
return
2589+
2590+
video_name = os.path.splitext(os.path.basename(args.input))[0]
2591+
# 创建子目录专门存放该视频的帧
2592+
frame_dir = os.path.join(args.output_dir, f"{video_name}_frames")
2593+
os.makedirs(frame_dir, exist_ok=True)
2594+
2595+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
2596+
f_count = 0
2597+
total_time = 0
2598+
2599+
print(f"总帧数: {total_frames}")
2600+
print(f"保存目录: {frame_dir}")
2601+
print(f"调试模式: {'开启 (打印所有检测框)' if args.verbose else '关闭 (仅打印摘要)'}")
2602+
print("-" * 80)
2603+
2604+
while True:
2605+
if args.max_frames and f_count >= args.max_frames:
2606+
print(f"\n[完成] 达到最大帧数限制: {args.max_frames}")
2607+
break
2608+
2609+
ret, frame = cap.read()
2610+
if not ret:
2611+
print("\n[完成] 视频处理完毕或读取结束")
2612+
break
2613+
2614+
f_count += 1
2615+
2616+
# 推理
2617+
dets, count, time_cost = inferencer.predict(frame)
2618+
total_time += time_cost
2619+
2620+
# 绘制并保存 (每一帧都保存)
2621+
res_img = inferencer.draw_detections(frame, dets)
2622+
save_name = f"frame_{f_count:05d}.jpg"
2623+
save_path = os.path.join(frame_dir, save_name)
2624+
cv2.imwrite(save_path, res_img)
2625+
2626+
# 打印每一帧的调试信息
2627+
# 格式: [帧号/总帧] 耗时 检测数 车辆数 保存路径
2628+
print(f"[Frame {f_count:04d}/{total_frames}] Time: {time_cost:6.2f}ms | "
2629+
f"Detections: {len(dets):2d} | Vehicles: {count:2d} | "
2630+
f"Saved: {save_name}")
2631+
2632+
# 如果开启 verbose,打印每个检测框的详情
2633+
if args.verbose:
2634+
for i, det in enumerate(dets):
2635+
bbox_str = f"[{det['bbox'][0]:.0f}, {det['bbox'][1]:.0f}, {det['bbox'][2]:.0f}, {det['bbox'][3]:.0f}]"
2636+
print(f" -> [{i+1}] {det['class_name']:15s} conf:{det['confidence']:.2f} bbox:{bbox_str}")
2637+
2638+
cap.release()
2639+
avg_time = total_time / f_count if f_count > 0 else 0
2640+
print("-" * 80)
2641+
print(f"\n视频处理统计:")
2642+
print(f" 处理帧数: {f_count}")
2643+
print(f" 总耗时: {total_time:.1f}ms")
2644+
print(f" 平均耗时: {avg_time:.1f}ms")
2645+
print(f" 平均FPS: {1000/avg_time:.1f}" if avg_time > 0 else " 平均FPS: N/A")
2646+
print(f" 结果保存在: {frame_dir}")
2647+
2648+
else:
2649+
print("不支持的文件格式")
2650+
2651+
2652+
if __name__ == "__main__":
2653+
main()
2654+
```

0 commit comments

Comments
 (0)