forked from hiroi-sora/PaddleOCR-json
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathargs.cpp
More file actions
226 lines (213 loc) · 9.76 KB
/
args.cpp
File metadata and controls
226 lines (213 loc) · 9.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <fstream>
#include <include/utility.h>
#include <gflags/gflags.h>
// 工作模式
DEFINE_string(image_path, "", "Set image_path to run a single task."); // 若填写了图片路径,则执行一次OCR。
DEFINE_int32(port, -1, "Set to 0 enable random port, set to 1~65535 enables specified port."); // 填写0随机端口号,填1^65535指定端口号。默认则启用匿名管道模式。
DEFINE_string(addr, "loopback", "Socket server addr, the value can be 'loopback', 'localhost', 'any', or other IPv4 address."); // 套接字服务器的地址模式,本地环回/任何可用。
// common args
DEFINE_bool(use_gpu, false, "Inferring with GPU or CPU.");
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
DEFINE_int32(gpu_mem, 4000, "GPU memory size (MB) to use.");
DEFINE_int32(cpu_threads, 10, "Num of threads with CPU.");
DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
DEFINE_bool(benchmark, false, "Whether use benchmark.");
DEFINE_string(output, "./output/", "Save benchmark log path.");
DEFINE_string(image_dir, "", "Dir of input image.");
DEFINE_string(
type, "ocr",
"Perform ocr or structure, the value is selected in ['ocr','structure'].");
DEFINE_string(config_path, "", "Path of config file."); // 配置文件路径
DEFINE_string(models_path, "", "Path of models folder."); // 预测库路径
DEFINE_bool(ensure_ascii, true, "Enable JSON ascii escape."); // true时json开启ascii转义
// detection related
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_string(limit_type, "max", "limit_type of input image.");
DEFINE_int32(limit_side_len, 960, "limit_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
DEFINE_bool(use_dilation, false, "Whether use the dilation on output map.");
DEFINE_string(det_db_score_mode, "slow", "Whether use polygon score.");
DEFINE_bool(visualize, true, "Whether show the detection results.");
// classification related
DEFINE_bool(use_angle_cls, false, "Whether use use_angle_cls.");
DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
DEFINE_int32(cls_batch_num, 1, "cls_batch_num.");
// recognition related
DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
"Path of dictionary.");
DEFINE_int32(rec_img_h, 48, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width");
// layout model related
DEFINE_string(layout_model_dir, "", "Path of table layout inference model.");
DEFINE_string(layout_dict_path,
"../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt",
"Path of dictionary.");
DEFINE_double(layout_score_threshold, 0.5, "Threshold of score.");
DEFINE_double(layout_nms_threshold, 0.5, "Threshold of nms.");
// structure model related
DEFINE_string(table_model_dir, "", "Path of table structure inference model.");
DEFINE_int32(table_max_len, 488, "max len size of input image.");
DEFINE_int32(table_batch_num, 1, "table_batch_num.");
DEFINE_bool(merge_no_span_structure, true,
"Whether merge <td> and </td> to <td></td>");
DEFINE_string(table_char_dict_path,
"../../ppocr/utils/dict/table_structure_dict_ch.txt",
"Path of dictionary.");
// ocr forward related
DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward.");
DEFINE_bool(table, false, "Whether use table structure in forward.");
DEFINE_bool(layout, false, "Whether use layout analysis in forward.");
// 检查一个路径path是否存在,将信息写入msg
void check_path(const std::string& path, const std::string& name, std::string& msg)
{
if (path.empty()) {
msg += (name + " is empty. ");
}
else if (!PaddleOCR::Utility::PathExists(path)) {
msg += (name + " [" + path + "] does not exist. ");
}
}
// 为 value 前置拼接预测库路径
void prepend_models(const std::string& models_path_base, std::string& value)
{
if (PaddleOCR::Utility::str_starts_with(value, "models")) {
value.erase(0, 6);
value = PaddleOCR::Utility::pathjoin(models_path_base, value);
}
}
// 从配置文件中读取配置,返回日志字符串。
std::string read_config()
{
// 设置默认预测库路径
std::string models_path_base = "models";
// 如果输入正常预测库路径参数
if (!FLAGS_models_path.empty() && PaddleOCR::Utility::PathExists(FLAGS_models_path))
{
// 则更新预测库路径
models_path_base = FLAGS_models_path;
// 之后我们会用这个预测库路径来更新所有其他参数的路径
}
if (!PaddleOCR::Utility::PathExists(FLAGS_config_path))
{
return ("config_path [" + FLAGS_config_path + "] does not exist. ");
}
std::ifstream infile(FLAGS_config_path);
if (!infile)
{
return ("[WARNING] Unable to open config_path [" + FLAGS_config_path + "]. ");
}
std::string msg = "Load config from [" + FLAGS_config_path + "] : ";
std::string line;
int num = 0;
while (getline(infile, line))
{
int length = line.length();
if (length < 3 || line[0] == '#') // 跳过空行和注释
continue;
int split = 0; // 键值对的分割线
for (; split < length; split++)
{
if (line[split] == ' ' || line[split] == '=')
break;
}
if (split >= length - 1 || split == 0) // 跳过长度不足的键值对
continue;
std::string key = line.substr(0, split);
std::string value = line.substr(split + 1);
prepend_models(models_path_base, value);
// 设置配置,优先级低于命令行传入参数。
std::string res = google::SetCommandLineOptionWithMode(key.c_str(), value.c_str(), google::SET_FLAG_IF_DEFAULT);
if (!res.empty())
{
num++;
msg += res.substr(0, res.length() - 1);
}
}
infile.close();
if (num == 0)
msg += "No valid config found.";
else
msg += ". ";
return msg;
}
// 检测参数合法性。成功返回空字符串,失败返回报错信息字符串。
std::string check_flags() {
// 设置默认预测库路径
std::string models_path_base = "models";
// 如果输入正常预测库路径参数
if (!FLAGS_models_path.empty() && PaddleOCR::Utility::PathExists(FLAGS_models_path))
{
// 则更新预测库路径
models_path_base = FLAGS_models_path;
// 之后我们会用这个预测库路径来更新所有其他参数的路径
}
std::string msg = "";
if (FLAGS_det) { // 检查det
prepend_models(models_path_base, FLAGS_det_model_dir);
check_path(FLAGS_det_model_dir, "det_model_dir", msg);
}
if (FLAGS_rec) { // 检查rec
prepend_models(models_path_base, FLAGS_rec_model_dir);
check_path(FLAGS_rec_model_dir, "rec_model_dir", msg);
}
if (FLAGS_cls && FLAGS_use_angle_cls) { // 检查cls
prepend_models(models_path_base, FLAGS_cls_model_dir);
check_path(FLAGS_cls_model_dir, "cls_model_dir", msg);
}
if (!FLAGS_rec_char_dict_path.empty()) { // 检查 rec_char_dict_path
prepend_models(models_path_base, FLAGS_rec_char_dict_path);
check_path(FLAGS_rec_char_dict_path, "rec_char_dict_path", msg);
}
if (FLAGS_table) { // 检查table
prepend_models(models_path_base, FLAGS_table_model_dir);
check_path(FLAGS_table_model_dir, "table_model_dir", msg);
if (!FLAGS_det)
check_path(FLAGS_det_model_dir, "det_model_dir", msg);
if (!FLAGS_rec)
check_path(FLAGS_rec_model_dir, "rec_model_dir", msg);
}
if (FLAGS_layout) { // 布局
prepend_models(models_path_base, FLAGS_layout_model_dir);
check_path(FLAGS_layout_model_dir, "layout_model_dir", msg);
}
if (!FLAGS_config_path.empty()) { // 配置文件目录非空时检查存在
check_path(FLAGS_config_path, "config_path", msg);
}
// 检查枚举值
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") {
msg += "precison should be 'fp32'(default), 'fp16' or 'int8', not " + FLAGS_precision + ". ";
}
if (FLAGS_type != "ocr" && FLAGS_type != "structure") {
msg += "type should be 'ocr'(default) or 'structure', not " + FLAGS_type + ". ";
}
if (FLAGS_limit_type != "max" && FLAGS_limit_type != "min") {
msg += "limit_type should be 'max'(default) or 'min', not " + FLAGS_limit_type + ". ";
}
if (FLAGS_det_db_score_mode != "slow" && FLAGS_det_db_score_mode != "fast") {
msg += "limit_type should be 'slow'(default) or 'fast', not " + FLAGS_det_db_score_mode + ". ";
}
return msg;
}