Skip to content

Commit 1619cc2

Browse files
committed
feat(test): add execution time statistics for test cases
1 parent 192b86c commit 1619cc2

1 file changed

Lines changed: 28 additions & 2 deletions

File tree

tests/test_main.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <iomanip>
1515
#include <dirent.h>
1616
#include <sys/stat.h>
17+
#include <chrono>
1718
#include "tokenizer.hpp"
1819

1920
#include <utf8proc/utf8proc.h>
@@ -33,6 +34,10 @@ namespace Color {
3334
const std::string GREY = "\033[90m";
3435
}
3536

37+
// ==================== 全局统计 ====================
38+
double g_total_load_ms = 0;
39+
double g_total_encode_ms = 0;
40+
3641
// 计算字符串在终端显示的视觉宽度,跳过 ANSI 转义序列,处理 ZWJ Emoji 序列
3742
int get_display_width(const std::string& str) {
3843
int width = 0;
@@ -137,7 +142,11 @@ bool run_basic_test(tokenizer::PreTrainedTokenizer* tok, const json& test_case,
137142
std::vector<int> expected_ids = test_case["ids_raw"].get<std::vector<int>>();
138143

139144
// 1. 测试 Encode
145+
auto start = std::chrono::high_resolution_clock::now();
140146
std::vector<int> result = tok->encode(input, false);
147+
auto end = std::chrono::high_resolution_clock::now();
148+
g_total_encode_ms += std::chrono::duration<double, std::milli>(end - start).count();
149+
141150
bool ids_match = (result == expected_ids);
142151

143152
// 2. 测试 Decode
@@ -183,8 +192,8 @@ bool run_chat_test(tokenizer::PreTrainedTokenizer* tok, const json& test_case, b
183192
std::vector<int> expected_ids = test_case["ids"].get<std::vector<int>>();
184193
bool add_gen_prompt = test_case.value("add_generation_prompt", false);
185194

186-
std::string result_text; // Declare result_text
187-
tokenizer::ChatMessages messages; // Declare messages
195+
std::string result_text;
196+
tokenizer::ChatMessages messages;
188197
bool has_complex = false;
189198
if (test_case["messages"].is_array()) {
190199
for (const auto& msg : test_case["messages"]) {
@@ -194,6 +203,8 @@ bool run_chat_test(tokenizer::PreTrainedTokenizer* tok, const json& test_case, b
194203
}
195204
}
196205
}
206+
207+
auto start = std::chrono::high_resolution_clock::now();
197208
if (has_complex) {
198209
result_text = tok->apply_chat_template(test_case["messages"].dump(), add_gen_prompt);
199210
} else {
@@ -205,6 +216,9 @@ bool run_chat_test(tokenizer::PreTrainedTokenizer* tok, const json& test_case, b
205216

206217
// 2. 比较生成的 Tokens
207218
std::vector<int> result_ids = tok->encode(result_text, false);
219+
auto end = std::chrono::high_resolution_clock::now();
220+
g_total_encode_ms += std::chrono::duration<double, std::milli>(end - start).count();
221+
208222
bool ids_match = (result_ids == expected_ids);
209223

210224
if (text_match && ids_match) {
@@ -240,11 +254,15 @@ TestResult run_model_tests(const std::string& model_path, const std::string& mod
240254
TestResult result;
241255

242256
// 1. 加载 tokenizer
257+
auto start = std::chrono::high_resolution_clock::now();
243258
auto tok = tokenizer::AutoTokenizer::from_pretrained(model_path);
259+
auto end = std::chrono::high_resolution_clock::now();
260+
244261
if (!tok) {
245262
std::cout << Color::RED << " ❌ Failed to load tokenizer" << Color::RESET << std::endl;
246263
return result;
247264
}
265+
g_total_load_ms += std::chrono::duration<double, std::milli>(end - start).count();
248266

249267
// 2. 加载 test_cases.jsonl
250268
std::string cases_path = model_path + "/test_cases.jsonl";
@@ -461,10 +479,18 @@ int main(int argc, char** argv) {
461479
for (const auto& m : failed_models) {
462480
std::cout << Color::RED << " - " << m << Color::RESET << std::endl;
463481
}
482+
std::cout << "--------------------------------------------------" << std::endl;
483+
std::cout << " Total Loading Time: " << std::fixed << std::setprecision(2) << g_total_load_ms << " ms" << std::endl;
484+
std::cout << " Total Encode Time : " << std::fixed << std::setprecision(2) << g_total_encode_ms << " ms" << std::endl;
485+
std::cout << " Total Time : " << std::fixed << std::setprecision(2) << (g_total_load_ms + g_total_encode_ms) << " ms" << std::endl;
464486
return 1;
465487
} else {
466488
std::cout << Color::GREEN << " Failed : 0" << Color::RESET << std::endl;
467489
std::cout << "\n✨ All tests passed! ✨" << std::endl;
490+
std::cout << "--------------------------------------------------" << std::endl;
491+
std::cout << " Total Loading Time: " << std::fixed << std::setprecision(2) << g_total_load_ms << " ms" << std::endl;
492+
std::cout << " Total Encode Time : " << std::fixed << std::setprecision(2) << g_total_encode_ms << " ms" << std::endl;
493+
std::cout << " Total Time : " << std::fixed << std::setprecision(2) << (g_total_load_ms + g_total_encode_ms) << " ms" << std::endl;
468494
return 0;
469495
}
470496
}

0 commit comments

Comments
 (0)