Skip to content

Commit 81f63f4

Browse files
committed
Use psirng for sampling
1 parent 947538a commit 81f63f4

File tree

6 files changed

+52
-4
lines changed

6 files changed

+52
-4
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "kompute"]
22
path = ggml/src/kompute
33
url = https://github.com/nomic-ai/kompute.git
4+
[submodule "libpsirngclient"]
5+
path = libpsirngclient
6+
url = https://github.com/nullspook/libpsirngclient.git

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,9 @@ if (LLAMA_BUILD_EXAMPLES)
199199
add_subdirectory(examples)
200200
add_subdirectory(pocs)
201201
endif()
202+
203+
#
204+
# psirng
205+
#
206+
207+
add_subdirectory(libpsirngclient)

src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ add_library(llama
2222
unicode-data.cpp
2323
)
2424

25-
target_include_directories(llama PUBLIC . ../include)
25+
target_include_directories(llama PUBLIC . ../include ../libpsirngclient/src)
2626
target_compile_features (llama PUBLIC cxx_std_11) # don't bump
2727

28-
target_link_libraries(llama PUBLIC ggml)
28+
target_link_libraries(llama PUBLIC ggml psirngclient)
2929

3030
if (BUILD_SHARED_LIBS)
3131
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)

src/llama-sampling.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,21 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama
619619
probs.push_back(candidates->data[i].p);
620620
}
621621

622-
std::discrete_distribution<> dist(probs.begin(), probs.end());
623-
int idx = dist(rng);
622+
std::vector<float> cdf(probs.size());
623+
cdf[0] = probs[0];
624+
for (size_t i = 1; i < probs.size(); ++i) {
625+
cdf[i] = cdf[i - 1] + probs[i];
626+
}
627+
628+
int idx;
629+
double u;
630+
631+
int rand_result = psirngclient_randuniform(smpl->psirngclient_ptr, &u, 1, 0.0, 1.0);
632+
if (rand_result != PSIRNGCLIENT_RESULT_OK) {
633+
GGML_ABORT("psirngclient_randuniform error: %d", rand_result);
634+
}
635+
636+
idx = static_cast<int>(std::distance(cdf.begin(), std::lower_bound(cdf.begin(), cdf.end(), u)));
624637

625638
llama_token result = candidates->data[idx].id;
626639

src/llama-sampling.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
#include "llama-impl.h"
44

5+
#include "psirngclient.h"
6+
57
struct llama_sampling {
68
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
79

810
std::mt19937 rng;
911

12+
psirngclient * psirngclient_ptr;
13+
1014
int32_t n_vocab = 0;
1115

1216
mutable int64_t t_sample_us = 0;

src/llama.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18193,6 +18193,28 @@ struct llama_context * llama_new_context_with_model(
1819318193
ctx->abort_callback_data = params.abort_callback_data;
1819418194

1819518195
ctx->sampling.rng = std::mt19937(params.seed);
18196+
18197+
ctx->sampling.psirngclient_ptr = nullptr;
18198+
18199+
const char* psirng_host = std::getenv("PSIRNG_HOST");
18200+
const char* psirng_grpc_port = std::getenv("PSIRNG_GRPC_PORT");
18201+
const char* psirng_cert_path = std::getenv("PSIRNG_CERT_PATH");
18202+
18203+
if (psirng_host != nullptr && psirng_grpc_port != nullptr && psirng_cert_path != nullptr) {
18204+
psirngclient_init(&ctx->sampling.psirngclient_ptr, psirng_host, std::stoi(psirng_grpc_port), psirng_cert_path);
18205+
if (!psirngclient_ishealthy(ctx->sampling.psirngclient_ptr)) {
18206+
LLAMA_LOG_ERROR("%s: psirng is not healthy\n", __func__);
18207+
llama_free(ctx);
18208+
return nullptr;
18209+
} else {
18210+
LLAMA_LOG_INFO("%s: Using psirng running on %s:%s\n", __func__, psirng_host, psirng_grpc_port);
18211+
}
18212+
} else {
18213+
LLAMA_LOG_ERROR("%s: psirng is not configured\n", __func__);
18214+
llama_free(ctx);
18215+
return nullptr;
18216+
}
18217+
1819618218
ctx->logits_all = params.logits_all;
1819718219
// build worst-case graph for encoder if a model contains encoder
1819818220
ctx->is_encoding = llama_model_has_encoder(model);

0 commit comments

Comments
 (0)