2626NMS_THRESH = 0.45
2727IMG_SIZE = (640 , 640 ) # (width, height)
2828
29- CLASSES = ("person" , "bicycle" , "car" ,"motorbike " ,"aeroplane " ,"bus " ,"train" ,"truck " ,"boat" ,"traffic light" ,
29+ # 默认类别定义 (COCO 80类)
30+ DEFAULT_CLASSES = ("person" , "bicycle" , "car" ,"motorbike " ,"aeroplane " ,"bus " ,"train" ,"truck " ,"boat" ,"traffic light" ,
3031 "fire hydrant" ,"stop sign " ,"parking meter" ,"bench" ,"bird" ,"cat" ,"dog " ,"horse " ,"sheep" ,"cow" ,"elephant" ,
3132 "bear" ,"zebra " ,"giraffe" ,"backpack" ,"umbrella" ,"handbag" ,"tie" ,"suitcase" ,"frisbee" ,"skis" ,"snowboard" ,"sports ball" ,"kite" ,
3233 "baseball bat" ,"baseball glove" ,"skateboard" ,"surfboard" ,"tennis racket" ,"bottle" ,"wine glass" ,"cup" ,"fork" ,"knife " ,
3334 "spoon" ,"bowl" ,"banana" ,"apple" ,"sandwich" ,"orange" ,"broccoli" ,"carrot" ,"hot dog" ,"pizza " ,"donut" ,"cake" ,"chair" ,"sofa" ,
3435 "pottedplant" ,"bed" ,"diningtable" ,"toilet " ,"tvmonitor" ,"laptop " ,"mouse " ,"remote " ,"keyboard " ,"cell phone" ,"microwave " ,
3536 "oven " ,"toaster" ,"sink" ,"refrigerator " ,"book" ,"clock" ,"vase" ,"scissors " ,"teddy bear " ,"hair drier" , "toothbrush " )
3637
38+ CLASSES = DEFAULT_CLASSES
39+
40+ def load_classes (path ):
41+ """
42+ 从文件加载类别,支持双引号和逗号分隔的格式
43+ 例如: "person", "bicycle", "car"
44+ """
45+ global CLASSES
46+ if not path or not os .path .exists (path ):
47+ CLASSES = DEFAULT_CLASSES
48+ return
49+
50+ try :
51+ with open (path , 'r' , encoding = 'utf-8' ) as f :
52+ content = f .read ().strip ()
53+ # 简单的解析逻辑:移除换行,按逗号分割,去除空格和双引号
54+ import re
55+ # 匹配双引号内的内容
56+ items = re .findall (r'"([^"]*)"' , content )
57+ if items :
58+ CLASSES = tuple (items )
59+ print (f"Successfully loaded { len (CLASSES )} classes from { path } " )
60+ else :
61+ # 备选方案:如果没匹配到双引号,尝试按逗号分割
62+ items = [item .strip ().strip ('"' ) for item in content .split (',' ) if item .strip ()]
63+ if items :
64+ CLASSES = tuple (items )
65+ print (f"Loaded { len (CLASSES )} classes from { path } (fallback parsing)" )
66+ else :
67+ print (f"Warning: No classes found in { path } , using default COCO classes" )
68+ CLASSES = DEFAULT_CLASSES
69+ except Exception as e :
70+ print (f"Error loading classes from { path } : { e } . Using default COCO classes" )
71+ CLASSES = DEFAULT_CLASSES
72+
3773# 动态配置参数
3874class DetectionConfig :
3975 def __init__ (self ):
@@ -481,6 +517,7 @@ def main():
481517 parser .add_argument ('--model_path' , type = str , required = True , help = 'RKNN model path' )
482518 parser .add_argument ('--camera_id' , type = int , default = 1 , help = 'Camera device ID (default: 1 for /dev/video1)' )
483519 parser .add_argument ('--video_path' , type = str , help = 'Path to video file (overrides camera_id)' )
520+ parser .add_argument ('--class_path' , type = str , help = 'Path to class_config.txt file for dynamic category loading' )
484521 parser .add_argument ('--host' , type = str , default = '0.0.0.0' , help = 'Web server host' )
485522 parser .add_argument ('--port' , type = int , default = 8000 , help = 'Web server port' )
486523 args = parser .parse_args ()
@@ -489,6 +526,10 @@ def main():
489526 print ("Error: RKNN-Toolkit-Lite2 is not available." )
490527 return
491528
529+ # 加载自定义类别
530+ if args .class_path :
531+ load_classes (args .class_path )
532+
492533 # 启动 Web 服务器线程
493534 web_thread = threading .Thread (target = run_fastapi , args = (args .host , args .port ), daemon = True )
494535 web_thread .start ()
0 commit comments