-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
93 lines (79 loc) · 4.64 KB
/
main.py
File metadata and controls
93 lines (79 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from conf.setting import *
import cv2
import shutil
import json
from tqdm import trange, tqdm
from recognition import FlowchartRecognition
import threading
import torch
from PIL import Image
from FasterRCNN.predict import flowchat_recognize, create_model
from cnocr import CnOcr
from paddleocr import PaddleOCR, draw_ocr
def init_flowchart_recognize_model(classed_file= "FasterRCNN/save_weights/V15ArrowMix/classes.json", models_path = "./FasterRCNN/save_weights/V15ArrowMix/resNetFpn-model-15.pth"):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# read class_indict
assert os.path.exists(classed_file), "json file {} dose not exist.".format(classed_file)
with open(classed_file, 'r') as f:
class_dict = json.load(f)
category_index = {str(v): str(k) for k, v in class_dict.items()}
# create model
recognize_model = create_model(num_classes=len(class_dict)+1)
# load train weights
weights_path = models_path
assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
weights_dict = torch.load(weights_path, map_location='cpu')
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
recognize_model.load_state_dict(weights_dict)
recognize_model.to(device)
recognize_model.eval() # 进入验证模式
return recognize_model, category_index
def start_recognize(recognize_model, category_index, ocr_model, img_path, result_save_path, img_name, recognize_lock, cnocr_lock, file_create_lock):
recognizer = FlowchartRecognition(recognize_model, category_index, ocr_model, recognize_lock, cnocr_lock, file_create_lock)
shape_nodes, arrow_nodes = recognizer.recognize_flowchart(image_path=f"{img_path}/{img_name}",
img_name=f"{img_path.split('/')[-1]}/{img_name}")
tmp_shape_nodes = list()
for item in shape_nodes:
tmp_shape_nodes.append({
"id": item["id"],
"Name": item["Name"],
"coordinate": item["coordinate"],
"top": item["top"],
"left": item["left"],
"width": item["size"]["width"],
"height": item["size"]["height"],
"rows": item["size"]["rows"]
})
tools.write_2_json({"nodes": tmp_shape_nodes, "edges": arrow_nodes}, f"{result_save_path}/{img_name.replace('.png', '.json').replace('.jpg', '.json')}")
recognizer.draw_recognized_node_edges(tmp_shape_nodes, arrow_nodes, f"{img_path}/{img_name}")
def start_recognize_flowchart(flowchart_img_savepath, result_savepath):
# recognizer = FlowchartRecognition()
recognize_model, category_index = init_flowchart_recognize_model()
# ocr_model = CnOcr(rec_model_name='densenet_lite_136-gru', rec_model_backend="pytorch", det_model_name="db_resnet34", det_model_backend="pytorch") # OCR中文
# ocr_model = CnOcr(det_model_name='en_PP-OCRv3_det', rec_model_name='en_PP-OCRv3') # OCR英文
ocr_model = PaddleOCR(use_gpu=True, lang="ch", det_model_dir="models/paddleocr/ch_PP-OCRv4_det_infer/", rec_model_dir="models/paddleocr/ch_PP-OCRv4_rec_infer/")
print(f"recognize_model:{recognize_model}")
print(f"category_index:{category_index}")
run_paras = list()
recognize_lock = threading.Lock()
cnocr_lock = threading.Lock()
file_create_lock = threading.Lock()
folders = os.listdir(flowchart_img_savepath)
for folder in folders:
flowchart_imgs = os.listdir(f"{flowchart_img_savepath}/{folder}")
for img_name in tqdm(flowchart_imgs[:], total=len(flowchart_imgs), desc=f"{folder}: "):
run_paras.append((recognize_model, category_index, ocr_model, f"{flowchart_img_savepath}/{folder}", f"{result_savepath}/{folder}", f"{img_name}", recognize_lock, cnocr_lock, file_create_lock))
tools.multi_thread_run(32, start_recognize, run_paras, "Recognize Flowchart: ")
def save_cover_shape_imgs(flowchart_data_path, flowhcart_img_path, file_name, result_save_path, cover_img_save_path):
try:
flowchart_data = tools.read_json(f"{flowchart_data_path}/{file_name.replace('.png', '.json')}")
except:
return
image = cv2.imread(f"{flowhcart_img_path}/{file_name}")
for item in flowchart_data['nodes']:
box = item["coordinate"]
cv2.rectangle(image, (int(box[0]), int(box[2])), (int(box[1]), int(box[3])), (255, 255, 255), -1)
if not os.path.exists(cover_img_save_path): os.makedirs(cover_img_save_path)
cv2.imwrite(f"{cover_img_save_path}/{file_name}", image)
if __name__=='__main__':
start_recognize_flowchart(flowchart_img_savepath="images", result_savepath="results")