Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 11aaf19

Browse files
authored
Merge pull request #1665 from janhq/j/simultaneous-download
feat: simultaneous download
2 parents ce7af64 + 4c110bf commit 11aaf19

File tree

9 files changed

+454
-272
lines changed

9 files changed

+454
-272
lines changed

engine/cli/commands/engine_install_cmd.cc

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,13 @@ bool EngineInstallCmd::Exec(const std::string& engine,
3636
DownloadProgress dp;
3737
dp.Connect(host_, port_);
3838
// engine can be small, so need to start ws first
39-
auto dp_res = std::async(std::launch::deferred, [&dp, &engine] {
40-
return dp.Handle(DownloadType::Engine);
39+
auto dp_res = std::async(std::launch::deferred, [&dp] {
40+
bool need_cuda_download = !system_info_utils::GetCudaVersion().empty();
41+
if (need_cuda_download) {
42+
return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit});
43+
} else {
44+
return dp.Handle({DownloadType::Engine});
45+
}
4146
});
4247

4348
auto versions_url = url_parser::Url{
@@ -133,12 +138,6 @@ bool EngineInstallCmd::Exec(const std::string& engine,
133138
if (!dp_res.get())
134139
return false;
135140

136-
bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
137-
if (check_cuda_download) {
138-
if (!dp.Handle(DownloadType::CudaToolkit))
139-
return false;
140-
}
141-
142141
CLI_LOG("Engine " << engine << " downloaded successfully!")
143142
return true;
144143
}
@@ -147,8 +146,14 @@ bool EngineInstallCmd::Exec(const std::string& engine,
147146
DownloadProgress dp;
148147
dp.Connect(host_, port_);
149148
// engine can be small, so need to start ws first
150-
auto dp_res = std::async(std::launch::deferred,
151-
[&dp] { return dp.Handle(DownloadType::Engine); });
149+
auto dp_res = std::async(std::launch::deferred, [&dp] {
150+
bool need_cuda_download = !system_info_utils::GetCudaVersion().empty();
151+
if (need_cuda_download) {
152+
return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit});
153+
} else {
154+
return dp.Handle({DownloadType::Engine});
155+
}
156+
});
152157

153158
auto install_url = url_parser::Url{
154159
.protocol = "http",
@@ -183,12 +188,6 @@ bool EngineInstallCmd::Exec(const std::string& engine,
183188
if (!dp_res.get())
184189
return false;
185190

186-
bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
187-
if (check_cuda_download) {
188-
if (!dp.Handle(DownloadType::CudaToolkit))
189-
return false;
190-
}
191-
192191
CLI_LOG("Engine " << engine << " downloaded successfully!")
193192
return true;
194193
}

engine/cli/commands/engine_update_cmd.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port,
2424
DownloadProgress dp;
2525
dp.Connect(host, port);
2626
// engine can be small, so need to start ws first
27-
auto dp_res = std::async(std::launch::deferred, [&dp, &engine] {
28-
return dp.Handle(DownloadType::Engine);
27+
auto dp_res = std::async(std::launch::deferred, [&dp] {
28+
bool need_cuda_download = !system_info_utils::GetCudaVersion().empty();
29+
if (need_cuda_download) {
30+
return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit});
31+
} else {
32+
return dp.Handle({DownloadType::Engine});
33+
}
2934
});
3035

3136
auto update_url = url_parser::Url{
@@ -48,12 +53,6 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port,
4853
if (!dp_res.get())
4954
return false;
5055

51-
bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
52-
if (check_cuda_download) {
53-
if (!dp.Handle(DownloadType::CudaToolkit))
54-
return false;
55-
}
56-
5756
CLI_LOG("Engine " << engine << " updated successfully!")
5857
return true;
5958
}

engine/cli/commands/model_pull_cmd.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ std::optional<std::string> ModelPullCmd::Exec(const std::string& host, int port,
143143
reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
144144
#endif
145145
dp.Connect(host, port);
146-
if (!dp.Handle(DownloadType::Model))
146+
if (!dp.Handle({DownloadType::Model}))
147147
return std::nullopt;
148148
if (force_stop)
149149
return std::nullopt;

engine/cli/utils/download_progress.cc

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ bool DownloadProgress::Connect(const std::string& host, int port) {
3434
return true;
3535
}
3636

37-
bool DownloadProgress::Handle(const DownloadType& event_type) {
37+
bool DownloadProgress::Handle(
38+
const std::unordered_set<DownloadType>& event_type) {
3839
assert(!!ws_);
3940
#if defined(_WIN32)
4041
HANDLE h_out = GetStdHandle(STD_OUTPUT_HANDLE);
@@ -50,10 +51,14 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
5051
}
5152
}
5253
#endif
53-
status_ = DownloadStatus::DownloadStarted;
54+
for (auto et : event_type) {
55+
status_[et] = DownloadStatus::DownloadStarted;
56+
}
5457
std::unique_ptr<indicators::DynamicProgress<indicators::ProgressBar>> bars;
5558

56-
std::vector<std::unique_ptr<indicators::ProgressBar>> items;
59+
std::unordered_map<std::string,
60+
std::pair<int, std::unique_ptr<indicators::ProgressBar>>>
61+
items;
5762
indicators::show_console_cursor(false);
5863
auto start = std::chrono::steady_clock::now();
5964
auto handle_message = [this, &bars, &items, event_type,
@@ -78,22 +83,28 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
7883
auto ev = cortex::event::GetDownloadEventFromJson(
7984
json_helper::ParseJsonString(message));
8085
// Ignore other task type
81-
if (ev.download_task_.type != event_type) {
86+
if (event_type.find(ev.download_task_.type) == event_type.end()) {
8287
return;
8388
}
8489
auto now = std::chrono::steady_clock::now();
8590
if (!bars) {
8691
bars = std::make_unique<
8792
indicators::DynamicProgress<indicators::ProgressBar>>();
88-
for (auto& i : ev.download_task_.items) {
89-
items.emplace_back(std::make_unique<indicators::ProgressBar>(
90-
indicators::option::BarWidth{50}, indicators::option::Start{"["},
91-
indicators::option::Fill{"="}, indicators::option::Lead{">"},
92-
indicators::option::End{"]"},
93-
indicators::option::PrefixText{pad_string(Repo2Engine(i.id))},
94-
indicators::option::ForegroundColor{indicators::Color::white},
95-
indicators::option::ShowRemainingTime{false}));
96-
bars->push_back(*(items.back()));
93+
}
94+
for (auto& i : ev.download_task_.items) {
95+
if (items.find(i.id) == items.end()) {
96+
auto idx = items.size();
97+
items[i.id] = std::pair(
98+
idx,
99+
std::make_unique<indicators::ProgressBar>(
100+
indicators::option::BarWidth{50},
101+
indicators::option::Start{"["}, indicators::option::Fill{"="},
102+
indicators::option::Lead{">"}, indicators::option::End{"]"},
103+
indicators::option::PrefixText{pad_string(Repo2Engine(i.id))},
104+
indicators::option::ForegroundColor{indicators::Color::white},
105+
indicators::option::ShowRemainingTime{false}));
106+
107+
bars->push_back(*(items.at(i.id).second));
97108
}
98109
}
99110
for (int i = 0; i < ev.download_task_.items.size(); i++) {
@@ -113,32 +124,36 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
113124
(total - downloaded) / bytes_per_sec);
114125
}
115126

116-
(*bars)[i].set_option(indicators::option::PrefixText{
117-
pad_string(Repo2Engine(it.id)) +
118-
std::to_string(int(static_cast<double>(downloaded) / total * 100)) +
119-
'%'});
120-
(*bars)[i].set_progress(
127+
(*bars)[items.at(it.id).first].set_option(
128+
indicators::option::PrefixText{
129+
pad_string(Repo2Engine(it.id)) +
130+
std::to_string(
131+
int(static_cast<double>(downloaded) / total * 100)) +
132+
'%'});
133+
(*bars)[items.at(it.id).first].set_progress(
121134
int(static_cast<double>(downloaded) / total * 100));
122-
(*bars)[i].set_option(indicators::option::PostfixText{
123-
time_remaining + " " +
124-
format_utils::BytesToHumanReadable(downloaded) + "/" +
125-
format_utils::BytesToHumanReadable(total)});
135+
(*bars)[items.at(it.id).first].set_option(
136+
indicators::option::PostfixText{
137+
time_remaining + " " +
138+
format_utils::BytesToHumanReadable(downloaded) + "/" +
139+
format_utils::BytesToHumanReadable(total)});
126140
} else if (ev.type_ == DownloadStatus::DownloadSuccess) {
127141
uint64_t total =
128142
it.bytes.value_or(std::numeric_limits<uint64_t>::max());
129-
(*bars)[i].set_progress(100);
143+
(*bars)[items.at(it.id).first].set_progress(100);
130144
auto total_str = format_utils::BytesToHumanReadable(total);
131-
(*bars)[i].set_option(indicators::option::PostfixText{
132-
"00m:00s " + total_str + "/" + total_str});
133-
(*bars)[i].set_option(indicators::option::PrefixText{
134-
pad_string(Repo2Engine(it.id)) + "100%"});
135-
(*bars)[i].set_progress(100);
145+
(*bars)[items.at(it.id).first].set_option(
146+
indicators::option::PostfixText{"00m:00s " + total_str + "/" +
147+
total_str});
148+
(*bars)[items.at(it.id).first].set_option(
149+
indicators::option::PrefixText{pad_string(Repo2Engine(it.id)) +
150+
"100%"});
151+
(*bars)[items.at(it.id).first].set_progress(100);
136152

137153
CTL_INF("Download success");
138154
}
155+
status_[ev.download_task_.type] = ev.type_;
139156
}
140-
141-
status_ = ev.type_;
142157
};
143158

144159
while (ws_->getReadyState() != easywsclient::WebSocket::CLOSED &&
@@ -152,7 +167,9 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
152167
SetConsoleMode(h_out, dw_original_out_mode);
153168
}
154169
#endif
155-
if (status_ == DownloadStatus::DownloadError)
156-
return false;
170+
for (auto const& [_, v] : status_) {
171+
if (v == DownloadStatus::DownloadError)
172+
return false;
173+
}
157174
return true;
158175
}

engine/cli/utils/download_progress.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <atomic>
33
#include <memory>
44
#include <string>
5+
#include <unordered_set>
56
#include "common/event.h"
67
#include "easywsclient.hpp"
78

@@ -10,19 +11,25 @@ class DownloadProgress {
1011
public:
1112
bool Connect(const std::string& host, int port);
1213

13-
bool Handle(const DownloadType& event_type);
14+
bool Handle(const std::unordered_set<DownloadType>& event_type);
1415

1516
void ForceStop() { force_stop_ = true; }
1617

1718
private:
1819
bool should_stop() const {
19-
return (status_ != DownloadStatus::DownloadStarted &&
20-
status_ != DownloadStatus::DownloadUpdated) ||
21-
force_stop_;
20+
bool should_stop = true;
21+
for (auto const& [_, v] : status_) {
22+
should_stop &= (v == DownloadStatus::DownloadSuccess);
23+
}
24+
for (auto const& [_, v] : status_) {
25+
should_stop |= (v == DownloadStatus::DownloadError ||
26+
v == DownloadStatus::DownloadStopped);
27+
}
28+
return should_stop || force_stop_;
2229
}
2330

2431
private:
2532
std::unique_ptr<easywsclient::WebSocket> ws_;
26-
std::atomic<DownloadStatus> status_ = DownloadStatus::DownloadStarted;
33+
std::unordered_map<DownloadType, std::atomic<DownloadStatus>> status_;
2734
std::atomic<bool> force_stop_ = false;
2835
};

engine/e2e-test/test_api_model_start.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def setup_and_teardown(self):
1212
success = start_server()
1313
if not success:
1414
raise Exception("Failed to start server")
15-
requests.post("http://localhost:3928/v1/engines/llama-cpp")
15+
run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60)
1616
run("Delete model", ["models", "delete", "tinyllama:gguf"])
1717
run(
1818
"Pull model",
@@ -27,5 +27,7 @@ def setup_and_teardown(self):
2727

2828
def test_models_start_should_be_successful(self):
2929
json_body = {"model": "tinyllama:gguf"}
30-
response = requests.post("http://localhost:3928/v1/models/start", json=json_body)
30+
response = requests.post(
31+
"http://localhost:3928/v1/models/start", json=json_body
32+
)
3133
assert response.status_code == 200, f"status_code: {response.status_code}"

engine/e2e-test/test_api_model_stop.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import requests
3-
from test_runner import start_server, stop_server
3+
from test_runner import run, start_server, stop_server
44

55

66
class TestApiModelStop:
@@ -13,16 +13,18 @@ def setup_and_teardown(self):
1313
if not success:
1414
raise Exception("Failed to start server")
1515

16-
requests.post("http://localhost:3928/engines/llama-cpp")
16+
run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60)
1717
yield
1818

19-
requests.delete("http://localhost:3928/engines/llama-cpp")
19+
run("Uninstall engine", ["engines", "uninstall", "llama-cpp"])
2020
# Teardown
2121
stop_server()
2222

2323
def test_models_stop_should_be_successful(self):
2424
json_body = {"model": "tinyllama:gguf"}
25-
response = requests.post("http://localhost:3928/v1/models/start", json=json_body)
25+
response = requests.post(
26+
"http://localhost:3928/v1/models/start", json=json_body
27+
)
2628
assert response.status_code == 200, f"status_code: {response.status_code}"
2729
response = requests.post("http://localhost:3928/v1/models/stop", json=json_body)
2830
assert response.status_code == 200, f"status_code: {response.status_code}"

0 commit comments

Comments
 (0)