From bc40a9481affd3b130c5f9f72bfb97f5292f5032 Mon Sep 17 00:00:00 2001 From: DwyaneShi Date: Tue, 5 Nov 2024 11:00:02 -0800 Subject: [PATCH 1/5] Add AIBrix LLM kv cache API - rolling hash based approach to maintain the token lineage - S3-FIFO inspired sync and GC mechanism Signed-off-by: DwyaneShi --- modules/llm-cache/CMakeLists.txt | 2 +- modules/llm-cache/ds/config.h | 24 + modules/llm-cache/ds/kv_cache_chunk.cc | 528 ++++++++++ modules/llm-cache/ds/kv_cache_chunk.h | 183 ++++ modules/llm-cache/ds/kv_cache_manager.cc | 25 + modules/llm-cache/ds/kv_cache_manager.h | 4 + modules/llm-cache/hash/hasher.h | 45 + modules/llm-cache/hash/md5.h | 47 + .../llm-cache/storage/aibrix_blob_storage.cc | 906 ++++++++++++++++++ .../llm-cache/storage/aibrix_blob_storage.h | 227 +++++ .../llm-cache/tests/aibrix_kv_cache_test.cc | 391 ++++++++ python/vineyard/llm/cache.cc | 45 +- python/vineyard/llm/cache.py | 101 +- src/client/client_base.cc | 25 +- src/client/client_base.h | 22 +- src/common/util/evicting_cache_map.h | 523 ++++++++++ src/common/util/protocols.cc | 50 +- src/common/util/protocols.h | 16 +- src/common/util/status.cc | 6 + src/common/util/status.h | 8 + src/server/async/socket_server.cc | 32 +- src/server/async/socket_server.h | 1 + src/server/server/vineyard_server.cc | 58 +- src/server/server/vineyard_server.h | 6 +- src/server/util/meta_tree.cc | 59 ++ src/server/util/meta_tree.h | 3 + test/evicting_cache_map_test.cc | 737 ++++++++++++++ test/runner.py | 14 + 28 files changed, 4056 insertions(+), 32 deletions(-) create mode 100644 modules/llm-cache/ds/kv_cache_chunk.cc create mode 100644 modules/llm-cache/ds/kv_cache_chunk.h create mode 100644 modules/llm-cache/hash/md5.h create mode 100644 modules/llm-cache/storage/aibrix_blob_storage.cc create mode 100644 modules/llm-cache/storage/aibrix_blob_storage.h create mode 100644 modules/llm-cache/tests/aibrix_kv_cache_test.cc create mode 100644 src/common/util/evicting_cache_map.h create mode 100644 test/evicting_cache_map_test.cc diff --git a/modules/llm-cache/CMakeLists.txt b/modules/llm-cache/CMakeLists.txt index 15f43a7cf..f53e8b6ca 100644 --- a/modules/llm-cache/CMakeLists.txt +++ b/modules/llm-cache/CMakeLists.txt @@ -16,7 +16,7 @@ file(GLOB VINEYARD_LLM_CACHE_SRCS "${CMAKE_CURRENT_SOURCE_DIR}" ) add_library(vineyard_llm_cache ${VINEYARD_LLM_CACHE_SRCS}) -target_link_libraries(vineyard_llm_cache PRIVATE libzstd_static ${GLOG_LIBRARIES}) +target_link_libraries(vineyard_llm_cache PRIVATE libzstd_static ${GLOG_LIBRARIES} ${OPENSSL_LIBRARIES}) target_link_libraries(vineyard_llm_cache PUBLIC vineyard_client) # install bundled thirdparty: rax and MurmurHash3 diff --git a/modules/llm-cache/ds/config.h b/modules/llm-cache/ds/config.h index 1ad50fac8..f9168da0a 100644 --- a/modules/llm-cache/ds/config.h +++ b/modules/llm-cache/ds/config.h @@ -85,6 +85,30 @@ struct FileCacheConfig : public KVCacheConfig { } }; +struct AIBrixCacheConfig : public KVCacheConfig { + int chunkSize; + std::string cacheNameSpace; + int localSyncInterval; // in seconds + bool enbaleGlobalGC; + int globalGCInterval; // in seconds + int globalTTL; // in seconds + + // Default local sync interval is 3 minutes and default global gc interval is + // 10 minutes. + AIBrixCacheConfig(int tensorByte = 10, int cacheCapacity = 10, int layer = 1, + int chunkSize = 4, std::string cacheNameSpace = "aibrix", + int localSyncInterval = 3 * 60, bool enbaleGlobalGC = true, + int globalGCInterval = 10 * 60, int globalTTL = 8 * 60) + : KVCacheConfig{tensorByte, cacheCapacity, layer} { + this->chunkSize = chunkSize; + this->cacheNameSpace = cacheNameSpace; + this->localSyncInterval = localSyncInterval; + this->enbaleGlobalGC = enbaleGlobalGC; + this->globalGCInterval = globalGCInterval; + this->globalTTL = globalTTL; + } +}; + } // namespace vineyard #endif // MODULES_LLM_CACHE_DS_CONFIG_H_ diff --git a/modules/llm-cache/ds/kv_cache_chunk.cc b/modules/llm-cache/ds/kv_cache_chunk.cc new file mode 100644 index 000000000..cc914642e --- /dev/null +++ b/modules/llm-cache/ds/kv_cache_chunk.cc @@ -0,0 +1,528 @@ +/** Copyright 2024 AIBrix. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include +#include + +#include "client/client.h" +#include "common/memory/memcpy.h" +#include "common/util/logging.h" +#include "llm-cache/ds/kv_cache_block.h" // LLMKV +#include "llm-cache/ds/kv_cache_chunk.h" +#include "llm-cache/hash/md5.h" + +namespace vineyard { + +void KVCacheChunk::Construct(const ObjectMeta& meta) { + Object::Construct(meta); + + std::string tname = type_name(); + + VINEYARD_ASSERT( + meta.GetTypeName() == tname, + "Expect typename '" + tname + "', but got '" + meta.GetTypeName() + "'"); + + // 1. construct the member field + total_tokens_ = meta.GetKeyValue(KVCacheChunk::kFieldNameTotalTokens); + tensor_nbytes_ = meta.GetKeyValue(KVCacheChunk::kFieldNameTensorNBytes); + layer_ = meta.GetKeyValue(KVCacheChunk::kFieldNameLayer); + chunk_size_ = meta.GetKeyValue(KVCacheChunk::kFieldNameChunkSize); + access_time_ = + meta_.GetKeyValue(KVCacheChunk::kFieldNameAccessTime); + md5_ = meta_.GetKeyValue(KVCacheChunk::kFieldNameMd5); + ns_ = meta_.GetKeyValue(KVCacheChunk::kFieldNameNS); + + // 2. construct the buffer + ObjectMeta blob_meta; + meta_.GetMemberMeta(KVCacheChunk::kFieldNameBuffer, blob_meta); + ObjectID blob_id = blob_meta.GetId(); + meta.GetBuffer(blob_id, buffer_); +} + +Status KVCacheChunkBuilder::Make(std::shared_ptr& builder, + RPCClient& rpc_client, int max_tokens, + int tensor_nbytes, int layer, int chunk_size, + const std::string& kv_cache_ns) { + size_t size = static_cast(chunk_size) * layer * tensor_nbytes * 2 + + max_tokens * sizeof(int); + builder = std::make_shared( + rpc_client, tensor_nbytes, layer, chunk_size, kv_cache_ns); + builder->chunk_id_ = InvalidObjectID(); + builder->remote_buffer_writer_ = std::make_shared(size); + + auto now = std::chrono::system_clock::now().time_since_epoch(); + builder->g_access_time_ = + std::chrono::duration_cast(now).count(); + + return Status::OK(); +} + +Status KVCacheChunkBuilder::Make(std::shared_ptr& builder, + RPCClient& rpc_client, int tensor_nbytes, + int layer, int chunk_size, + const std::string& kv_cache_ns, + ObjectID chunk_id) { + builder = std::make_shared( + rpc_client, tensor_nbytes, layer, chunk_size, kv_cache_ns); + builder->chunk_id_ = chunk_id; + return Status::OK(); +} + +Status KVCacheChunkBuilder::IsSame(const ObjectMeta& meta) { + RETURN_ON_ASSERT(is_ready_); + if (VLOG_IS_ON(100)) { + LOG(INFO) << "this: {total_tokens=" << total_tokens_ + << ", chunk_size=" << chunk_size_ + << ", tensor_nbytes=" << tensor_nbytes_ << ", layer=" << layer_ + << ", md5=" << md5_ << "}"; + LOG(INFO) << "that: " << meta.ToString(); + } + + RETURN_ON_ASSERT(meta.HasKey(KVCacheChunk::kFieldNameTotalTokens)); + RETURN_ON_ASSERT(meta.HasKey(KVCacheChunk::kFieldNameChunkSize)); + RETURN_ON_ASSERT(meta.HasKey(KVCacheChunk::kFieldNameTensorNBytes)); + RETURN_ON_ASSERT(meta.HasKey(KVCacheChunk::kFieldNameLayer)); + RETURN_ON_ASSERT(meta.HasKey(KVCacheChunk::kFieldNameAccessTime)); + RETURN_ON_ASSERT(meta.HasKey(KVCacheChunk::kFieldNameMd5)); + RETURN_ON_ASSERT(meta.HasKey(KVCacheChunk::kFieldNameNS)); + + RETURN_ON_ASSERT(meta.GetKeyValue(KVCacheChunk::kFieldNameTotalTokens) == + total_tokens_); + RETURN_ON_ASSERT(meta.GetKeyValue(KVCacheChunk::kFieldNameChunkSize) == + chunk_size_); + RETURN_ON_ASSERT(meta.GetKeyValue( + KVCacheChunk::kFieldNameTensorNBytes) == tensor_nbytes_); + RETURN_ON_ASSERT(meta.GetKeyValue(KVCacheChunk::kFieldNameLayer) == + layer_); + // We assume it's not possilbe to have same name and md5 of all tokens + RETURN_ON_ASSERT(meta.GetKeyValue(KVCacheChunk::kFieldNameMd5) == + md5_); + + RETURN_ON_ASSERT(meta.GetKeyValue(KVCacheChunk::kFieldNameNS) == + ns_); + + return Status::OK(); +} + +Status KVCacheChunkBuilder::Query( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensor) { + return QueryImpl(prefix, tokens, kv_tensor); +} + +Status KVCacheChunkBuilder::Construct() { + ObjectMeta object_meta; + std::shared_ptr object = nullptr; + std::shared_ptr chunk = nullptr; + + RETURN_ON_ASSERT(rpc_client_.Connected()); + + VLOG(100) << "Constructing " << ObjectIDToString(chunk_id_); + auto status = Status::OK(); + status = rpc_client_.GetMetaData(chunk_id_, object_meta, true); + if (!status.ok()) { + VLOG(100) << "Get meta data failed: " << status.ToString(); + return Status::ObjectNotExists(); + } else if (object_meta.IsLocal()) { + object = rpc_client_.GetObject(chunk_id_); + } + + // fetch from remote + if (object == nullptr) { + std::map cluster_info; + RETURN_ON_ERROR(rpc_client_.ClusterInfo(cluster_info)); + std::string rpc_endpoint = + cluster_info[object_meta.GetInstanceId()].value("rpc_endpoint", ""); + if (!rpc_endpoint.empty()) { + std::string rdma_endpoint = + cluster_info[object_meta.GetInstanceId()].value("rdma_endpoint", ""); + RPCClient remote_rpc_client; + RETURN_ON_ERROR( + remote_rpc_client.Connect(rpc_endpoint, "", "", rdma_endpoint)); + object = remote_rpc_client.GetObject(chunk_id_); + RETURN_ON_ASSERT(object != nullptr); + ObjectID buffer_id = + object_meta.GetMember(KVCacheChunk::kFieldNameBuffer)->id(); + std::shared_ptr blob; + RETURN_ON_ERROR(remote_rpc_client.GetRemoteBlob(buffer_id, blob)); + std::dynamic_pointer_cast(object)->buffer_ = blob->Buffer(); + } + } + + RETURN_ON_ASSERT(object != nullptr, "object is nullptr"); + LOG(INFO) << "Got " << ObjectIDToString(chunk_id_) << " from instance " + << object_meta.GetInstanceId(); + + chunk = std::dynamic_pointer_cast(object); + if (chunk->buffer_ == nullptr) { + return Status::IOError(); + } + + if (chunk_id_ != chunk->id()) { + // If the object is migrated, we should delete the copied object. + status = rpc_client_.DelData(chunk->id()); + if (!status.ok()) { + LOG(ERROR) << "Delete object failed: " << status.ToString() + << " It may cause memory leak."; + } + } + + // sanity check + RETURN_ON_ASSERT(tensor_nbytes_ == chunk->tensor_nbytes_); + RETURN_ON_ASSERT(layer_ == chunk->layer_); + RETURN_ON_ASSERT(chunk_size_ == chunk->chunk_size_); + RETURN_ON_ASSERT(ns_ == chunk->ns_); + + // construct meta info + auto all_tokens_off = + chunk->chunk_size_ * chunk->layer_ * chunk->tensor_nbytes_ * 2; + total_tokens_ = chunk->total_tokens_; + g_access_time_ = chunk->access_time_; + buffer_ = chunk->buffer_; + + auto* buffer = buffer_->data(); + + all_tokens_.resize(chunk->total_tokens_); + std::memcpy(all_tokens_.data(), buffer + all_tokens_off, + chunk->total_tokens_ * sizeof(int)); + + md5_ = md5(std::string(all_tokens_.begin(), all_tokens_.end())); + RETURN_ON_ASSERT(md5_ == chunk->md5_); + + return Status::OK(); +} + +Status KVCacheChunkBuilder::QueryImpl( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensor) { + RETURN_ON_ASSERT(tokens.size() == chunk_size_, + "The size of tokens is not equal to chunk_size"); + RETURN_ON_ASSERT(kv_tensor.size() == chunk_size_, + "The size of kv tensor is not equal to chunk_size"); + if (!is_ready_) { + std::unique_lock wlock(mutex_); + if (!is_ready_ && chunk_id_ != InvalidObjectID()) { + // need to construct from given chunk + Construct(); + is_ready_ = true; + cv_.notify_all(); + } else if (!is_ready_) { + // need to wait for the completion of update + cv_.wait(wlock, [this]() -> bool { return this->is_ready_; }); + } + } + + return QueryInternal(prefix, tokens, kv_tensor); +} + +Status KVCacheChunkBuilder::QueryInternal( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensor) { + uint8_t* buffer = nullptr; + if (remote_buffer_writer_ != nullptr) { + buffer = reinterpret_cast(remote_buffer_writer_->data()); + } else if (buffer_ != nullptr) { + buffer = const_cast(buffer_->data()); + } + + RETURN_ON_ASSERT(buffer != nullptr, "failed chunk"); + + auto all_tokens = prefix; + all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end()); + if (VLOG_IS_ON(100) && (all_tokens != all_tokens_)) { + auto str0 = "all_tokens[" + std::to_string(all_tokens.size()) + "]: "; + for (int i = 0; i < all_tokens.size(); i++) { + str0 += std::to_string(all_tokens[i]) + ", "; + } + auto str1 = "all_tokens_[" + std::to_string(all_tokens_.size()) + "]: "; + for (int i = 0; i < all_tokens_.size(); i++) { + str1 += std::to_string(all_tokens_[i]) + ", "; + } + LOG(INFO) << str0; + LOG(INFO) << str1; + } + RETURN_ON_ASSERT(all_tokens == all_tokens_, "tokens not match"); + + if (kv_tensor[0][0].first.data == nullptr) { + for (int i = 0; i < chunk_size_; i++) { + VINEYARD_ASSERT(kv_tensor[i].size() == layer_); + uint8_t* key_tensor_chunk_data = buffer + i * layer_ * tensor_nbytes_; + uint8_t* value_tensor_chunk_data = + key_tensor_chunk_data + chunk_size_ * layer_ * tensor_nbytes_; + + for (int j = 0; j < layer_; j++) { + LLMKV& key_tensor = kv_tensor[i][j].first; + LLMKV& value_tensor = kv_tensor[i][j].second; + VINEYARD_ASSERT(key_tensor.data == nullptr); + VINEYARD_ASSERT(value_tensor.data == nullptr); + + key_tensor.data = key_tensor_chunk_data + j * tensor_nbytes_; + key_tensor.length = tensor_nbytes_; + value_tensor.data = value_tensor_chunk_data + j * tensor_nbytes_; + value_tensor.length = tensor_nbytes_; + } + } + } else { + std::vector dst_buffers; + std::vector src_buffers; + dst_buffers.reserve(chunk_size_ * layer_ * 2); + src_buffers.reserve(chunk_size_ * layer_ * 2); + + for (int i = 0; i < chunk_size_; i++) { + uint8_t* key_tensor_chunk_data = buffer + i * layer_ * tensor_nbytes_; + + for (int j = 0; j < layer_; j++) { + LLMKV& key_tensor = kv_tensor[i][j].first; + VINEYARD_ASSERT(key_tensor.data != nullptr); + VINEYARD_ASSERT(key_tensor.length == tensor_nbytes_); + + uint8_t* key_tensor_data = key_tensor_chunk_data + j * tensor_nbytes_; + + dst_buffers.push_back(key_tensor.data); + src_buffers.push_back(key_tensor_data); + } + } + + for (int i = 0; i < chunk_size_; i++) { + uint8_t* key_tensor_chunk_data = buffer + i * layer_ * tensor_nbytes_; + uint8_t* value_tensor_chunk_data = + key_tensor_chunk_data + chunk_size_ * layer_ * tensor_nbytes_; + + for (int j = 0; j < layer_; j++) { + LLMKV& value_tensor = kv_tensor[i][j].second; + VINEYARD_ASSERT(value_tensor.data != nullptr); + VINEYARD_ASSERT(value_tensor.length == tensor_nbytes_); + + uint8_t* value_tensor_data = + value_tensor_chunk_data + j * tensor_nbytes_; + + dst_buffers.push_back(value_tensor.data); + src_buffers.push_back(value_tensor_data); + } + } + + vineyard::memory::concurrent_memcpy_n(dst_buffers, src_buffers, + tensor_nbytes_); + } + + if (VLOG_IS_ON(200)) { + PrintKVCacheChunk(); + } + return Status::OK(); +} + +Status KVCacheChunkBuilder::Update( + const std::vector& prefix, const std::vector& tokens, + const std::vector>>& kv_tensor) { + return UpdateImpl(prefix, tokens, kv_tensor); +} + +Status KVCacheChunkBuilder::UpdateImpl( + const std::vector& prefix, const std::vector& tokens, + const std::vector>>& kv_tensor) { + VINEYARD_ASSERT(tokens.size() == chunk_size_, + "The size of tokens is not equal to chunk_size"); + VINEYARD_ASSERT(kv_tensor.size() == chunk_size_, + "The size of kv tensor is not equal to chunk_size"); + + VINEYARD_ASSERT(remote_buffer_writer_ != nullptr, + "remote_buffer_writer_ is nullptr"); + + VINEYARD_ASSERT(!is_ready_); + + std::unique_lock wlock(mutex_); + + all_tokens_.reserve(prefix.size() + tokens.size()); + all_tokens_.insert(all_tokens_.end(), prefix.begin(), prefix.end()); + all_tokens_.insert(all_tokens_.end(), tokens.begin(), tokens.end()); + + md5_ = md5(std::string(all_tokens_.begin(), all_tokens_.end())); + + total_tokens_ = all_tokens_.size(); + VINEYARD_ASSERT(total_tokens_ == prefix.size() + tokens.size()); + auto buffer = reinterpret_cast(remote_buffer_writer_->data()); + + std::vector dst_buffers; + std::vector src_buffers; + dst_buffers.reserve(chunk_size_ * layer_ * 2); + src_buffers.reserve(chunk_size_ * layer_ * 2); + + for (int i = 0; i < chunk_size_; i++) { + uint8_t* key_tensor_chunk_data = buffer + i * layer_ * tensor_nbytes_; + VINEYARD_ASSERT(kv_tensor[i].size() == layer_); + + for (int j = 0; j < layer_; j++) { + LLMKV key_tensor = kv_tensor[i][j].first; + VINEYARD_ASSERT(key_tensor.length == tensor_nbytes_); + + uint8_t* key_tensor_data = key_tensor_chunk_data + j * tensor_nbytes_; + dst_buffers.push_back(key_tensor_data); + src_buffers.push_back(key_tensor.data); + } + } + + for (int i = 0; i < chunk_size_; i++) { + uint8_t* key_tensor_chunk_data = buffer + i * layer_ * tensor_nbytes_; + uint8_t* value_tensor_chunk_data = + key_tensor_chunk_data + chunk_size_ * layer_ * tensor_nbytes_; + + for (int j = 0; j < layer_; j++) { + LLMKV value_tensor = kv_tensor[i][j].second; + VINEYARD_ASSERT(value_tensor.length == tensor_nbytes_); + + uint8_t* value_tensor_data = value_tensor_chunk_data + j * tensor_nbytes_; + dst_buffers.push_back(value_tensor_data); + src_buffers.push_back(value_tensor.data); + } + } + + vineyard::memory::concurrent_memcpy_n(dst_buffers, src_buffers, + tensor_nbytes_); + + // write all tokens + buffer += chunk_size_ * layer_ * tensor_nbytes_ * 2; + std::memcpy(buffer, all_tokens_.data(), total_tokens_ * sizeof(int)); + + is_ready_ = true; + cv_.notify_all(); + + if (VLOG_IS_ON(200)) { + PrintKVCacheChunk(); + } + + return Status::OK(); +} + +std::shared_ptr KVCacheChunkBuilder::Seal() { + VINEYARD_ASSERT(buffer_ == nullptr); + + if (remote_buffer_writer_ == nullptr) { + return nullptr; + } + + auto chunk = std::make_shared(); + + // 1. seal the buffer + ObjectMeta blob_meta; + Status status = + rpc_client_.CreateRemoteBlob(remote_buffer_writer_, blob_meta); + if (!status.ok()) { + VLOG(100) << "Failed to CreateRemoteBlob, error=" << status.ToString(); + return nullptr; + } + + size_t nbytes = remote_buffer_writer_->size(); + chunk->meta_.AddMember(KVCacheChunk::kFieldNameBuffer, blob_meta); + + // 2. store the member field to meta + chunk->meta_.AddKeyValue(KVCacheChunk::kFieldNameTotalTokens, total_tokens_); + chunk->meta_.AddKeyValue(KVCacheChunk::kFieldNameChunkSize, chunk_size_); + chunk->meta_.AddKeyValue(KVCacheChunk::kFieldNameTensorNBytes, + tensor_nbytes_); + chunk->meta_.AddKeyValue(KVCacheChunk::kFieldNameLayer, layer_); + chunk->meta_.AddKeyValue(KVCacheChunk::kFieldNameAccessTime, access_time_); + chunk->meta_.AddKeyValue(KVCacheChunk::kFieldNameMd5, md5_); + chunk->meta_.AddKeyValue(KVCacheChunk::kFieldNameNS, ns_); + chunk->meta_.SetNBytes(nbytes); + + // 3. set the object type to meta + chunk->meta_.SetTypeName(type_name()); + + if (!rpc_client_.CreateMetaData(chunk->meta_, chunk->id_).ok()) { + return nullptr; + } + + return chunk; +} + +void KVCacheChunkBuilder::PrintKVCacheChunk() { + LOG(INFO) << "builder:" << this; + + uint8_t* buffer = nullptr; + if (remote_buffer_writer_ != nullptr) { + buffer = reinterpret_cast(remote_buffer_writer_->data()); + } else if (buffer_ != nullptr) { + buffer = const_cast(buffer_->data()); + } + + if (buffer == nullptr) { + LOG(INFO) << ">failed chunk"; + } + + if (total_tokens_ > chunk_size_) { + std::string prefix_tokens = ""; + for (size_t i = 0; i < total_tokens_ - chunk_size_; i++) { + prefix_tokens += std::to_string(all_tokens_[i]) + " "; + } + LOG(INFO) << ">prefix tokens:" << prefix_tokens; + } else { + LOG(INFO) << ">prefix tokens: N/A"; + } + + for (int i = 0; i < chunk_size_; i++) { + LOG(INFO) << ">index:" << i; + LOG(INFO) << ">token:" << all_tokens_[total_tokens_ - chunk_size_ + i]; + + uint8_t* key_tensor_chunk_data = buffer + i * layer_ * tensor_nbytes_; + uint8_t* value_tensor_chunk_data = + key_tensor_chunk_data + chunk_size_ * layer_ * tensor_nbytes_; + + for (int curr = 0; curr < layer_; curr++) { + LOG(INFO) << ">layer:" << curr; + uint8_t* key_tensor_data = key_tensor_chunk_data + curr * tensor_nbytes_; + uint8_t* value_tensor_data = + value_tensor_chunk_data + curr * tensor_nbytes_; + + // print the first tensor_nbytes bytes + std::string key_tensor = ""; + std::string value_tensor = ""; + for (int j = 0; j < tensor_nbytes_; j++) { + key_tensor += std::to_string(key_tensor_data[j]) + " "; + value_tensor += std::to_string(value_tensor_data[j]) + " "; + } + LOG(INFO) << ">>key_tensor:" << key_tensor; + LOG(INFO) << ">>value_tensor:" << value_tensor; + } + } + + static const auto get_ts = + [](std::chrono::duration time) { + auto duration_since_epoch = + std::chrono::duration_cast( + time); + std::chrono::time_point timestamp = + std::chrono::system_clock::time_point(duration_since_epoch); + time_t t = std::chrono::system_clock::to_time_t(timestamp); + + std::tm tm; + localtime_r(&t, &tm); + std::ostringstream oss; + oss << std::put_time(&tm, "%Y-%m-%d %H:%M:%S"); + return oss.str(); + }; + + LOG(INFO) << ">global_access_time:" + << get_ts(std::chrono::nanoseconds(g_access_time_)); + LOG(INFO) << ">access_time:" + << get_ts(std::chrono::nanoseconds(access_time_)); + + LOG(INFO) << "=========================="; +} + +} // namespace vineyard diff --git a/modules/llm-cache/ds/kv_cache_chunk.h b/modules/llm-cache/ds/kv_cache_chunk.h new file mode 100644 index 000000000..b7f3d133f --- /dev/null +++ b/modules/llm-cache/ds/kv_cache_chunk.h @@ -0,0 +1,183 @@ +/** Copyright 2024 AIBrix. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef MODULES_LLM_CACHE_DS_KV_CACHE_CHUNK_H_ +#define MODULES_LLM_CACHE_DS_KV_CACHE_CHUNK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "client/client.h" +#include "client/ds/blob.h" +#include "client/ds/i_object.h" +#include "client/ds/remote_blob.h" +#include "client/rpc_client.h" +#include "llm-cache/ds/kv_tensor.h" + +namespace vineyard { + +// forward declaration +struct LLMKV; + +class KVCacheChunk : public vineyard::Registered { + public: + inline static constexpr char kFieldNameNS[] = "namespace"; + inline static constexpr char kFieldNameBuffer[] = "buffer"; + inline static constexpr char kFieldNameTotalTokens[] = "total_tokens"; + inline static constexpr char kFieldNameTensorNBytes[] = "tensor_nbytes"; + inline static constexpr char kFieldNameLayer[] = "layer"; + inline static constexpr char kFieldNameChunkSize[] = "chunk_size"; + inline static constexpr char kFieldNameAccessTime[] = "access_time"; + inline static constexpr char kFieldNameMd5[] = "md5"; + + private: + std::shared_ptr buffer_; + int total_tokens_; + int tensor_nbytes_; + int layer_; + int chunk_size_; + uint64_t access_time_; + std::string md5_; + std::string ns_; + + public: + static std::unique_ptr Create() __attribute__((used)) { + return std::static_pointer_cast( + std::unique_ptr{new KVCacheChunk()}); + } + + void Construct(const ObjectMeta& meta) override; + + int GetChunkSize() { return chunk_size_; } + + static std::string GetNameSpace(const std::string& kv_cache_ns) { + return std::regex_replace(kv_cache_ns, std::regex("_+$"), ""); + } + + ~KVCacheChunk() = default; + + friend class KVCacheChunkBuilder; +}; + +class KVCacheChunkBuilder { + private: + RPCClient& rpc_client_; + std::vector all_tokens_; + std::shared_ptr remote_buffer_writer_ = nullptr; + ObjectID chunk_id_; + std::shared_ptr buffer_ = nullptr; + int total_tokens_; + int tensor_nbytes_; + int layer_; + int chunk_size_; + std::string ns_; + std::shared_mutex time_mu_; + uint64_t g_access_time_ = 0; + uint64_t access_time_ = 0; + std::mutex mutex_; + std::condition_variable cv_; + std::atomic is_ready_ = false; + std::string md5_; + + public: + static Status Make(std::shared_ptr& chunk_builder, + RPCClient& rpc_client, int max_tokens, int tensor_nbytes, + int layer, int chunk_size, const std::string& kv_cache_ns); + + static Status Make(std::shared_ptr& chunk_builder, + RPCClient& rpc_client, int tensor_nbytes, int layer, + int chunk_size, const std::string& kv_cache_ns, + ObjectID chunk_id); + + Status Update( + const std::vector& prefix, const std::vector& tokens, + const std::vector>>& kv_tensor); + + Status Query(const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensor); + + void SetAccessTime(uint64_t time) { + std::unique_lock wlock(time_mu_); + if (time > access_time_) { + access_time_ = time; + } + } + + void SetGlobalAccessTime(uint64_t time) { + std::unique_lock wlock(time_mu_); + if (time > g_access_time_) { + g_access_time_ = time; + } + } + + uint64_t GetGlobalAccessTime() { + std::shared_lock rlock(time_mu_); + return g_access_time_; + } + + uint64_t GetAccessTime() { + std::shared_lock rlock(time_mu_); + return access_time_; + } + + bool IsReady() { return is_ready_; } + + std::shared_ptr Seal(); + + uint64_t GetTensorNBytes() { return tensor_nbytes_; } + + int GetChunkSize() { return chunk_size_; } + + void PrintKVCacheChunk(); + + Status IsSame(const ObjectMeta& meta); + + KVCacheChunkBuilder(RPCClient& rpc_client, int tensor_nbytes, int layer, + int chunk_size, const std::string& kv_cache_ns) + : rpc_client_(rpc_client), + tensor_nbytes_(tensor_nbytes), + layer_(layer), + chunk_size_(chunk_size), + ns_(KVCacheChunk::GetNameSpace(kv_cache_ns)) {} + + ~KVCacheChunkBuilder() = default; + + private: + Status Construct(); + + Status UpdateImpl( + const std::vector& prefix, const std::vector& tokens, + const std::vector>>& kv_tensor); + + Status QueryImpl( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensor); + + Status QueryInternal( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensor); +}; + +} // namespace vineyard + +#endif // MODULES_LLM_CACHE_DS_KV_CACHE_CHUNK_H_ diff --git a/modules/llm-cache/ds/kv_cache_manager.cc b/modules/llm-cache/ds/kv_cache_manager.cc index 7a4912fe6..841747c0e 100644 --- a/modules/llm-cache/ds/kv_cache_manager.cc +++ b/modules/llm-cache/ds/kv_cache_manager.cc @@ -26,6 +26,7 @@ limitations under the License. #include "common/util/status.h" #include "llm-cache/ds/kv_cache.h" #include "llm-cache/ds/kv_cache_manager.h" +#include "llm-cache/storage/aibrix_blob_storage.h" #include "llm-cache/storage/blob_storage.h" #include "llm-cache/storage/local_file_storage.h" #include "llm-cache/storage/vineyard_file_storage.h" @@ -144,6 +145,30 @@ Status KVCacheManager::Make(RPCClient& rpc_client, Client& ipc_client, return Status::OK(); } +Status KVCacheManager::Make(RPCClient& rpc_client, Client& ipc_client, + std::shared_ptr& manager, + AIBrixCacheConfig& config) { + if (config.tensorByte <= 0 || config.cacheCapacity <= 0 || + config.layer <= 0) { + return Status::Invalid("Invalid tensor byte, cache capacity or layer."); + } + + if (config.chunkSize <= 0) { + return Status::Invalid("Invalid chunk size"); + } + + std::shared_ptr storage = + std::make_shared( + rpc_client, ipc_client, config.tensorByte, config.cacheCapacity, + config.layer, config.chunkSize, config.cacheNameSpace, + config.localSyncInterval, config.enbaleGlobalGC, + config.globalGCInterval, config.globalTTL); + manager = std::make_shared(storage); + RETURN_ON_ERROR(storage->Init()); + manager->config = std::make_shared(config); + return Status::OK(); +} + /** * @brief Update the kv state with the given token list in the kv state cache * manager. diff --git a/modules/llm-cache/ds/kv_cache_manager.h b/modules/llm-cache/ds/kv_cache_manager.h index 073994a26..c944ca9ff 100644 --- a/modules/llm-cache/ds/kv_cache_manager.h +++ b/modules/llm-cache/ds/kv_cache_manager.h @@ -45,6 +45,10 @@ class KVCacheManager { std::shared_ptr& manager, FileCacheConfig& config); + static Status Make(RPCClient& rpc_client, Client& ipc_client, + std::shared_ptr& manager, + AIBrixCacheConfig& config); + Status Update(const std::vector& tokenList, int nextToken, const std::vector>& kvState); diff --git a/modules/llm-cache/hash/hasher.h b/modules/llm-cache/hash/hasher.h index 0d936df3f..55b3ca44f 100644 --- a/modules/llm-cache/hash/hasher.h +++ b/modules/llm-cache/hash/hasher.h @@ -84,6 +84,51 @@ class Hasher { return Status::OK(); } + /* + * This function processes a sequence of tokens in fixed-size chunks to + * generate a sequence of hash values. + * + * For each chunk of tokens: + * - If it's not the first chunk, the hash of the previous chunk is added as + * a prefix to the current chunk. + * - The combined data (previous hash and current chunk) is hashed to + * produce a new hash value. + * - The new hash is converted to a hexadecimal string. + * + * The function thus produces a series of interdependent hash values, each + * influenced by the previous hash. + */ + Status computeChunkHashesForTokens(const std::vector& tokens, + int chunkSize, + std::vector& hashes) { + char hashBuffer[9]; + int tokenSize = tokens.size() - tokens.size() % chunkSize; + // if the token list (upper_bound) is less than the batch size, then return + // directly + if (tokenSize < chunkSize) { + return Status::OK(); + } + + std::vector candidates; + candidates.reserve(chunkSize + 1); + int prevHash = 0; + for (int i = 0; i < tokenSize; i += chunkSize) { + if (i > 0) { + candidates.push_back(prevHash); + } + candidates.insert(candidates.end(), tokens.begin() + i, + tokens.begin() + i + chunkSize); + auto currHash = + hashAlgorithm->hash(reinterpret_cast(candidates.data()), + candidates.size() * sizeof(int)); + std::snprintf(hashBuffer, sizeof(hashBuffer), "%08x", currHash); + hashes.push_back(hashBuffer); + candidates.clear(); + prevHash = currHash; + } + return Status::OK(); + } + private: IHashAlgorithm* hashAlgorithm; }; diff --git a/modules/llm-cache/hash/md5.h b/modules/llm-cache/hash/md5.h new file mode 100644 index 000000000..ff53df622 --- /dev/null +++ b/modules/llm-cache/hash/md5.h @@ -0,0 +1,47 @@ +/** Copyright 2024 AIBrix. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef MODULES_LLM_CACHE_HASH_MD5_H_ +#define MODULES_LLM_CACHE_HASH_MD5_H_ + +#include +#include +#include + +namespace vineyard { + +std::string md5(const std::string& content) { + auto* context = EVP_MD_CTX_new(); + const auto* md = EVP_md5(); + unsigned char md_value[EVP_MAX_MD_SIZE]; + unsigned int md_len; + std::string output; + + EVP_DigestInit_ex(context, md, nullptr); + EVP_DigestUpdate(context, content.c_str(), content.size()); + EVP_DigestFinal_ex(context, md_value, &md_len); + EVP_MD_CTX_free(context); + + output.resize(md_len * 2); + for (int i = 0; i < md_len; i++) { + std::sprintf(&output[i * 2], "%02x", // NOLINT(runtime/printf) + md_value[i]); + } + return output; +} + +} // namespace vineyard + +#endif // MODULES_LLM_CACHE_HASH_MD5_H_ diff --git a/modules/llm-cache/storage/aibrix_blob_storage.cc b/modules/llm-cache/storage/aibrix_blob_storage.cc new file mode 100644 index 000000000..4916fffd6 --- /dev/null +++ b/modules/llm-cache/storage/aibrix_blob_storage.cc @@ -0,0 +1,906 @@ +/** Copyright 2024 AIBrix. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm-cache/storage/aibrix_blob_storage.h" +#include "llm-cache/thread_group.h" + +namespace vineyard { +AIBrixBlobStorage::AIBrixBlobStorage( + RPCClient& rpc_client, Client& ipc_client, size_t tensor_nbytes, + int capacity, int layer, int chunk_size, std::string kv_cache_ns, + int64_t local_sync_interval_s, bool global_gc_enabled, + int64_t global_gc_interval_s, int64_t global_ttl_s) + : rpc_client_(rpc_client), + ipc_client_(ipc_client), + hash_alg_(std::make_shared()), + hasher_(std::make_shared(hash_alg_.get())), + tensor_nbytes_(tensor_nbytes), + capacity_(capacity), + layer_(layer), + chunk_size_(chunk_size), + chunk_obj_size_(tensor_nbytes * 2 * layer * chunk_size + + kMaxTokensPerSeq * sizeof(int)), + kv_cache_ns_(kv_cache_ns), + local_sync_interval_s_(std::chrono::seconds(local_sync_interval_s)), + global_gc_enabled_(global_gc_enabled), + global_gc_interval_s_(std::chrono::seconds(global_gc_interval_s)), + global_ttl_s_(std::chrono::seconds(global_ttl_s)), + ghost_fifo_(capacity_), + small_fifo_(capacity_ * kSmallFifoCapacityRatio), + main_fifo_(capacity_ - capacity_ * kSmallFifoCapacityRatio, + kMinEviction) { + kv_cache_ns_ = std::regex_replace(kv_cache_ns_, std::regex("/"), "_"); + kv_cache_ns_ = std::regex_replace(kv_cache_ns_ + "_", std::regex("_+"), "_"); +} + +Status AIBrixBlobStorage::Init() { + if (!rpc_client_.Connected()) { + RETURN_ON_ASSERT(ipc_client_.Connected()); + // check if rpc is enabled on server side + std::map cluster_info; + RETURN_ON_ERROR(ipc_client_.ClusterInfo(cluster_info)); + auto instance_id = ipc_client_.instance_id(); + std::string rpc_endpoint = + cluster_info[instance_id].value("rpc_endpoint", ""); + RETURN_ON_ASSERT(!rpc_endpoint.empty()); + std::string rdma_endpoint = + cluster_info[instance_id].value("rdma_endpoint", ""); + RETURN_ON_ERROR(rpc_client_.Connect(rpc_endpoint, "", "", rdma_endpoint)); + } + + RETURN_ON_ASSERT(rpc_client_.Connected()); + + main_fifo_.setPruneHook([this](auto&& pairs) { + std::vector delete_list; + MainFifoPruneHookLoop(std::forward(pairs), delete_list); + + auto status = this->Delete(delete_list); + if (!status.ok()) { + LOG(ERROR) << "Failed to delete objects, " << status.ToString(); + } + }); + + small_fifo_.setPruneHook([this](auto&& pairs) { + for (auto& pair : pairs) { + auto& key = pair.first; + auto& entry = pair.second; + VLOG(100) << "Evicting " << key << " from S"; + // Items in S should not have been persisted + VINEYARD_ASSERT(entry.object_id == InvalidObjectID()); + if (entry.access_bit) { + VLOG(100) << key << " has been accessed, promote it to M"; + // Promote the item to M + entry.access_bit = false; + + std::unique_lock lock(this->main_fifo_mu_); + this->main_fifo_.set(key, std::move(entry), /* promote */ false); + } else { + VLOG(100) << "Evict " << key << " to G"; + // Evict the item to G + entry.chunk_builder.reset(); + entry.access_bit = false; + + std::unique_lock lock(this->ghost_fifo_mu_); + this->ghost_fifo_.set(key, std::move(entry), /* promote */ false); + } + } + }); + + ghost_fifo_.setPruneHook([this](auto&& pairs) { + for (auto& pair : pairs) { + auto& key = pair.first; + auto& entry = pair.second; + VLOG(100) << "Evicting " << key << " from G"; + // Items in G should not have been persisted + VINEYARD_ASSERT(entry.object_id == InvalidObjectID()); + entry.chunk_builder.reset(); + } + }); + + VINEYARD_DISCARD(BuildMainFifo()); + + global_gc_thread_ = + std::thread(AIBrixBlobStorage::GlobalGCThread, shared_from_this()); + local_sync_thread_ = + std::thread(AIBrixBlobStorage::LocalSyncThread, shared_from_this()); + + return Status::OK(); +} + +template +void AIBrixBlobStorage::MainFifoPruneHookLoop( + PairsT&& pairs, std::vector& delete_list) { + for (auto& pair : pairs) { + auto& key = pair.first; + auto& entry = pair.second; + + VLOG(100) << "Evicting " << key << " from M"; + + if (entry.access_bit) { + VLOG(100) << key << " has been accessed, reinsert it back to M"; + entry.access_bit = false; + this->main_fifo_.set(key, std::move(entry)); + } else { + // If the entry has a valid object id, then it has been persisted to the + // object store, we have to explicitly delete it. + if (entry.object_id != InvalidObjectID()) { + VLOG(100) << "Deleting obj=" << entry.object_id; + delete_list.push_back(entry.object_id); + } + + entry.chunk_builder.reset(); + } + } +} + +Status AIBrixBlobStorage::BuildMainFifo() { + std::vector chunk_metas; + RETURN_ON_ERROR(ListKVCache(kv_cache_ns_, chunk_metas)); + VLOG(100) << "Global M: " << chunk_metas.size() << " chunks"; + + std::vector delete_list; + auto build_prune_hook = [this, &delete_list](auto&& pairs) { + MainFifoPruneHookLoop(std::forward(pairs), delete_list); + }; + + { + std::unique_lock lock(main_fifo_mu_); + for (const auto& meta : chunk_metas) { + if (!meta.HasKey("__name")) { + continue; + } + const auto& chunk_name = meta.GetKeyValue("__name"); + const auto chunk_id = meta.GetId(); + auto it = main_fifo_.findWithoutPromotion(chunk_name); + if (it == main_fifo_.end()) { + FifoEntry entry; + entry.object_id = chunk_id; + // For insert here, we use a dedicate prune hook to remove + // delete operations from the critical path + VLOG(100) << "Main fifo: insert " << chunk_name << ", obj id " + << ObjectIDToString(chunk_id); + main_fifo_.set(chunk_name, std::move(entry), /* promote */ false, + build_prune_hook); + } else if (it->second.chunk_builder && + it->second.chunk_builder->IsReady()) { + // try to update access time + auto access_time_label = meta.Label(KVCacheChunk::kFieldNameAccessTime); + if (access_time_label.empty()) { + access_time_label = std::to_string( + meta.GetKeyValue(KVCacheChunk::kFieldNameAccessTime)); + } + uint64_t time = std::stoull(access_time_label); + it->second.chunk_builder->SetGlobalAccessTime(time); + it->second.chunk_builder->SetAccessTime(time); + } + } + } + + if (!delete_list.empty()) { + Delete(delete_list); + } + + return Status::OK(); +} + +Status AIBrixBlobStorage::GetTokenChunkHashes( + const std::vector& prefix, const std::vector& tokens, + std::vector& chunk_hashes) { + std::vector all(prefix.begin(), prefix.end()); + all.insert(all.end(), tokens.begin(), tokens.end()); + + RETURN_ON_ERROR( + hasher_->computeChunkHashesForTokens(all, chunk_size_, chunk_hashes)); + auto sz = tokens.size() / chunk_size_; + chunk_hashes = + std::vector(chunk_hashes.end() - sz, chunk_hashes.end()); + return Status::OK(); +} + +#define DEFINE_TASK_FN(FN, OP, CB) \ + auto FN = [this, &prefix, &tokens, &kv_tensors, cb = CB]( \ + size_t i, \ + std::shared_ptr builder) -> Status { \ + auto chunk_size = this->chunk_size_; \ + if (builder == nullptr) { \ + return Status::OK(); \ + } \ + \ + std::vector my_prefix(prefix.begin(), prefix.end()); \ + if (i > 0) { \ + my_prefix.insert(my_prefix.end(), tokens.begin(), \ + tokens.begin() + i * chunk_size); \ + } \ + std::vector my_tokens(tokens.begin() + i * chunk_size, \ + tokens.begin() + (i + 1) * chunk_size); \ + \ + std::vector>> my_kv_tensors( \ + kv_tensors.begin() + i * chunk_size, \ + kv_tensors.begin() + (i + 1) * chunk_size); \ + \ + auto status = builder->OP(my_prefix, my_tokens, my_kv_tensors); \ + if (status.ok()) { \ + cb(i, my_kv_tensors); \ + } \ + return status; \ + } + +#define WAIT_TASK_RESULTS(TIDS, COUNTER, FIRST_ERROR, OBJ_NAMES) \ + { \ + bool skip_rest = false; \ + for (size_t i = 0; i < TIDS.size(); ++i) { \ + auto status = tg.TaskResult(TIDS[i]); \ + if (status.ok() && !skip_rest) { \ + COUNTER += chunk_size_; \ + } else if (!status.ok() && skip_rest == false) { \ + FIRST_ERROR = status; \ + skip_rest = true; \ + VLOG(100) << "First error: " << FIRST_ERROR.ToString(); \ + } else { \ + /* delete from index */ \ + { \ + std::unique_lock lock(main_fifo_mu_); \ + main_fifo_.erase(OBJ_NAMES[i]); \ + } \ + { \ + std::unique_lock lock(small_fifo_mu_); \ + small_fifo_.erase(OBJ_NAMES[i]); \ + } \ + VLOG(100) << "Error: " << status.ToString(); \ + } \ + } \ + } + +Status AIBrixBlobStorage::UpdateInternal( + const std::vector& prefix, const std::vector& tokens, + const std::vector>>& kv_tensors, + size_t& updated) { + updated = 0; + + if (exit_flag_) { + return Status::Invalid("Storage has been closed"); + } + + if (prefix.size() % chunk_size_ != 0) { + return Status::Invalid("Prefix size " + std::to_string(prefix.size()) + + " should be multiple of chunk size " + + std::to_string(chunk_size_)); + } + + if (tokens.size() != kv_tensors.size()) { + return Status::Invalid("Tokens size " + std::to_string(tokens.size()) + + " should be equal to kv tensors size " + + std::to_string(kv_tensors.size())); + } + + if (tokens.size() > kMaxTokensPerSeq) { + return Status::Invalid("Token list size exceeds the size limit"); + } + + std::vector chunk_hashes; + GetTokenChunkHashes(prefix, tokens, chunk_hashes); + std::vector obj_names; + for (const auto& chunk_hash : chunk_hashes) { + obj_names.push_back(kv_cache_ns_ + chunk_hash); + } + + auto now = std::chrono::system_clock::now().time_since_epoch(); + auto access_time = + std::chrono::duration_cast(now).count(); + + parallel::ThreadGroup tg( + std::min(obj_names.size(), + static_cast(std::thread::hardware_concurrency()))); + std::vector tids; + + DEFINE_TASK_FN(fn, Update, [](auto, auto&&) {}); + + auto ret = Status::OK(); + size_t index; + { + std::unique_lock slock(small_fifo_mu_); + std::unique_lock mlock(main_fifo_mu_); + for (index = 0; index < obj_names.size(); index++) { + const auto& obj_name = obj_names[index]; + auto sit = small_fifo_.findWithoutPromotion(obj_name); + if (sit == small_fifo_.end()) { + auto mit = main_fifo_.findWithoutPromotion(obj_name); + if (mit == main_fifo_.end()) { + break; + } + } + } + } + + updated += index * chunk_size_; + + if (index == obj_names.size()) { + // we find all chunks in local + return ret; + } + + // right now index points to the first missing chunk + for (size_t i = index; i < obj_names.size(); i++) { + const auto& obj_name = obj_names[i]; + EvictingCacheMap* target_fifo = nullptr; + std::mutex* target_fifo_mu = nullptr; + + { + std::unique_lock lock(ghost_fifo_mu_); + auto it = ghost_fifo_.findWithoutPromotion(obj_name); + if (it != ghost_fifo_.end()) { + target_fifo = &main_fifo_; + target_fifo_mu = &main_fifo_mu_; + } else { + target_fifo = &small_fifo_; + target_fifo_mu = &small_fifo_mu_; + } + } + + FifoEntry entry; + { + std::unique_lock lock(*target_fifo_mu); + auto it = target_fifo->findWithoutPromotion(obj_name); + if (it != target_fifo->end()) { + // chunk is already in the cache + tids.push_back(tg.AddTask(fn, i, nullptr)); + continue; + } + + auto status = KVCacheChunkBuilder::Make( + entry.chunk_builder, rpc_client_, kMaxTokensPerSeq, tensor_nbytes_, + layer_, chunk_size_, kv_cache_ns_); + if (!status.ok()) { + VLOG(100) << "Failed to make chunk builder, " << status.ToString(); + ret += status; + // skip this and rest chunks + break; + } + + entry.chunk_builder->SetAccessTime(access_time); + if (VLOG_IS_ON(100)) { + if (target_fifo == &main_fifo_) { + LOG(INFO) << "Main fifo: insert " << obj_name; + } else { + LOG(INFO) << "Small fifo: insert " << obj_name; + } + } + target_fifo->set(obj_name, entry, /* promote */ false); + + tids.push_back(tg.AddTask(fn, i, entry.chunk_builder)); + } + } + + if (tids.empty()) { + return ret; + } + + Status first_error = Status::OK(); + WAIT_TASK_RESULTS(tids, updated, first_error, obj_names); + return first_error; +} + +Status AIBrixBlobStorage::QueryInternal( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensors, + size_t& matched) { + matched = 0; + + if (exit_flag_) { + return Status::Invalid("Storage has been closed"); + } + + if (prefix.size() % chunk_size_ != 0) { + return Status::Invalid("Prefix size " + std::to_string(prefix.size()) + + " should be multiple of chunk size " + + std::to_string(chunk_size_)); + } + + if (tokens.size() != kv_tensors.size()) { + return Status::Invalid("Tokens size " + std::to_string(tokens.size()) + + " should be equal to kv tensors size " + + std::to_string(kv_tensors.size())); + } + + if (tokens.size() > kMaxTokensPerSeq) { + return Status::Invalid("Token list size exceeds the size limit"); + } + + std::vector chunk_hashes; + GetTokenChunkHashes(prefix, tokens, chunk_hashes); + std::vector obj_names; + for (const auto& chunk_hash : chunk_hashes) { + obj_names.push_back(kv_cache_ns_ + chunk_hash); + } + + auto now = std::chrono::system_clock::now().time_since_epoch(); + auto access_time = + std::chrono::duration_cast(now).count(); + + parallel::ThreadGroup tg( + std::min(obj_names.size(), + static_cast(std::thread::hardware_concurrency()))); + std::vector tids; + + bool is_zero_copy = kv_tensors[0][0].first.data == nullptr; + + auto cb = [&kv_tensors, chunk_size = chunk_size_, is_zero_copy]( + size_t i, auto&& my_kv_tensors) { + if (!is_zero_copy) { + return; + } + + // for zero-copy use case, we need to copy descriptors back + size_t j = 0; + for (auto it = kv_tensors.begin() + i * chunk_size; + it != kv_tensors.begin() + (i + 1) * chunk_size; it++) { + *it = my_kv_tensors[j++]; + } + }; + + DEFINE_TASK_FN(fn, Query, cb); + + for (size_t i = 0; i < obj_names.size(); i++) { + const auto& obj_name = obj_names[i]; + + std::shared_ptr query_builder = nullptr; + // Check if the chunk is in S first, and then check M + { + std::unique_lock lock(small_fifo_mu_); + auto it = small_fifo_.findWithoutPromotion(obj_name); + if (it != small_fifo_.end()) { + VLOG(100) << "Hit " << obj_name << " in S"; + it->second.access_bit = true; + it->second.chunk_builder->SetAccessTime(access_time); + query_builder = it->second.chunk_builder; + } + } + + if (query_builder == nullptr) { + std::unique_lock lock(main_fifo_mu_); + auto it = main_fifo_.findWithoutPromotion(obj_name); + if (it != main_fifo_.end()) { + VLOG(100) << "Hit " << obj_name << " in M"; + if (it->second.chunk_builder == nullptr) { + VLOG(100) << "Loading " << obj_name; + VINEYARD_ASSERT(it->second.object_id != InvalidObjectID()); + + auto status = KVCacheChunkBuilder::Make( + it->second.chunk_builder, rpc_client_, tensor_nbytes_, layer_, + chunk_size_, kv_cache_ns_, it->second.object_id); + if (!status.ok()) { + VLOG(100) << "Failed to make chunk builder, " << status.ToString(); + // skip this and rest chunks + break; + } else { + VLOG(100) << "obj name=" << obj_name + << ", obj id=" << ObjectIDToString(it->second.object_id); + } + } + it->second.access_bit = true; + it->second.chunk_builder->SetAccessTime(access_time); + query_builder = it->second.chunk_builder; + } + } + + // cache miss + if (query_builder == nullptr) { + break; + } + + // cache hit + tids.push_back(tg.AddTask(fn, i, query_builder)); + } + + if (tids.empty()) { + return Status::ObjectNotExists(); + } + + Status first_error = Status::OK(); + WAIT_TASK_RESULTS(tids, matched, first_error, obj_names); + return first_error; +} + +Status AIBrixBlobStorage::Update( + const std::vector& tokens, + const std::vector>>& kv_tensors, + size_t& updated) { + return UpdateInternal({}, tokens, kv_tensors, updated); +} + +Status AIBrixBlobStorage::Update( + const std::vector& prefix, const std::vector& tokens, + const std::vector>>& kv_tensors, + size_t& updated) { + return UpdateInternal(prefix, tokens, kv_tensors, updated); +} + +Status AIBrixBlobStorage::Query( + const std::vector& tokens, + std::vector>>& kv_tensors, + size_t& matched) { + return QueryInternal({}, tokens, kv_tensors, matched); +} + +Status AIBrixBlobStorage::Query( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensors, + size_t& matched) { + return QueryInternal(prefix, tokens, kv_tensors, matched); +} + +Status AIBrixBlobStorage::SealAndPersist( + const std::string& name, + const std::shared_ptr& chunk_builder, + ObjectID& chunk_id) { + auto& client = GetClient(); + std::shared_ptr chunk = chunk_builder->Seal(); + if (chunk == nullptr) { + return Status::IOError(); + } + chunk_id = chunk->id(); + RETURN_ON_ERROR(client.Persist(chunk_id)); + VINEYARD_DISCARD( + client.Label(chunk_id, KVCacheChunk::kFieldNameAccessTime, + std::to_string(chunk_builder->GetAccessTime()))); + auto status = client.PutName(chunk_id, name, /* unique */ true); + if (status.IsNameExists()) { + ObjectID obj_id; + auto s = client.GetName(name, obj_id); + if (!s.ok()) { + VLOG(100) << "Failed to get obj id of existing chunk, name=" << name + << ", error=" << s.ToString(); + } else { + ObjectMeta meta; + s = client.GetMetaData(obj_id, meta); + + if (!s.ok()) { + VLOG(100) << "Failed to get meta of existing chunk, name=" << name + << ", id=" << ObjectIDToString(obj_id) + << ", error=" << s.ToString(); + } else { + if (chunk_builder->IsSame(meta).ok()) { + VLOG(100) + << "Existing chunk is the same as the persisting chunk, name=" + << name << ", obj ids: " << ObjectIDToString(chunk_id) << ", " + << ObjectIDToString(obj_id); + VINEYARD_DISCARD(client.DelData(chunk_id)); + // reuse existing chunk + chunk_id = obj_id; + return Status::OK(); + } else { + VLOG(100) << "A different chunk has taken name=" << name + << ", obj id=" << ObjectIDToString(obj_id); + // go to the following if branch + } + } + } + } + + if (!status.ok()) { + VLOG(100) << "Failed to put name " << name << ", " << status.ToString(); + // Just delete chunk id and wait for the next time + VINEYARD_DISCARD(client.DelData(chunk_id)); + } + return status; +} + +Status AIBrixBlobStorage::Delete(const std::vector& chunk_list) { + Status status = Status::OK(); + std::vector delete_ids; + auto& client = GetClient(); + for (const auto& name : chunk_list) { + ObjectID id; + if (client.GetName(name, id, false).ok()) { + delete_ids.push_back(id); + status += client.DropName(name); + } else { + VLOG(100) << "Failed to get obj id for name=" << name; + } + } + + status += Delete(delete_ids); + + return status; +} + +Status AIBrixBlobStorage::Delete(const std::vector& obj_ids) { + Status status = Status::OK(); + if (obj_ids.size() > 0) { + auto& client = GetClient(); + status += client.DelData(obj_ids); + if (status.ok()) { + if (VLOG_IS_ON(100)) { + for (auto id : obj_ids) { + LOG(INFO) << "Deleted obj " << ObjectIDToString(id); + } + } + } else { + if (VLOG_IS_ON(100)) { + for (auto id : obj_ids) { + LOG(INFO) << "Failed to delete obj " << ObjectIDToString(id) << ", " + << status.ToString(); + } + } + } + } + return status; +} + +void AIBrixBlobStorage::CloseCache() { + LOG(INFO) << "Close AIBrixBlobStorage"; + TerminateGCThreads(); +} + +std::string AIBrixBlobStorage::GetTimestamp( + std::chrono::duration time) { + auto duration_since_epoch = + std::chrono::duration_cast(time); + std::chrono::time_point timestamp = + std::chrono::system_clock::time_point(duration_since_epoch); + time_t t = std::chrono::system_clock::to_time_t(timestamp); + + std::tm tm; + localtime_r(&t, &tm); + std::ostringstream oss; + oss << std::put_time(&tm, "%Y-%m-%d %H:%M:%S"); + return oss.str(); +} + +Status AIBrixBlobStorage::ListKVCache(const std::string& prefix, + std::vector& metas) { + std::string ns = KVCacheChunk::GetNameSpace(prefix); + std::vector ids; + auto status = rpc_client_.ListBy(KVCacheChunk::kFieldNameNS, ns, false, + UINT64_MAX, ids); + if (!status.ok()) { + VLOG(100) << "Failed to list by namespace=" << ns + << ", error=" << status.ToString(); + return status; + } + + return rpc_client_.GetMetaData(ids, metas, /* sync remote */ true); +} + +Status AIBrixBlobStorage::ProcessPersistList( + const std::vector>& persist_list) { + VLOG(100) << "ProcessPersistList: #persist chunks=" << persist_list.size(); + + auto& client = GetClient(); + for (const auto& pair : persist_list) { + auto& key = pair.first; + auto& chunk_builder = pair.second.chunk_builder; + ObjectID chunk_id; + auto s = SealAndPersist(key, chunk_builder, chunk_id); + if (s.ok()) { + VLOG(100) << "Persist " << key + << ", obj id=" << ObjectIDToString(chunk_id); + std::unique_lock lock(main_fifo_mu_); + auto it = main_fifo_.findWithoutPromotion(key); + if (it != main_fifo_.end()) { + VLOG(100) << "Main fifo: " << key + << " obj id=" << ObjectIDToString(chunk_id); + it->second.object_id = chunk_id; + } else { + VLOG(100) << "Main fifo: " << key << " not in fifo, delete obj " + << chunk_id; + VINEYARD_DISCARD(client.DelData(chunk_id)); + } + } else { + VLOG(100) << "Failed to seal and persist " << key + << ", error=" << s.ToString(); + } + } + + return Status::OK(); +} + +Status AIBrixBlobStorage::ProcessUpdateList( + const std::vector>& update_list) { + VLOG(100) << "ProcessUpdateList: #update chunks=" << update_list.size(); + + auto& client = GetClient(); + for (const auto& pair : update_list) { + auto& key = pair.first; + auto& chunk_builder = pair.second.chunk_builder; + auto& chunk_id = pair.second.object_id; + + auto update_status = + client.Label(chunk_id, KVCacheChunk::kFieldNameAccessTime, + std::to_string(chunk_builder->GetAccessTime())); + + chunk_builder->SetGlobalAccessTime(chunk_builder->GetAccessTime()); + + if (update_status.ok()) { + VLOG(100) << "Updated " << key << "'s access time to " + << GetTimestamp(std::chrono::nanoseconds( + chunk_builder->GetAccessTime())); + } else { + VLOG(100) << "Failed to update " << key << "'s access time" + << ", error=" << update_status.ToString(); + } + } + + return Status::OK(); +} + +Status AIBrixBlobStorage::LocalSyncFunc() { + // load global main fifo and merge it with the local one + VINEYARD_DISCARD(BuildMainFifo()); + + using PairT = std::pair; + std::vector persist_list; + std::vector update_list; + { + std::unique_lock lock(main_fifo_mu_); + for (auto& pair : main_fifo_) { + auto& key = pair.first; + auto& entry = pair.second; + if (!entry.chunk_builder || !entry.chunk_builder->IsReady()) { + // skip never accessed chunks + // skip not ready chunks + continue; + } + + if (entry.object_id == InvalidObjectID()) { + persist_list.push_back({key, entry}); + } + + if (entry.object_id != InvalidObjectID() && + entry.chunk_builder->GetAccessTime() > + entry.chunk_builder->GetGlobalAccessTime() + + local_sync_interval_s_.count() * 1000000000) { + update_list.push_back({key, entry}); + } + } + } + + auto status = ProcessPersistList(persist_list); + + status += ProcessUpdateList(update_list); + + return status; +} + +Status AIBrixBlobStorage::GlobalGCFunc() { + auto now = std::chrono::high_resolution_clock::now(); + auto nanoseconds_since_epoch = + std::chrono::duration_cast( + now.time_since_epoch()); + std::vector chunk_metas; + std::vector delete_chunks; + std::vector delete_list; + RETURN_ON_ERROR(ListKVCache(kv_cache_ns_, chunk_metas)); + VLOG(100) << "Global GC: " << chunk_metas.size() << " chunks to check"; + for (const auto& meta : chunk_metas) { + const auto chunk_id = meta.GetId(); + std::string chunk_name("unknown"); + if (meta.HasKey("__name")) { + chunk_name = meta.GetKeyValue("__name"); + } + auto access_time_label = meta.Label(KVCacheChunk::kFieldNameAccessTime); + if (access_time_label.empty()) { + access_time_label = std::to_string( + meta.GetKeyValue(KVCacheChunk::kFieldNameAccessTime)); + } + uint64_t time = std::stoull(access_time_label); + auto access_time = std::chrono::nanoseconds(time); + VLOG(100) << "Chunk TTL: " << global_ttl_s_.count() << " s"; + if ((access_time + global_ttl_s_).count() < + nanoseconds_since_epoch.count()) { + VLOG(100) << "Global GC: " << chunk_name << " is GC'ed"; + VLOG(100) << "Access time: " << GetTimestamp(access_time); + VLOG(100) << "Now: " << GetTimestamp(nanoseconds_since_epoch); + delete_chunks.emplace_back(chunk_name); + delete_list.emplace_back(chunk_id); + } else { + VLOG(100) << "Global GC: " << chunk_name << " is alive"; + VLOG(100) << "Access time: " << GetTimestamp(access_time); + VLOG(100) << "Now: " << GetTimestamp(nanoseconds_since_epoch); + } + } + + if (delete_list.size() > 0) { + { + std::unique_lock lock(main_fifo_mu_); + for (const auto& name : delete_chunks) { + main_fifo_.erase(name); + } + } + VINEYARD_DISCARD(Delete(delete_list)); + } + return Status::OK(); +} + +#define DEFINE_GC_THREAD(NAME, ENABLED, GC_MU, GC_CV, GC_INTERVAL) \ + void AIBrixBlobStorage::NAME##Thread( \ + std::shared_ptr self) { \ + int64_t last_time = \ + std::chrono::duration_cast( \ + std::chrono::high_resolution_clock::now().time_since_epoch()) \ + .count(); \ + while (1) { \ + std::unique_lock lock(self->GC_MU); \ + auto interval = self->GC_INTERVAL; \ + if (self->GC_CV.wait_for(lock, interval, [self, &last_time, &interval] { \ + int64_t current_time = \ + std::chrono::duration_cast( \ + std::chrono::high_resolution_clock::now() \ + .time_since_epoch()) \ + .count(); \ + return self->exit_flag_ || \ + (current_time - last_time) > interval.count(); \ + })) { \ + if (!(self->ENABLED)) { \ + LOG(INFO) << #NAME " skipped"; \ + return; \ + } \ + if (self->exit_flag_) { \ + LOG(INFO) << #NAME " exit"; \ + return; \ + } \ + LOG(INFO) << #NAME " started"; \ + Status status = self->NAME##Func(); \ + if (!status.ok()) { \ + LOG(ERROR) << #NAME " failed: " << status.ToString(); \ + /* Not a fatal error and wait for next time */ \ + } else { \ + LOG(INFO) << #NAME " completed"; \ + } \ + last_time = std::chrono::duration_cast( \ + std::chrono::system_clock::now().time_since_epoch()) \ + .count(); \ + } \ + } \ + } + +DEFINE_GC_THREAD(LocalSync, local_gc_enabled_, local_sync_mu_, local_sync_cv_, + local_sync_interval_s_); +DEFINE_GC_THREAD(GlobalGC, global_gc_enabled_, global_gc_mu_, global_gc_cv_, + global_gc_interval_s_); + +void AIBrixBlobStorage::TerminateGCThreads() { + std::lock_guard local_lock(local_sync_mu_); + std::lock_guard global_lock(global_gc_mu_); + if (!exit_flag_) { + exit_flag_ = true; + + VLOG(100) << "Terminating global GC thread"; + global_gc_mu_.unlock(); + global_gc_cv_.notify_all(); + global_gc_thread_.join(); + + VLOG(100) << "Terminating local sync thread"; + local_sync_mu_.unlock(); + local_sync_cv_.notify_all(); + local_sync_thread_.join(); + } +} + +ClientBase& AIBrixBlobStorage::GetClient() { return rpc_client_; } + +} // namespace vineyard diff --git a/modules/llm-cache/storage/aibrix_blob_storage.h b/modules/llm-cache/storage/aibrix_blob_storage.h new file mode 100644 index 000000000..c6e141deb --- /dev/null +++ b/modules/llm-cache/storage/aibrix_blob_storage.h @@ -0,0 +1,227 @@ +/** Copyright 2024 AIBrix. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef MODULES_LLM_CACHE_STORAGE_AIBRIX_BLOB_STORAGE_H_ +#define MODULES_LLM_CACHE_STORAGE_AIBRIX_BLOB_STORAGE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "client/client.h" +#include "common/util/evicting_cache_map.h" +#include "common/util/logging.h" + +#include "llm-cache/ds/kv_cache_chunk.h" +#include "llm-cache/hash/hasher.h" +#include "llm-cache/storage/storage.h" + +namespace vineyard { + +class AIBrixBlobStorage + : public IStorage, + public std::enable_shared_from_this { + private: + static constexpr int kMaxTokensPerSeq = 64 * 1024; + static constexpr double kSmallFifoCapacityRatio = 0.3; + static constexpr int kMinEviction = 32; + + RPCClient& rpc_client_; + Client& ipc_client_; + + std::shared_ptr hash_alg_; + std::shared_ptr hasher_; + + size_t tensor_nbytes_; + int layer_; + int chunk_size_; + int capacity_; + size_t chunk_obj_size_; + std::string kv_cache_ns_; + + // intervals in seconds + std::chrono::duration local_sync_interval_s_; + std::chrono::duration global_gc_interval_s_; + // TTL in seconds + std::chrono::duration global_ttl_s_; + + bool exit_flag_ = false; + + // global GC is carried out in the global GC thread. + // it checks global chunks' access time and deletes + // the expired chunks. + bool global_gc_enabled_ = false; + std::condition_variable global_gc_cv_; + std::mutex global_gc_mu_; + std::thread global_gc_thread_; + + // local sync is carried out in the local sync thread. + // it persists newly added chunks and deletes evicted chunks. + const bool local_gc_enabled_ = true; + std::condition_variable local_sync_cv_; + std::mutex local_sync_mu_; + std::thread local_sync_thread_; + + // S3-FIFO + // + // - a small FIFO map (S) that quickly removes new and unpopular + // objects (quick demotion) + // - a main FIFO map (M) that keeps popular objects in the cache + // with reinsertion (lazy promotion), and + // - a ghost FIFO map (G) that stores the id of objects recently + // evicted from S. + // + // G stores same number of entries as M. Note that, a request + // found in G is not a cache hit since the items in G do not + // have data. + // + // Each entry in the cache uses one bit of metadata to track + // hotness. + // + // Upon a cache miss, if the id of the requested object is not + // tracked in G, it is inserted into S; however, if the requested + // object is tracked in G, then the object is inserted into M. + // + // When S performs an eviction if the object has not been reused + // since insertion, it is evicted to G, and only its id is tracked + // in G. Otherwise, the object is promoted to M with the access + // bit reset to zero. + // + // When M performs an eviction, an object is directly evicted if + // it has not been reused since the insertion. Otherwise, the + // object is reinserted into M. (lazy promotion) + struct FifoEntry { + std::shared_ptr chunk_builder = nullptr; + ObjectID object_id = InvalidObjectID(); + bool access_bit = false; + }; + + std::mutex ghost_fifo_mu_; + EvictingCacheMap ghost_fifo_; + std::mutex small_fifo_mu_; + EvictingCacheMap small_fifo_; + std::mutex main_fifo_mu_; + EvictingCacheMap + main_fifo_; // mirror of global chunk list + + std::vector evict_list_; + + public: + AIBrixBlobStorage(RPCClient& rpc_client, Client& ipc_client, + size_t tensor_nbytes, int capacity, int layer, + int chunk_size, std::string kv_cache_ns, + int64_t local_sync_interval_s, bool global_gc_enabled, + int64_t global_gc_interval_s, int64_t global_ttl_s); + + Status Update( + const std::vector& token_list, int next_token, + const std::vector>& kv_tensors) override { + return Status::NotImplemented(); + } + + Status Update( + const std::vector& token_list, + const std::vector>>& kv_tensors, + size_t& updated) override; + + Status Update( + const std::vector& prefix, const std::vector& token_list, + const std::vector>>& kv_tensors, + size_t& updated) override; + + Status Query(const std::vector& token_list, + std::vector>>& kv_tensors, + size_t& matched) override; + + Status Query(const std::vector& prefix, int token, + std::vector>& kv_tensors) override { + return Status::NotImplemented(); + } + + Status Query(const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensors, + size_t& matched) override; + + void CloseCache() override; + + ~AIBrixBlobStorage() = default; + + Status Init(); + + void StopGlobalGCThread() override { global_gc_enabled_ = false; } + + void StartGlobalGCThread() override { global_gc_enabled_ = true; } + + static std::string GetTimestamp( + std::chrono::duration time); + + Status ListKVCache(const std::string& prefix, std::vector& metas); + + private: + template + void MainFifoPruneHookLoop(PairsT&& pairs, + std::vector& delete_list); + + Status BuildMainFifo(); + + ClientBase& GetClient(); + + Status GetTokenChunkHashes(const std::vector& prefix, + const std::vector& tokens, + std::vector& chunk_hashes); + + Status UpdateInternal( + const std::vector& prefix, const std::vector& token_list, + const std::vector>>& kv_tensors, + size_t& updated); + + Status QueryInternal( + const std::vector& prefix, const std::vector& tokens, + std::vector>>& kv_tensors, + size_t& matched); + + Status SealAndPersist( + const std::string& name, + const std::shared_ptr& chunk_builder, + ObjectID& chunk_id); + + Status Delete(const std::vector& chunk_list); + + Status Delete(const std::vector& obj_ids); + + Status LocalSyncFunc(); + + Status GlobalGCFunc(); + + Status ProcessPersistList( + const std::vector>& persist_list); + + Status ProcessUpdateList( + const std::vector>& udpate_list); + + static void LocalSyncThread(std::shared_ptr self); + + static void GlobalGCThread(std::shared_ptr self); + + void TerminateGCThreads(); +}; + +} // namespace vineyard + +#endif // MODULES_LLM_CACHE_STORAGE_AIBRIX_BLOB_STORAGE_H_ diff --git a/modules/llm-cache/tests/aibrix_kv_cache_test.cc b/modules/llm-cache/tests/aibrix_kv_cache_test.cc new file mode 100644 index 000000000..e1e3278fa --- /dev/null +++ b/modules/llm-cache/tests/aibrix_kv_cache_test.cc @@ -0,0 +1,391 @@ +/** Copyright 2024 AIBrix. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include + +#include "client/client.h" +#include "client/ds/object_meta.h" +#include "client/rpc_client.h" +#include "common/util/logging.h" +#include "llm-cache/ds/config.h" +#include "llm-cache/ds/kv_cache_manager.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +size_t nr_rounds = 3; +int tensorNBytes = 80; +int capacity = 20; +int layer = 3; +int chunk_size = 5; +std::string cache_prefix = "aibrix_test"; +int global_ttl = 5; + +AIBrixCacheConfig config; + +std::vector round_1_tokens = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + 67, 68, 69, 70}; // 70 tokens +std::vector round_2_tokens = {1, 2, 3, 4, 5, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 51}; // 15 tokens +std::vector round_3_tokens = {1, 2, 3, 9, 10, + 11, 12, 13, 14, 21}; // 10 tokens +std::vector round_4_tokens = {1, 2, 3, 4, 5}; // 5 tokens + +std::vector> tokens_list = {round_1_tokens, round_2_tokens, + round_3_tokens, round_4_tokens}; + +std::shared_ptr init(RPCClient& rpc_client, Client& client) { + std::shared_ptr kv_cache_manager; + VINEYARD_CHECK_OK( + KVCacheManager::Make(rpc_client, client, kv_cache_manager, config)); + return kv_cache_manager; +} + +void print_current_tokens(const std::vector& prefix, int next_token) { + std::string tokens_str = ""; + for (size_t i = 0; i < prefix.size(); ++i) { + tokens_str += std::to_string(prefix[i]) + " "; + } + tokens_str += std::to_string(next_token); + LOG(INFO) << "Current tokens: " + tokens_str; +} + +void print_kv_state(const std::vector>& kv_state) { + VLOG(100) << "kv_state: "; + for (size_t i = 0; i < kv_state.size(); ++i) { + uint8_t* key_state_data = + reinterpret_cast(kv_state[i].first.data); + uint8_t* value_state_data = + reinterpret_cast(kv_state[i].second.data); + // print the first tensorNBytes bytes + std::string key_state_str = ""; + std::string value_state_str = ""; + for (int j = 0; j < tensorNBytes; j++) { + key_state_str += std::to_string(key_state_data[j]) + " "; + value_state_str += std::to_string(value_state_data[j]) + " "; + } + VLOG(100) << "layer " << i << ":"; + VLOG(100) << "key_state: " << key_state_str; + VLOG(100) << "value_state: " << value_state_str; + VLOG(100) << "---------------------"; + } +} + +// we do not consider the layer. +std::vector> generate_kv_state(int token) { + std::vector> kv_state; + for (int currentLayer = 0; currentLayer < layer; currentLayer++) { + LLMKV key_state; + LLMKV value_state; + key_state.data = malloc(tensorNBytes); + value_state.data = malloc(tensorNBytes); + + key_state.length = tensorNBytes; + value_state.length = tensorNBytes; + + for (int i = 0; i < tensorNBytes; ++i) { + (reinterpret_cast(key_state.data))[i] = + (static_cast(token)) + i + currentLayer; + (reinterpret_cast(value_state.data))[i] = + (static_cast(token)) + i + currentLayer; + } + kv_state.emplace_back(key_state, value_state); + } + return kv_state; +} + +void check_kv_state(const std::vector>& kv_state, + int& token) { + VINEYARD_ASSERT(kv_state.size() == (size_t) layer); + for (size_t index = 0; index < kv_state.size(); ++index) { + VINEYARD_ASSERT(kv_state[index].first.length == (size_t) tensorNBytes); + VINEYARD_ASSERT(kv_state[index].second.length == (size_t) tensorNBytes); + for (int i = 0; i < tensorNBytes; ++i) { + if ((reinterpret_cast(kv_state[index].first.data))[i] != + (static_cast(token)) + i + index) { + VLOG(100) << "token:" << token << " tensorNBytes" << tensorNBytes + << " layer:" << index; + VLOG(100) << "key_state[" << i << "]: " + << (reinterpret_cast(kv_state[index].first.data))[i] + << ". But is should be " + << (static_cast(token)) + i + index; + throw std::runtime_error("key_state error!"); + } + if (reinterpret_cast(kv_state[index].second.data)[i] != + (static_cast(token)) + i + index) { + VLOG(100) << "token:" << token << " tensorNBytes" << tensorNBytes + << " layer:" << index; + VLOG(100) << "value_state[" << i << "]: " + << (reinterpret_cast( + kv_state[index].second.data))[i] + << ". But is should be " + << (static_cast(token)) + i + index; + throw std::runtime_error("value_state error!"); + } + } + } +} + +void inference(std::shared_ptr& kv_cache_manager, + std::vector tokens, bool block = false) { + std::vector inference_tokens; + std::vector>> kv_state; + for (size_t i = 0; i < tokens.size(); ++i) { + std::vector> current_kv_state = + generate_kv_state(tokens[i]); + print_kv_state(current_kv_state); + kv_state.push_back(current_kv_state); + inference_tokens.push_back(tokens[i]); + } + + size_t updated = 0; + Status result = kv_cache_manager->Update(inference_tokens, kv_state, updated); + + std::vector>> kv_state_to_query; + for (size_t i = 0; i < tokens.size(); ++i) { + std::vector> current_kv_state = + generate_kv_state(0); + kv_state_to_query.push_back(current_kv_state); + } + size_t matched = 0; + Status query_result = + kv_cache_manager->Query(inference_tokens, kv_state_to_query, matched); + if (!query_result.ok()) { + LOG(INFO) << "Query failed!"; + } + + LOG(INFO) << "Match tokens:" << matched << ". Total tokens:" << tokens.size(); + for (size_t i = 0; i < matched; ++i) { + check_kv_state(kv_state_to_query[i], tokens[i]); + } +} + +void inference(std::shared_ptr& kv_cache_manager, + std::vector prefix, std::vector tokens, + bool block = false) { + std::vector inference_tokens; + std::vector> kv_state_to_query; + std::vector>> kv_state_to_query_list; + + // Get tokens with batched query interface + inference_tokens = std::vector(prefix.begin(), prefix.end()); + + for (int i = 0; i < tokens.size(); i++) { + kv_state_to_query.clear(); + for (int currentLayer = 0; currentLayer < layer; currentLayer++) { + kv_state_to_query.emplace_back(LLMKV{nullptr, 0}, LLMKV{nullptr, 0}); + } + kv_state_to_query_list.emplace_back(kv_state_to_query); + } + size_t matched = 0; + Status result = kv_cache_manager->Query(inference_tokens, tokens, + kv_state_to_query_list, matched); + LOG(INFO) << "Find " << matched << " matched tokens from token lists."; + for (size_t i = 0; i < matched; i++) { + print_current_tokens(inference_tokens, tokens[i]); + inference_tokens.emplace_back(tokens[i]); + check_kv_state(kv_state_to_query_list[i], tokens[i]); + } +} + +void thread_func(std::string socket) { + RPCClient dummy_rpc_client; + Client client; + VINEYARD_CHECK_OK(client.Connect(socket)); + std::shared_ptr manager = init(dummy_rpc_client, client); + + for (size_t r = 0; r < nr_rounds; r++) { + for (size_t i = 0; i < tokens_list.size(); i++) { + LOG(INFO) << "Round " << i << " :"; + inference(manager, tokens_list[i]); + } + sleep(1); + } + + sleep(5); + + for (size_t r = 0; r < nr_rounds; r++) { + for (size_t i = 0; i < tokens_list.size(); i++) { + LOG(INFO) << "Round " << i << " :"; + inference(manager, tokens_list[i]); + } + sleep(1); + } + + sleep(5); + + for (size_t r = 0; r < nr_rounds; r++) { + for (size_t i = 0; i < tokens_list.size(); i++) { + size_t total_chunks = tokens_list[i].size() / chunk_size; + size_t prefix_chunks = total_chunks / 2; + std::vector prefix( + tokens_list[i].begin(), + tokens_list[i].begin() + prefix_chunks * chunk_size); + std::vector tokens_list_rest( + tokens_list[i].begin() + prefix_chunks * chunk_size, + tokens_list[i].end()); + inference(manager, prefix, tokens_list_rest); + } + sleep(1); + } + + LOG(INFO) << "inference end"; + + // sleep a while to trigger local sync and global gc + sleep(3 * global_ttl); + + manager->Close(); + client.Disconnect(); + if (dummy_rpc_client.Connected()) { + dummy_rpc_client.Disconnect(); + } +} + +void list_objects(const std::string& socket) { + Client client; + VINEYARD_CHECK_OK(client.Connect(socket)); + std::shared_ptr status; + VINEYARD_CHECK_OK(client.InstanceStatus(status)); + + std::vector metas; + if (status->memory_usage != 0) { + metas = client.ListObjectMeta(".*", true); + LOG(INFO) << "Object:"; + for (size_t i = 0; i < metas.size(); i++) { + LOG(INFO) << metas[i].ToString(); + } + } + + client.Disconnect(); +} + +void clear_objects(const std::string& socket) { + Client client; + VINEYARD_CHECK_OK(client.Connect(socket)); + std::shared_ptr status; + VINEYARD_CHECK_OK(client.InstanceStatus(status)); + + std::vector metas; + if (status->memory_usage != 0) { + metas = client.ListObjectMeta(".*", true); + for (size_t i = 0; i < metas.size(); i++) { + LOG(INFO) << "Client " << client.instance_id() << " deletes obj " + << ObjectIDToString(metas[i].GetId()) + << " instance_id=" << metas[i].GetInstanceId(); + client.DelData(metas[i].GetId()); + } + } + + client.Disconnect(); +} + +void test(const std::vector& sockets, bool enableGlobalGC) { + config = AIBrixCacheConfig(tensorNBytes, capacity, layer, chunk_size, + cache_prefix, 1, enableGlobalGC, 3, global_ttl); + + std::vector threads; + for (int i = 0; i < sockets.size(); i++) { + threads.push_back(std::thread(thread_func, sockets[i])); + } + + for (int i = 0; i < sockets.size(); i++) { + threads[i].join(); + LOG(INFO) << "Thread:" << i << " exit."; + list_objects(sockets[i]); + } + + for (int i = 0; i < sockets.size(); i++) { + clear_objects(sockets[i]); + } + + size_t total_memory_usage = 0; + for (size_t i = 0; i < sockets.size(); i++) { + Client client; + VINEYARD_CHECK_OK(client.Connect(sockets[i])); + std::shared_ptr status; + VINEYARD_CHECK_OK(client.InstanceStatus(status)); + LOG(INFO) << "Client " << client.instance_id() + << " memory usage:" << status->memory_usage; + total_memory_usage += status->memory_usage; + client.Disconnect(); + } + LOG(INFO) << "Total memory usage:" << total_memory_usage; +} + +int main(int argc, char** argv) { + std::vector sockets; + if (argc < 2) { + printf( + "usage ./aibrix_kv_cache_test --client-num " + "--vineyard-ipc-sockets ... -d " + " -c -l -b \n"); + return 1; + } + + if (strcmp(argv[1], "--client-num") != 0) { + return 1; + } + + int client_num = std::stoi(argv[2]); + + for (int i = 3; i < argc; i++) { + if (strcmp(argv[i], "-d") == 0) { + tensorNBytes = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-c") == 0) { + capacity = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-l") == 0) { + layer = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-b") == 0) { + chunk_size = atoi(argv[i + 1]); + } else if (strcmp(argv[i], "-s") == 0) { + for (int j = i + 1; j < argc; j++) { + if (strcmp(argv[j], "1") == 0) { + tokens_list.push_back(round_1_tokens); + } else if (strcmp(argv[j], "2") == 0) { + tokens_list.push_back(round_2_tokens); + } else if (strcmp(argv[j], "3") == 0) { + tokens_list.push_back(round_3_tokens); + } else if (strcmp(argv[j], "4") == 0) { + tokens_list.push_back(round_4_tokens); + } else { + break; + } + } + } else if (strcmp(argv[i], "--vineyard-ipc-sockets") == 0) { + for (int j = 0; j < client_num; j++) { + sockets.push_back(std::string(argv[i + j + 1])); + } + } + } + + LOG(INFO) << "Test AIBrixKVCache with tensorNBytes: " << tensorNBytes + << ", capacity: " << capacity << ", layer: " << layer + << ", chunk_size: " << chunk_size + << ", cache_prefix: " << cache_prefix << " and use " << client_num + << " client."; + + test(sockets, /* global gc */ false); + test(sockets, /* global gc */ true); + + LOG(INFO) << "Passed AIBrixKVCache tests..."; + return 0; +} diff --git a/python/vineyard/llm/cache.cc b/python/vineyard/llm/cache.cc index 823764190..75ec892f8 100644 --- a/python/vineyard/llm/cache.cc +++ b/python/vineyard/llm/cache.cc @@ -138,13 +138,39 @@ PYBIND11_MODULE(_llm_C, m) { py::arg("enable_global_gc") = false, py::arg("global_gc_interval") = 30 * 60, py::arg("global_ttl") = 30 * 60) + .def( + py::init([](py::object rpc_client, py::object ipc_client, + int tensor_nbytes, int cache_capacity, int layer, + int chunk_size, std::string kv_cache_ns, + int64_t local_sync_interval_s, bool global_gc_enabled, + int64_t global_gc_interval_s, + int64_t global_ttl_s) -> std::shared_ptr { + AIBrixCacheConfig config(tensor_nbytes, cache_capacity, layer, + chunk_size, kv_cache_ns, + local_sync_interval_s, global_gc_enabled, + global_gc_interval_s, global_ttl_s); + Client& ipc_client_ = ipc_client.cast(); + RPCClient& rpc_client_ = rpc_client.cast(); + std::shared_ptr manager; + VINEYARD_CHECK_OK(vineyard::KVCacheManager::Make( + rpc_client_, ipc_client_, manager, config)); + return manager; + }), + py::arg("rpc_client"), py::arg("ipc_client"), + py::arg("tensor_nbytes") = 1024, py::arg("cache_capacity") = 1024, + py::arg("layer") = 1, py::arg("chunk_size") = 16, + py::arg("kv_cache_ns") = "aibrix", + py::arg("local_sync_interval_s") = 3 * 60, + py::arg("enable_global_gc") = true, + py::arg("global_gc_interval_s") = 10 * 60, + py::arg("global_ttl_s") = 8 * 60) .def( "update", [](KVCacheManager* self, const std::vector& tokenList, int& next_token, const std::vector>& kv_state) { - VINEYARD_CHECK_OK(self->Update(tokenList, next_token, kv_state)); + VINEYARD_DISCARD(self->Update(tokenList, next_token, kv_state)); }, py::arg("tokens"), py::arg("next_token"), py::arg("kv_state")) .def( @@ -153,7 +179,7 @@ PYBIND11_MODULE(_llm_C, m) { const std::vector>>& kv_states) -> size_t { size_t updated = 0; - VINEYARD_CHECK_OK(self->Update(tokens, kv_states, updated)); + VINEYARD_DISCARD(self->Update(tokens, kv_states, updated)); return updated; }, py::arg("tokens"), py::arg("kv_states")) @@ -164,7 +190,7 @@ PYBIND11_MODULE(_llm_C, m) { const std::vector>>& kv_states) -> size_t { size_t updated = 0; - VINEYARD_CHECK_OK(self->Update(prefix, tokens, kv_states, updated)); + VINEYARD_DISCARD(self->Update(prefix, tokens, kv_states, updated)); return updated; }, py::arg("prefix"), py::arg("tokens"), py::arg("kv_states")) @@ -174,7 +200,7 @@ PYBIND11_MODULE(_llm_C, m) { const std::vector>>& kv_states) -> size_t { size_t updated = 0; - VINEYARD_CHECK_OK(self->BatchedUpdate(tokens, kv_states, updated)); + VINEYARD_DISCARD(self->BatchedUpdate(tokens, kv_states, updated)); return updated; }, py::arg("tokens"), py::arg("kv_states")) @@ -186,7 +212,7 @@ PYBIND11_MODULE(_llm_C, m) { kv_cache_list .cast>>>(); size_t matched = 0; - VINEYARD_CHECK_OK(self->Query(tokens, kv_state_vec, matched)); + VINEYARD_DISCARD(self->Query(tokens, kv_state_vec, matched)); for (size_t i = 0; i < kv_state_vec.size() && i < matched; ++i) { for (size_t j = 0; j < kv_state_vec[i].size(); ++j) { kv_cache_list[i].cast()[j] = @@ -202,8 +228,8 @@ PYBIND11_MODULE(_llm_C, m) { int& next_token, py::list& kv_state) { std::vector> kv_state_vec = kv_state.cast>>(); - VINEYARD_CHECK_OK(self->Query(prefix, next_token, kv_state_vec)); - for (size_t i = 0; i < kv_state_vec.size(); ++i) { + auto status = self->Query(prefix, next_token, kv_state_vec); + for (size_t i = 0; i < kv_state_vec.size() && status.ok(); ++i) { kv_state[i] = py::cast(kv_state_vec[i]); } }, @@ -217,7 +243,7 @@ PYBIND11_MODULE(_llm_C, m) { kv_cache_list .cast>>>(); size_t matched = 0; - VINEYARD_CHECK_OK( + VINEYARD_DISCARD( self->Query(prefix, tokens, kv_state_vec, matched)); for (size_t i = 0; i < kv_state_vec.size() && i < matched; ++i) { for (size_t j = 0; j < kv_state_vec[i].size(); ++j) { @@ -236,8 +262,7 @@ PYBIND11_MODULE(_llm_C, m) { kv_cache_list .cast>>>(); size_t matched = 0; - VINEYARD_CHECK_OK( - self->BatchedQuery(tokens, kv_state_vec, matched)); + VINEYARD_DISCARD(self->BatchedQuery(tokens, kv_state_vec, matched)); for (size_t i = 0; i < kv_state_vec.size() && i < matched; ++i) { for (size_t j = 0; j < kv_state_vec[i].size(); ++j) { kv_cache_list[i].cast()[j] = diff --git a/python/vineyard/llm/cache.py b/python/vineyard/llm/cache.py index 8cfcf95cf..ed8ce8c10 100644 --- a/python/vineyard/llm/cache.py +++ b/python/vineyard/llm/cache.py @@ -204,12 +204,80 @@ def __repr__(self): ) +class AIBrixCacheConfig: + """AIBrixCacheConfig is a class to configure the AIBrix llm kv cache.""" + + def __init__( + self, + chunk_size: int = 16, + kv_cache_ns: str = "aibrix", + local_sync_interval_s: int = 3 * 60, + enable_global_gc: bool = True, + global_gc_interval_s: int = 10 * 60, + global_ttl_s: int = 8 * 60, + socket: str = "", + rpc_endpoint: str = "", + rdma_endpoint: str = "", + ): + """Create an AIBrix cache config. + + Args: + chunk_size (int): + Divide the token list into batches, each batch + contains chunk_size tokens. Defaults to 16. + kv_cache_ns (str): + The namespace of kv cache objects. + Defaults to "aibrix". + local_sync_interval_s (int): + The interval of local sync func (seconds). + Defaults to 3 * 60 seconds. + enable_global_gc (bool): + Enable the global gc or not. Defaults to True. + global_gc_interval_s (int): + The interval of the global gc (seconds). + Defaults to 10 * 60 seconds. + global_ttl_s (int): + The time to live of the global kv cache objects (seconds). + Defaults to 8 * 60 seconds. + """ + self.chunk_size = chunk_size + self.kv_cache_ns = kv_cache_ns + self.local_sync_interval_s = local_sync_interval_s + self.enable_global_gc = enable_global_gc + self.global_gc_interval_s = global_gc_interval_s + self.global_ttl_s = global_ttl_s + + import vineyard + + self.ipc_client = vineyard.connect(socket).ipc_client + splits = rpc_endpoint.split(":") + if len(splits) == 2: + rpc_host = splits[0] + rpc_port = splits[1] + self.rpc_client = vineyard.connect( + host=rpc_host, port=rpc_port, rdma_endpoint=rdma_endpoint + ).rpc_client + + def __repr__(self): + return ( + f'AIBrixCacheConfig(' + f'chunk_size={self.chunk_size}, ' + f'kv_cache_ns={self.kv_cache_ns}, ' + f'local_sync_interval_s={self.local_sync_interval_s}, ' + f'enable_global_gc={self.enable_global_gc}, ' + f'global_gc_interval_s={self.global_gc_interval_s}, ' + f'global_ttl_s={self.global_ttl_s}), ' + ) + + class KVCache: # pylint: disable=too-many-instance-attributes """KVCache is a class that manages the llm kv cache in vineyard.""" def __init__( self, - cache_config: Optional[Union[VineyardCacheConfig, FileCacheConfig]] = None, + cache_config: Optional[ + Union[VineyardCacheConfig, FileCacheConfig, AIBrixCacheConfig] + ] = None, tensor_nbytes: int = 1024, cache_capacity: int = 1024, layer: int = 1, @@ -292,19 +360,44 @@ def __init__( config, 'VINEYARD_LLM_CACHE_FILESYSTEM', 'rdma_endpoint', str ) cache_config = FileCacheConfig(**config) + if 'AIBRIX_LLM_KV_CACHE' in os.environ: + config = {} + _argument_from_env(config, 'AIBRIX_LLM_KV_CACHE', 'chunk_size', int) + _argument_from_env(config, 'AIBRIX_LLM_KV_CACHE', 'kv_cache_ns', str) + _argument_from_env( + config, 'AIBRIX_LLM_KV_CACHE', 'local_sync_interval_s', int + ) + _argument_from_env( + config, 'AIBRIX_LLM_KV_CACHE', 'enable_global_gc', bool + ) + _argument_from_env( + config, 'AIBRIX_LLM_KV_CACHE', 'global_gc_interval_s', int + ) + _argument_from_env(config, 'AIBRIX_LLM_KV_CACHE', 'global_ttl_s', int) + _argument_from_env(config, 'AIBRIX_LLM_KV_CACHE', 'socket', str) + _argument_from_env(config, 'AIBRIX_LLM_KV_CACHE', 'rpc_endpoint', str) + _argument_from_env(config, 'AIBRIX_LLM_KV_CACHE', 'rdma_endpoint', str) + cache_config = AIBrixCacheConfig(**config) if rank is not None and world_size is not None: if isinstance(cache_config, FileCacheConfig): cache_config.root = os.path.join( cache_config.root, f'{world_size}-{rank}' ) + if isinstance(cache_config, AIBrixCacheConfig): + cache_config.kv_cache_ns = ( + f'{cache_config.kv_cache_ns}_{world_size}_{rank}' + ) logger.info("Initializing vineyard llm cache with config: %r", cache_config) - if not isinstance(cache_config, VineyardCacheConfig) and not isinstance( - cache_config, FileCacheConfig + if ( + not isinstance(cache_config, VineyardCacheConfig) + and not isinstance(cache_config, FileCacheConfig) + and not isinstance(cache_config, AIBrixCacheConfig) ): raise ValueError( - "The cache_config should be VineyardCacheConfig or FileCacheConfig." + "The cache_config should be VineyardCacheConfig or FileCacheConfig " + "or AIBrixCacheConfig." ) self.cache_config = cache_config self.tensor_nbytes = tensor_nbytes diff --git a/src/client/client_base.cc b/src/client/client_base.cc index 6788bb0e6..4239819c3 100644 --- a/src/client/client_base.cc +++ b/src/client/client_base.cc @@ -55,9 +55,14 @@ Status ClientBase::GetData(const std::vector& ids, RETURN_ON_ERROR(doRead(message_in)); std::unordered_map meta_trees; RETURN_ON_ERROR(ReadGetDataReply(message_in, meta_trees)); + if (meta_trees.empty()) { + return Status::OK(); + } trees.reserve(ids.size()); for (auto const& id : ids) { - trees.emplace_back(meta_trees.at(id)); + if (meta_trees.count(id) > 0) { + trees.emplace_back(meta_trees.at(id)); + } } return Status::OK(); } @@ -252,6 +257,19 @@ Status ClientBase::ListNames(std::string const& pattern, bool const regex, return Status::OK(); } +Status ClientBase::ListBy(std::string const& field, std::string const& pattern, + bool const regex, size_t const limit, + std::vector& ids) { + ENSURE_CONNECTED(this); + std::string message_out; + WriteListByRequest(field, pattern, regex, limit, message_out); + RETURN_ON_ERROR(doWrite(message_out)); + json message_in; + RETURN_ON_ERROR(doRead(message_in)); + RETURN_ON_ERROR(ReadListByReply(message_in, ids)); + return Status::OK(); +} + Status ClientBase::CreateStream(const ObjectID& id) { ENSURE_CONNECTED(this); std::string message_out; @@ -394,10 +412,11 @@ Status ClientBase::ShallowCopy(const ObjectID id, json const& extra_metadata, return Status::OK(); } -Status ClientBase::PutName(const ObjectID id, std::string const& name) { +Status ClientBase::PutName(const ObjectID id, std::string const& name, + const bool unique) { ENSURE_CONNECTED(this); std::string message_out; - WritePutNameRequest(id, name, message_out); + WritePutNameRequest(id, name, unique, message_out); RETURN_ON_ERROR(doWrite(message_out)); json message_in; RETURN_ON_ERROR(doRead(message_in)); diff --git a/src/client/client_base.h b/src/client/client_base.h index 34955a341..4c70479ee 100644 --- a/src/client/client_base.h +++ b/src/client/client_base.h @@ -238,6 +238,24 @@ class ClientBase { Status ListNames(std::string const& pattern, bool const regex, size_t const limit, std::map& names); + /** + * @brief List by field in vineyard, using the given patterns. + * + * @param field The field that will be used to match against + * @param pattern The pattern string that will be used + * @param regex Whether the pattern is a regular expression pattern. Default + * is false. When `regex` is false, the pattern will be treated as a glob + * pattern. + * @param limit The number limit for how many objects will be returned at + * most. + * @param ids corresponding object ids. + * + * @return Status that indicates whether the list action has succeeded. + */ + Status ListBy(std::string const& field, std::string const& pattern, + bool const regex, size_t const limit, + std::vector& ids); + /** * @brief Allocate a stream on vineyard. The metadata of parameter `id` must * has already been created on vineyard. @@ -400,10 +418,12 @@ class ClientBase { * @param id The ID of the object. * @param name The user-specific name that will be associated with the given * object. + * @param unique Fail if the given name has already been registered * * @return Status that indicates whether the request has succeeded. */ - Status PutName(const ObjectID id, std::string const& name); + Status PutName(const ObjectID id, std::string const& name, + const bool unique = false); /** * @brief Retrieve the object ID by associated name. diff --git a/src/common/util/evicting_cache_map.h b/src/common/util/evicting_cache_map.h new file mode 100644 index 000000000..dac72ed37 --- /dev/null +++ b/src/common/util/evicting_cache_map.h @@ -0,0 +1,523 @@ +/* + * Copyright 2014-present Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace vineyard { + +// Folly v2021.03.08.00 +/** + * A general purpose LRU evicting cache. Designed to support constant time + * set/get operations. It maintains a doubly linked list of items that are + * threaded through an index (a hash map). The access ordered is maintained + * on the list by moving an element to the front of list on a get. New elements + * are added to the front of the list. The index size is set to half the + * capacity (setting capacity to 0 is a special case. see notes at the end of + * this section). So assuming uniform distribution of keys, set/get are both + * constant time operations. + * + * On reaching capacity limit, clearSize_ LRU items are evicted at a time. If + * a callback is specified with setPruneHook, it is invoked for each eviction. + * + * This is NOT a thread-safe implementation. + * + * Configurability: capacity of the cache, number of items to evict, eviction + * callback and the hasher to hash the keys can all be supplied by the caller. + * + * If at a given state, N1 - N6 are the nodes in MRU to LRU order and hashing + * to index keys as {(N1,N5)->H1, (N4,N5,N5)->H2, N3->Hi}, the datastructure + * layout is as below. N1 .. N6 is a list threaded through the hash. + * Assuming, each the number of nodes hashed to each index key is bounded, the + * following operations run in constant time. + * i) get computes the index key, walks the list of elements hashed to + * the key and moves it to the front of the list, if found. + * ii) set inserts a new node into the list and places the same node on to the + * list of elements hashing to the corresponding index key. + * ii) prune deletes nodes from the end of the list as well from the index. + * + * +----+ +----+ +----+ + * | H1 | <-> | N1 | <-> | N5 | + * +----+ +----+ +----+ + * ^ ^ ^ + * | ___/ \ + * | / \ + * |_ /________ \___ + * / | \ + * / | \ + * v v v + * +----+ +----+ +----+ +----+ + * | H2 | <-> | N4 | <-> | N2 | <-> | N6 | + * +----+ +----+ +----+ +----+ + * . ^ ^ + * . | | + * . | | + * . | _____| + * . | / + * v v + * +----+ +----+ + * | Hi | <-> | N3 | + * +----+ +----+ + * + * N.B 1 : Changing the capacity with setMaxSize does not change the index size + * and it could end up in too many elements indexed to the same slot in index. + * The set/get performance will get worse in this case. So it is best to avoid + * resizing. + * + * N.B 2 : Setting capacity to 0, using setMaxSize or initialization, turns off + * evictions based on sizeof the cache making it an INFINITE size cache + * unless evictions of LRU items are triggered by calling prune() by clients + * (using their own eviction criteria). + */ +template , + class TKeyEqual = std::equal_to> +class EvictingCacheMap { + private: + // typedefs for brevity + struct Node; + struct KeyHasher; + struct KeyValueEqual; + typedef boost::intrusive::link_mode link_mode; + typedef boost::intrusive::unordered_set< + Node, boost::intrusive::hash, + boost::intrusive::equal> + NodeMap; + typedef boost::intrusive::list NodeList; + typedef std::pair TPair; + + public: + typedef std::function&&)> PruneHookCall; + + // iterator base : returns TPair on dereference + template + class iterator_base + : public boost::iterator_adaptor, + TIterator, Value, + boost::bidirectional_traversal_tag> { + public: + iterator_base() {} + + explicit iterator_base(TIterator it) + : iterator_base::iterator_adaptor_(it) {} + + template ::value && + std::is_convertible::value, + int> = 0> + /* implicit */ iterator_base(iterator_base const& other) + : iterator_base::iterator_adaptor_(other.base()) {} + + Value& dereference() const { return this->base_reference()->pr; } + }; + + // iterators + typedef iterator_base iterator; + typedef iterator_base + const_iterator; + typedef iterator_base + reverse_iterator; + typedef iterator_base + const_reverse_iterator; + + // the default map typedefs + using key_type = TKey; + using mapped_type = TValue; + using hasher = THash; + + /** + * Construct a EvictingCacheMap + * @param maxSize maximum size of the cache map. Once the map size exceeds + * maxSize, the map will begin to evict. + * @param clearSize the number of elements to clear at a time when the + * eviction size is reached. + */ + explicit EvictingCacheMap(std::size_t maxSize, std::size_t clearSize = 1, + const THash& keyHash = THash(), + const TKeyEqual& keyEqual = TKeyEqual()) + : nIndexBuckets_(std::max(maxSize / 2, std::size_t(kMinNumIndexBuckets))), + indexBuckets_(new typename NodeMap::bucket_type[nIndexBuckets_]), + indexTraits_(indexBuckets_.get(), nIndexBuckets_), + keyHash_(keyHash), + keyEqual_(keyEqual), + index_(indexTraits_, keyHash_, keyEqual_), + maxSize_(maxSize), + clearSize_(clearSize) {} + + EvictingCacheMap(const EvictingCacheMap&) = delete; + EvictingCacheMap& operator=(const EvictingCacheMap&) = delete; + EvictingCacheMap(EvictingCacheMap&&) = default; + EvictingCacheMap& operator=(EvictingCacheMap&&) = default; + + ~EvictingCacheMap() { + setPruneHook(nullptr); + // ignore any potential exceptions from pruneHook_ + pruneWithFailSafeOption(size(), nullptr, true); + } + + /** + * Adjust the max size of EvictingCacheMap. Note that this does not update + * nIndexBuckets_ accordingly. This API can cause performance to get very + * bad, e.g., the nIndexBuckets_ is still 100 after maxSize is updated to 1M. + * + * Calling this function with an arugment of 0 removes the limit on the cache + * size and elements are not evicted unless clients explicitly call prune. + * + * If you intend to resize dynamically using this, then picking an index size + * that works well and initializing with corresponding maxSize is the only + * reasonable option. + * + * @param maxSize new maximum size of the cache map. + * @param pruneHook callback to use on eviction. + */ + void setMaxSize(size_t maxSize, PruneHookCall pruneHook = nullptr) { + if (maxSize != 0 && maxSize < size()) { + // Prune the excess elements with our new constraints. + prune(std::max(size() - maxSize, clearSize_), pruneHook); + } + maxSize_ = maxSize; + } + + size_t getMaxSize() const { return maxSize_; } + + void setClearSize(size_t clearSize) { clearSize_ = clearSize; } + + /** + * Check for existence of a specific key in the map. This operation has + * no effect on LRU order. + * @param key key to search for + * @return true if exists, false otherwise + */ + bool exists(const TKey& key) const { + return findInIndex(key) != index_.end(); + } + + /** + * Get the value associated with a specific key. This function always + * promotes a found value to the head of the LRU. + * @param key key associated with the value + * @return the value if it exists + * @throw std::out_of_range exception of the key does not exist + */ + TValue& get(const TKey& key) { + auto it = find(key); + if (it == end()) { + throw std::out_of_range("Key does not exist"); + } + return it->second; + } + + /** + * Get the iterator associated with a specific key. This function always + * promotes a found value to the head of the LRU. + * @param key key to associate with value + * @return the iterator of the object (a std::pair of const TKey, TValue) or + * end() if it does not exist + */ + iterator find(const TKey& key) { + auto it = findInIndex(key); + if (it == index_.end()) { + return end(); + } + lru_.splice(lru_.begin(), lru_, lru_.iterator_to(*it)); + return iterator(lru_.iterator_to(*it)); + } + + /** + * Get the value associated with a specific key. This function never + * promotes a found value to the head of the LRU. + * @param key key associated with the value + * @return the value if it exists + * @throw std::out_of_range exception of the key does not exist + */ + const TValue& getWithoutPromotion(const TKey& key) const { + auto it = findWithoutPromotion(key); + if (it == end()) { + throw std::out_of_range("Key does not exist"); + } + return it->second; + } + + TValue& getWithoutPromotion(const TKey& key) { + auto const& cThis = *this; + return const_cast(cThis.getWithoutPromotion(key)); + } + + /** + * Get the iterator associated with a specific key. This function never + * promotes a found value to the head of the LRU. + * @param key key to associate with value + * @return the iterator of the object (a std::pair of const TKey, TValue) or + * end() if it does not exist + */ + const_iterator findWithoutPromotion(const TKey& key) const { + auto it = findInIndex(key); + return (it == index_.end()) ? end() : const_iterator(lru_.iterator_to(*it)); + } + + iterator findWithoutPromotion(const TKey& key) { + auto it = findInIndex(key); + return (it == index_.end()) ? end() : iterator(lru_.iterator_to(*it)); + } + + /** + * Erase the key-value pair associated with key if it exists. + * @param key key associated with the value + * @return true if the key existed and was erased, else false + */ + bool erase(const TKey& key) { + auto it = findInIndex(key); + if (it != index_.end()) { + erase(const_iterator(lru_.iterator_to(*it))); + return true; + } + return false; + } + + /** + * Erase the key-value pair associated with pos + * @param pos iterator to the element to be erased + * @return iterator to the following element or end() if pos was the last + * element + */ + iterator erase(const_iterator pos) { + auto* node = const_cast(&(*pos.base())); + std::unique_ptr nptr(node); + index_.erase(index_.iterator_to(*node)); + return iterator(lru_.erase(pos.base())); + } + + /** + * Set a key-value pair in the dictionary + * @param key key to associate with value + * @param value value to associate with the key + * @param promote boolean flag indicating whether or not to move something + * to the front of an LRU. This only really matters if you're setting + * a value that already exists. + * @param pruneHook callback to use on eviction (if it occurs). + */ + void set(const TKey& key, TValue value, bool promote = true, + PruneHookCall pruneHook = nullptr) { + auto it = findInIndex(key); + if (it != index_.end()) { + it->pr.second = std::move(value); + if (promote) { + lru_.splice(lru_.begin(), lru_, lru_.iterator_to(*it)); + } + } else { + auto node = new Node(key, std::move(value)); + index_.insert(*node); + lru_.push_front(*node); + + // no evictions if maxSize_ is 0 i.e. unlimited capacity + if (maxSize_ > 0 && size() > maxSize_) { + prune(clearSize_, pruneHook); + } + } + } + + /** + * Insert a new key-value pair in the dictionary if no element exists for key + * @param key key to associate with value + * @param value value to associate with the key + * @param pruneHook callback to use on eviction (if it occurs). + * @return a pair consisting of an iterator to the inserted element (or to the + * element that prevented the insertion) and a bool denoting whether the + * insertion took place. + */ + std::pair insert(const TKey& key, TValue value, + PruneHookCall pruneHook = nullptr) { + auto node = std::make_unique(key, std::move(value)); + auto pair = index_.insert(*node); + if (pair.second) { + lru_.push_front(*node); + node.release(); + + // no evictions if maxSize_ is 0 i.e. unlimited capacity + if (maxSize_ > 0 && size() > maxSize_) { + prune(clearSize_, pruneHook); + } + } + return std::make_pair(iterator(lru_.iterator_to(*pair.first)), pair.second); + } + + /** + * Get the number of elements in the dictionary + * @return the size of the dictionary + */ + std::size_t size() const { return index_.size(); } + + /** + * Typical empty function + * @return true if empty, false otherwise + */ + bool empty() const { return index_.empty(); } + + void clear(PruneHookCall pruneHook = nullptr) { prune(size(), pruneHook); } + + /** + * Set the prune hook, which is the function invoked on the key and value + * on each eviction. Will throw If the pruneHook throws, unless the + * EvictingCacheMap object is being destroyed in which case it will + * be ignored. + * @param pruneHook new callback to use on eviction. + * @param promote boolean flag indicating whether or not to move something + * to the front of an LRU. + * @return the iterator of the object (a std::pair of const TKey, TValue) or + * end() if it does not exist + */ + void setPruneHook(PruneHookCall pruneHook) { pruneHook_ = pruneHook; } + + /** + * Prune the minimum of pruneSize and size() from the back of the LRU. + * Will throw if pruneHook throws. + * @param pruneSize minimum number of elements to prune + * @param pruneHook a custom pruneHook function + */ + void prune(std::size_t pruneSize, PruneHookCall pruneHook = nullptr) { + // do not swallow exceptions for prunes not triggered from destructor + pruneWithFailSafeOption(pruneSize, pruneHook, false); + } + + // Iterators and such + iterator begin() { return iterator(lru_.begin()); } + iterator end() { return iterator(lru_.end()); } + const_iterator begin() const { return const_iterator(lru_.begin()); } + const_iterator end() const { return const_iterator(lru_.end()); } + + const_iterator cbegin() const { return const_iterator(lru_.cbegin()); } + const_iterator cend() const { return const_iterator(lru_.cend()); } + + reverse_iterator rbegin() { return reverse_iterator(lru_.rbegin()); } + reverse_iterator rend() { return reverse_iterator(lru_.rend()); } + + const_reverse_iterator rbegin() const { + return const_reverse_iterator(lru_.rbegin()); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(lru_.rend()); + } + + const_reverse_iterator crbegin() const { + return const_reverse_iterator(lru_.crbegin()); + } + const_reverse_iterator crend() const { + return const_reverse_iterator(lru_.crend()); + } + + private: + struct Node : public boost::intrusive::unordered_set_base_hook, + public boost::intrusive::list_base_hook { + Node(const TKey& key, TValue&& value) + : pr(std::make_pair(key, std::move(value))) {} + TPair pr; + }; + + struct KeyHasher { + KeyHasher(const THash& keyHash) // NOLINT(runtime/explicit) + : hash(keyHash) {} + std::size_t operator()(const Node& node) const { + return hash(node.pr.first); + } + std::size_t operator()(const TKey& key) const { return hash(key); } + THash hash; + }; + + struct KeyValueEqual { + KeyValueEqual(const TKeyEqual& keyEqual) // NOLINT(runtime/explicit) + : equal(keyEqual) {} + bool operator()(const TKey& lhs, const Node& rhs) const { + return equal(lhs, rhs.pr.first); + } + bool operator()(const Node& lhs, const TKey& rhs) const { + return equal(lhs.pr.first, rhs); + } + bool operator()(const Node& lhs, const Node& rhs) const { + return equal(lhs.pr.first, rhs.pr.first); + } + TKeyEqual equal; + }; + + /** + * Get the iterator in in the index associated with a specific key. This is + * merely a search in the index and does not promote the object. + * @param key key to associate with value + * @return the NodeMap::iterator to the Node containing the object + * (a std::pair of const TKey, TValue) or index_.end() if it does not exist + */ + typename NodeMap::iterator findInIndex(const TKey& key) { + return index_.find(key, KeyHasher(keyHash_), KeyValueEqual(keyEqual_)); + } + + typename NodeMap::const_iterator findInIndex(const TKey& key) const { + return index_.find(key, KeyHasher(keyHash_), KeyValueEqual(keyEqual_)); + } + + /** + * Prune the minimum of pruneSize and size() from the back of the LRU. + * @param pruneSize minimum number of elements to prune + * @param pruneHook a custom pruneHook function + * @param failSafe true if exceptions are to ignored, false by default + */ + void pruneWithFailSafeOption(std::size_t pruneSize, PruneHookCall pruneHook, + bool failSafe) { + auto& ph = (nullptr == pruneHook) ? pruneHook_ : pruneHook; + + std::vector pairs; + pairs.reserve(pruneSize); + for (std::size_t i = 0; i < pruneSize && !lru_.empty(); i++) { + auto* node = &(*lru_.rbegin()); + std::unique_ptr nptr(node); + + lru_.erase(lru_.iterator_to(*node)); + index_.erase(index_.iterator_to(*node)); + pairs.emplace_back(std::move(nptr->pr)); + } + + if (ph) { + try { + ph(std::move(pairs)); + } catch (...) { + if (!failSafe) { + throw; + } + } + } + } + + static const std::size_t kMinNumIndexBuckets = 100; + PruneHookCall pruneHook_; + std::size_t nIndexBuckets_; + std::unique_ptr indexBuckets_; + typename NodeMap::bucket_traits indexTraits_; + THash keyHash_; + TKeyEqual keyEqual_; + NodeMap index_; + NodeList lru_; + std::size_t maxSize_; + std::size_t clearSize_; +}; + +} // namespace vineyard diff --git a/src/common/util/protocols.cc b/src/common/util/protocols.cc index ae67eefba..26f6d6468 100644 --- a/src/common/util/protocols.cc +++ b/src/common/util/protocols.cc @@ -117,6 +117,8 @@ const std::string command_t::GET_DATA_REQUEST = "get_data_request"; const std::string command_t::GET_DATA_REPLY = "get_data_reply"; const std::string command_t::LIST_DATA_REQUEST = "list_data_request"; const std::string command_t::LIST_DATA_REPLY = "list_data_reply"; +const std::string command_t::LIST_BY_REQUEST = "list_by_request"; +const std::string command_t::LIST_BY_REPLY = "list_by_reply"; const std::string command_t::DELETE_DATA_REQUEST = "del_data_request"; const std::string command_t::DELETE_DATA_REPLY = "del_data_reply"; const std::string command_t::EXISTS_REQUEST = "exists_request"; @@ -1723,20 +1725,26 @@ Status ReadDropStreamReply(const json& root) { } void WritePutNameRequest(const ObjectID object_id, const std::string& name, - std::string& msg) { + const bool unique, std::string& msg) { json root; root["type"] = command_t::PUT_NAME_REQUEST; root["object_id"] = object_id; root["name"] = name; + root["unique"] = unique; encode_msg(root, msg); } Status ReadPutNameRequest(const json& root, ObjectID& object_id, - std::string& name) { + std::string& name, bool& unique) { CHECK_IPC_ERROR(root, command_t::PUT_NAME_REQUEST); object_id = root["object_id"].get(); name = root["name"].get_ref(); + if (root.contains("unique")) { + unique = root["unique"].get(); + } else { + unique = false; + } return Status::OK(); } @@ -1820,6 +1828,44 @@ Status ReadListNameReply(const json& root, return Status::OK(); } +void WriteListByRequest(std::string const& field, std::string const& pattern, + bool const regex, size_t const limit, + std::string& msg) { + json root; + root["type"] = command_t::LIST_BY_REQUEST; + root["field"] = field; + root["pattern"] = pattern; + root["regex"] = regex; + root["limit"] = limit; + + encode_msg(root, msg); +} + +Status ReadListByRequest(const json& root, std::string& field, + std::string& pattern, bool& regex, size_t& limit) { + CHECK_IPC_ERROR(root, command_t::LIST_BY_REQUEST); + field = root["field"].get_ref(); + pattern = root["pattern"].get_ref(); + regex = root.value("regex", false); + limit = root["limit"].get(); + return Status::OK(); +} + +void WriteListByReply(std::vector const& ids, std::string& msg) { + json root; + root["type"] = command_t::LIST_BY_REPLY; + root["size"] = ids.size(); + root["ids"] = ids; + + encode_msg(root, msg); +} + +Status ReadListByReply(const json& root, std::vector& ids) { + CHECK_IPC_ERROR(root, command_t::LIST_BY_REPLY); + ids = root.value("ids", std::vector{}); + return Status::OK(); +} + void WriteDropNameRequest(const std::string& name, std::string& msg) { json root; root["type"] = command_t::DROP_NAME_REQUEST; diff --git a/src/common/util/protocols.h b/src/common/util/protocols.h index e79dc6aa4..bd1481eff 100644 --- a/src/common/util/protocols.h +++ b/src/common/util/protocols.h @@ -92,6 +92,8 @@ struct command_t { static const std::string GET_DATA_REPLY; static const std::string LIST_DATA_REQUEST; static const std::string LIST_DATA_REPLY; + static const std::string LIST_BY_REQUEST; + static const std::string LIST_BY_REPLY; static const std::string DELETE_DATA_REQUEST; static const std::string DELETE_DATA_REPLY; static const std::string EXISTS_REQUEST; @@ -508,6 +510,16 @@ void WriteListDataRequest(std::string const& pattern, bool const regex, Status ReadListDataRequest(const json& root, std::string& pattern, bool& regex, size_t& limit); +void WriteListByRequest(std::string const& field, std::string const& pattern, + bool const regex, size_t const limit, std::string& msg); + +Status ReadListByRequest(const json& root, std::string& field, + std::string& pattern, bool& regex, size_t& limit); + +void WriteListByReply(std::vector const& ids, std::string& msg); + +Status ReadListByReply(const json& root, std::vector& ids); + void WriteDelDataRequest(const ObjectID id, const bool force, const bool deep, const bool memory_trim, const bool fastpath, std::string& msg); @@ -661,10 +673,10 @@ void WriteDropStreamReply(std::string& msg); Status ReadDropStreamReply(const json& root); void WritePutNameRequest(const ObjectID object_id, const std::string& name, - std::string& msg); + const bool unique, std::string& msg); Status ReadPutNameRequest(const json& root, ObjectID& object_id, - std::string& name); + std::string& name, bool& unique); void WritePutNameReply(std::string& msg); diff --git a/src/common/util/status.cc b/src/common/util/status.cc index 8e1edf47b..fa7b1e246 100644 --- a/src/common/util/status.cc +++ b/src/common/util/status.cc @@ -183,6 +183,9 @@ std::string Status::CodeAsString() const { case StatusCode::kGlobalObjectInvalid: type = "Global object invalid"; break; + case StatusCode::kNameExists: + type = "Name exists"; + break; case StatusCode::kUnknownError: default: type = "Unknown error"; @@ -300,6 +303,9 @@ std::string Status::CodeAsLabel() const { case StatusCode::kGlobalObjectInvalid: type = "GlobalObjectInvalid"; break; + case StatusCode::kNameExists: + type = "NameExists"; + break; case StatusCode::kUnknownError: default: type = "UnknownError"; diff --git a/src/common/util/status.h b/src/common/util/status.h index bf12f1e51..916be7787 100644 --- a/src/common/util/status.h +++ b/src/common/util/status.h @@ -257,6 +257,8 @@ enum class StatusCode : unsigned char { kGlobalObjectInvalid = 51, + kNameExists = 91, + kUnknownError = 255 }; @@ -595,6 +597,10 @@ class VINEYARD_MUST_USE_TYPE Status { return Status(StatusCode::kGlobalObjectInvalid, message); } + static Status NameExists(std::string const& name) { + return Status(StatusCode::kNameExists, "Name " + name + " already exists"); + } + /// Return an error status for unknown errors static Status UnknownError(std::string const& message = "") { return Status(StatusCode::kUnknownError, message); @@ -705,6 +711,8 @@ class VINEYARD_MUST_USE_TYPE Status { bool IsGlobalObjectInvalid() const { return code() == StatusCode::kGlobalObjectInvalid; } + /// Return true iff the name already exists + bool IsNameExists() const { return code() == StatusCode::kNameExists; } /// Return true iff the status indicates an unknown error. bool IsUnknownError() const { return code() == StatusCode::kUnknownError; } diff --git a/src/server/async/socket_server.cc b/src/server/async/socket_server.cc index ff6a273e0..3c1e3f44e 100644 --- a/src/server/async/socket_server.cc +++ b/src/server/async/socket_server.cc @@ -295,6 +295,8 @@ bool SocketConnection::processMessage(const std::string& message_in) { return doDelData(root); } else if (cmd == command_t::LIST_DATA_REQUEST) { return doListData(root); + } else if (cmd == command_t::LIST_BY_REQUEST) { + return doListBy(root); } else if (cmd == command_t::EXISTS_REQUEST) { return doExists(root); } else if (cmd == command_t::PERSIST_REQUEST) { @@ -1430,10 +1432,11 @@ bool SocketConnection::doPutName(const json& root) { auto self(shared_from_this()); ObjectID object_id; std::string name; - TRY_READ_REQUEST(ReadPutNameRequest, root, object_id, name); + bool unique; + TRY_READ_REQUEST(ReadPutNameRequest, root, object_id, name, unique); name = escape_json_pointer(name); - RESPONSE_ON_ERROR( - server_ptr_->PutName(object_id, name, [self](const Status& status) { + RESPONSE_ON_ERROR(server_ptr_->PutName( + object_id, name, unique, [self](const Status& status) { std::string message_out; if (status.ok()) { WritePutNameReply(message_out); @@ -1500,6 +1503,29 @@ bool SocketConnection::doListName(const json& root) { return false; } +bool SocketConnection::doListBy(const json& root) { + auto self(shared_from_this()); + std::string field; + std::string pattern; + bool regex; + size_t limit; + TRY_READ_REQUEST(ReadListByRequest, root, field, pattern, regex, limit); + RESPONSE_ON_ERROR(server_ptr_->ListBy( + field, pattern, regex, limit, + [self](const Status& status, const std::vector& ids) { + std::string message_out; + if (status.ok()) { + WriteListByReply(ids, message_out); + } else { + VLOG(100) << "Error: " << status.ToString(); + WriteErrorReply(status, message_out); + } + self->doWrite(message_out); + return Status::OK(); + })); + return false; +} + bool SocketConnection::doDropName(const json& root) { auto self(shared_from_this()); std::string name; diff --git a/src/server/async/socket_server.h b/src/server/async/socket_server.h index da2c3dda5..b3b219b72 100644 --- a/src/server/async/socket_server.h +++ b/src/server/async/socket_server.h @@ -102,6 +102,7 @@ class SocketConnection : public std::enable_shared_from_this { bool doCreateDatas(json const& root); bool doGetData(json const& root); bool doListData(json const& root); + bool doListBy(json const& root); bool doDelData(json const& root); bool doExists(json const& root); bool doPersist(json const& root); diff --git a/src/server/server/vineyard_server.cc b/src/server/server/vineyard_server.cc index 1e493f3d6..d6146add5 100644 --- a/src/server/server/vineyard_server.cc +++ b/src/server/server/vineyard_server.cc @@ -452,6 +452,28 @@ Status VineyardServer::ListName( return Status::OK(); } +Status VineyardServer::ListBy( + std::string const& field, std::string const& pattern, bool const regex, + size_t const limit, callback_t&> callback) { + ENSURE_VINEYARDD_READY(); + auto self(shared_from_this()); + meta_service_ptr_->RequestToGetData(true, [field, pattern, regex, limit, + callback](const Status& status, + const json& meta) { + if (status.ok()) { + std::vector ids; + Status s; + VCATCH_JSON_ERROR( + meta, s, meta_tree::ListBy(meta, field, pattern, regex, limit, ids)); + return callback(s, ids); + } else { + VLOG(100) << "Error: " << status.ToString(); + return status; + } + }); + return Status::OK(); +} + namespace detail { Status validate_metadata(const json& tree, json& result, Signature& signature, @@ -868,12 +890,13 @@ Status VineyardServer::DeleteAllAt(const json& meta, } Status VineyardServer::PutName(const ObjectID object_id, - const std::string& name, callback_t<> callback) { + const std::string& name, const bool unique, + callback_t<> callback) { ENSURE_VINEYARDD_READY(); auto self(shared_from_this()); meta_service_ptr_->RequestToPersist( - [object_id, name](const Status& status, const json& meta, - std::vector& ops) { + [object_id, name, unique](const Status& status, const json& meta, + std::vector& ops) { if (status.ok()) { // TODO: do proper validation: // 1. global objects can have name, local ones cannot. @@ -909,6 +932,19 @@ Status VineyardServer::PutName(const ObjectID object_id, "transient objects cannot have name, please persist it first"); } + if (unique) { + std::map names; + Status s; + VCATCH_JSON_ERROR( + meta, s, + meta_tree::ListName(meta, name, false, /* limit */ 1, names)); + VINEYARD_DISCARD(s); + // if unique is requested and name exists + if (!names.empty()) { + return Status::NameExists(name); + } + } + ops.emplace_back(meta_tree::op_t::Put("/names/" + name, object_id)); ops.emplace_back(meta_tree::op_t::Put( "/data/" + ObjectIDToString(object_id) + "/__name", @@ -1201,12 +1237,18 @@ Status VineyardServer::LabelObjects(const ObjectID object_id, if (is_transient) { self->meta_service_ptr_->RequestToBulkUpdate( [callback, object_id, label_string]( - const Status& status, const json&, + const Status& status, const json& new_tree, std::vector& ops, ObjectID&, Signature&, InstanceID&) { if (!status.ok()) { return callback(status); } + if (!new_tree.contains("data") || + !new_tree["data"].contains(ObjectIDToString(object_id))) { + return callback(Status::ObjectNotExists( + "object " + ObjectIDToString(object_id) + + " doesn't exist")); + } ops.emplace_back(meta_tree::op_t::Put( "/data/" + ObjectIDToString(object_id) + "/__labels", label_string)); @@ -1217,11 +1259,17 @@ Status VineyardServer::LabelObjects(const ObjectID object_id, } else { self->meta_service_ptr_->RequestToPersist( [callback, object_id, label_string]( - const Status& status, const json&, + const Status& status, const json& new_tree, std::vector& ops) { if (!status.ok()) { return callback(status); } + if (!new_tree.contains("data") || + !new_tree["data"].contains(ObjectIDToString(object_id))) { + return callback(Status::ObjectNotExists( + "object " + ObjectIDToString(object_id) + + " doesn't exist")); + } ops.emplace_back(meta_tree::op_t::Put( "/data/" + ObjectIDToString(object_id) + "/__labels", label_string)); diff --git a/src/server/server/vineyard_server.h b/src/server/server/vineyard_server.h index 085310aad..ca6c0ef53 100644 --- a/src/server/server/vineyard_server.h +++ b/src/server/server/vineyard_server.h @@ -125,6 +125,10 @@ class VineyardServer : public std::enable_shared_from_this { size_t const limit, callback_t&> callback); + Status ListBy(std::string const& field, std::string const& pattern, + bool const regex, size_t const limit, + callback_t&> callback); + Status CreateData( const json& tree, callback_t callback); @@ -168,7 +172,7 @@ class VineyardServer : public std::enable_shared_from_this { Status DeleteAllAt(const json& meta, InstanceID const instance_id); Status PutName(const ObjectID object_id, const std::string& name, - callback_t<> callback); + const bool unique, callback_t<> callback); Status GetName(const std::string& name, const bool wait, DeferredReq::alive_t alive, // if connection is still alive diff --git a/src/server/util/meta_tree.cc b/src/server/util/meta_tree.cc index 4d48bf75b..a9aafa013 100644 --- a/src/server/util/meta_tree.cc +++ b/src/server/util/meta_tree.cc @@ -232,6 +232,31 @@ static Status get_name(const json& tree, std::string& name, return Status::OK(); } +static Status get_field(const json& tree, std::string const& field, + std::string& value, bool const decode = false) { + json::const_iterator iter = tree.find(field); + if (iter == tree.end()) { + return Status::MetaTreeInvalid(field + + " not found in metadata: " + tree.dump(4)); + } + if (iter->is_object()) { + LOG(ERROR) << "meta tree id invalid. " << *iter; + return Status::MetaTreeInvalid("type of '" + field + + "' in metadata is wrong: " + iter->dump(4)); + } + value = iter->get_ref(); + if (decode) { + NodeType node_type = NodeType::InvalidType; + decode_value(value, node_type, value); + if (node_type != NodeType::Value) { + return Status::MetaTreeInvalid( + "value of '" + field + + "' in metadata is not of value type: " + value); + } + } + return Status::OK(); +} + static Status get_type(const json& tree, std::string& type, bool const decode = false) { // type: get the typename @@ -432,6 +457,40 @@ Status ListName(const json& tree, std::string const& pattern, bool const regex, return Status::OK(); } +Status ListBy(const json& tree, std::string const& field, + std::string const& pattern, bool const regex, size_t const limit, + std::vector& ids) { + if (!tree.contains("data")) { + return Status::OK(); + } + + size_t found = 0; + for (auto const& item : tree["data"].items()) { + if (!item.value().is_object() || item.value().empty()) { + LOG(INFO) << "metadata tree in vineyardd shouldn't be empty"; + return Status::MetaTreeInvalid( + "metadata tree in vineyard server is empty in 'ListBy'"); + } + + if (found >= limit) { + break; + } + + std::string value; + auto status = get_field(item.value(), field, value, true); + if (!status.ok()) { + continue; + } + + // match on pattern + if (MatchTypeName(regex, pattern, value)) { + found += 1; + ids.push_back(ObjectIDFromString(item.key())); + } + } + return Status::OK(); +} + Status DelDataOps(const json& tree, const ObjectID id, std::vector& ops, bool& sync_remote) { if (IsBlob(id)) { diff --git a/src/server/util/meta_tree.h b/src/server/util/meta_tree.h index 43152afa9..00ead646b 100644 --- a/src/server/util/meta_tree.h +++ b/src/server/util/meta_tree.h @@ -94,6 +94,9 @@ Status ListData(const json& tree, const std::string& instance_name, Status ListAllData(const json& tree, std::vector& objects); Status ListName(const json& tree, std::string const& pattern, bool const regex, size_t const limit, std::map& names); +Status ListBy(const json& tree, std::string const& field, + std::string const& pattern, bool const regex, size_t const limit, + std::vector& ids); Status IfPersist(const json& tree, const ObjectID id, bool& persist); Status Exists(const json& tree, const ObjectID id, bool& exists); diff --git a/test/evicting_cache_map_test.cc b/test/evicting_cache_map_test.cc new file mode 100644 index 000000000..71dfbe2bf --- /dev/null +++ b/test/evicting_cache_map_test.cc @@ -0,0 +1,737 @@ +/* + * Copyright 2014-present Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/util/evicting_cache_map.h" +#include "common/util/logging.h" + +using namespace vineyard; // NOLINT(build/namespaces) + +void SanityTest() { + EvictingCacheMap map(0); + + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + CHECK(!map.exists(1)); + map.set(1, 1); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + CHECK_EQ(1, map.get(1)); + CHECK(map.exists(1)); + map.set(1, 2); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + CHECK_EQ(2, map.get(1)); + CHECK(map.exists(1)); + map.erase(1); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + CHECK(!map.exists(1)); + + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + CHECK(!map.exists(1)); + map.set(1, 1); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + CHECK_EQ(1, map.get(1)); + CHECK(map.exists(1)); + map.set(1, 2); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + CHECK_EQ(2, map.get(1)); + CHECK(map.exists(1)); + + CHECK(!map.exists(2)); + map.set(2, 1); + CHECK(map.exists(2)); + CHECK_EQ(2, map.size()); + CHECK(!map.empty()); + CHECK_EQ(1, map.get(2)); + map.set(2, 2); + CHECK_EQ(2, map.size()); + CHECK(!map.empty()); + CHECK_EQ(2, map.get(2)); + CHECK(map.exists(2)); + map.erase(2); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + CHECK(!map.exists(2)); + map.erase(1); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + CHECK(!map.exists(1)); +} + +void PruneTest() { + EvictingCacheMap map(0); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(1000000); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(100); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(99); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 99; i++) { + CHECK(!map.exists(i)); + } + CHECK(map.exists(99)); + CHECK_EQ(99, map.get(99)); + + map.prune(100); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(90); + CHECK_EQ(10, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 90; i++) { + CHECK(!map.exists(i)); + } + for (int i = 90; i < 100; i++) { + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } +} + +void PruneHookTest() { + EvictingCacheMap map(0); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + int sum = 0; + auto pruneCb = [&](auto&& pairs) { + for (auto&& [k, v] : pairs) { + CHECK_EQ(k, v); + sum += k; + } + }; + + map.setPruneHook(pruneCb); + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(1000000); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + CHECK_EQ((99 * 100) / 2, sum); + sum = 0; + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(100); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + CHECK_EQ((99 * 100) / 2, sum); + sum = 0; + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(99); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 99; i++) { + CHECK(!map.exists(i)); + } + CHECK(map.exists(99)); + CHECK_EQ(99, map.get(99)); + + CHECK_EQ((98 * 99) / 2, sum); + sum = 0; + + map.prune(100); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + CHECK_EQ(99, sum); + sum = 0; + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + map.prune(90); + CHECK_EQ(10, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 90; i++) { + CHECK(!map.exists(i)); + } + for (int i = 90; i < 100; i++) { + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + CHECK_EQ((89 * 90) / 2, sum); + sum = 0; +} + +void SetMaxSize() { + EvictingCacheMap map(100, 20); + for (int i = 0; i < 90; i++) { + map.set(i, i); + CHECK(map.exists(i)); + } + + CHECK_EQ(90, map.size()); + map.setMaxSize(50); + CHECK_EQ(map.size(), 50); + + for (int i = 0; i < 90; i++) { + map.set(i, i); + CHECK(map.exists(i)); + } + CHECK_EQ(40, map.size()); + map.setMaxSize(0); + CHECK_EQ(40, map.size()); + map.setMaxSize(10); + CHECK_EQ(10, map.size()); +} + +void SetClearSize() { + EvictingCacheMap map(100, 20); + for (int i = 0; i < 90; i++) { + map.set(i, i); + CHECK(map.exists(i)); + } + + CHECK_EQ(90, map.size()); + map.setClearSize(40); + map.setMaxSize(50); + CHECK_EQ(map.size(), 50); + + for (int i = 0; i < 90; i++) { + map.set(i, i); + CHECK(map.exists(i)); + } + CHECK_EQ(20, map.size()); + map.setMaxSize(0); + CHECK_EQ(20, map.size()); + map.setMaxSize(10); + CHECK_EQ(0, map.size()); +} + +void DestructorInvocationTest() { + struct SumInt { + SumInt(int val_, int* ref_) : val(val_), ref(ref_) {} + ~SumInt() { *ref += val; } + + SumInt(SumInt const&) = delete; + SumInt& operator=(SumInt const&) = delete; + + SumInt(SumInt&& other) : val(std::exchange(other.val, 0)), ref(other.ref) {} + SumInt& operator=(SumInt&& other) { + std::swap(val, other.val); + std::swap(ref, other.ref); + return *this; + } + + int val; + int* ref; + }; + + int sum; + EvictingCacheMap map(0); + + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, SumInt(i, &sum)); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i).val); + } + + sum = 0; + map.prune(1000000); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + CHECK_EQ((99 * 100) / 2, sum); + + for (int i = 0; i < 100; i++) { + map.set(i, SumInt(i, &sum)); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i).val); + } + + sum = 0; + map.prune(100); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + CHECK_EQ((99 * 100) / 2, sum); + + for (int i = 0; i < 100; i++) { + map.set(i, SumInt(i, &sum)); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i).val); + } + + sum = 0; + map.prune(99); + CHECK_EQ(1, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 99; i++) { + CHECK(!map.exists(i)); + } + CHECK(map.exists(99)); + CHECK_EQ(99, map.get(99).val); + + CHECK_EQ((98 * 99) / 2, sum); + + sum = 0; + map.prune(100); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + CHECK_EQ(99, sum); + for (int i = 0; i < 100; i++) { + map.set(i, SumInt(i, &sum)); + CHECK_EQ(i + 1, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i).val); + } + + sum = 0; + map.prune(90); + CHECK_EQ(10, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 90; i++) { + CHECK(!map.exists(i)); + } + for (int i = 90; i < 100; i++) { + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i).val); + } + CHECK_EQ((89 * 90) / 2, sum); + sum = 0; + for (int i = 0; i < 90; i++) { + auto pair = map.insert(i, SumInt(i + 1, &sum)); + CHECK_EQ(i + 1, pair.first->second.val); + CHECK(pair.second); + CHECK(map.exists(i)); + } + CHECK_EQ(0, sum); + for (int i = 90; i < 100; i++) { + auto pair = map.insert(i, SumInt(i + 1, &sum)); + CHECK_EQ(i, pair.first->second.val); + CHECK(!pair.second); + CHECK(map.exists(i)); + } + CHECK_EQ((10 * 191) / 2, sum); + sum = 0; + map.prune(100); + CHECK_EQ((90 * 91) / 2 + (10 * 189) / 2, sum); + + sum = 0; + map.set(3, SumInt(3, &sum)); + map.set(2, SumInt(2, &sum)); + map.set(1, SumInt(1, &sum)); + CHECK_EQ(0, sum); + CHECK_EQ(2, map.erase(map.find(1))->second.val); + CHECK_EQ(1, sum); + CHECK(map.end() == map.erase(map.findWithoutPromotion(3))); + CHECK_EQ(4, sum); + map.prune(1); + CHECK_EQ(6, sum); +} + +void LruSanityTest() { + EvictingCacheMap map(10); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_GE(10, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + CHECK_EQ(10, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 90; i++) { + CHECK(!map.exists(i)); + } + for (int i = 90; i < 100; i++) { + CHECK(map.exists(i)); + } +} + +void LruPromotionTest() { + EvictingCacheMap map(10); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_GE(10, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + for (int j = 0; j < std::min(i + 1, 9); j++) { + CHECK(map.exists(j)); + CHECK_EQ(j, map.get(j)); + } + } + + CHECK_EQ(10, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 9; i++) { + CHECK(map.exists(i)); + } + CHECK(map.exists(99)); + for (int i = 10; i < 99; i++) { + CHECK(!map.exists(i)); + } +} + +void LruNoPromotionTest() { + EvictingCacheMap map(10); + CHECK_EQ(0, map.size()); + CHECK(map.empty()); + for (int i = 0; i < 100; i++) { + CHECK(!map.exists(i)); + } + + for (int i = 0; i < 100; i++) { + map.set(i, i); + CHECK_GE(10, map.size()); + CHECK(!map.empty()); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + for (int j = 0; j < std::min(i + 1, 9); j++) { + if (map.exists(j)) { + CHECK_EQ(j, map.getWithoutPromotion(j)); + } + } + } + + CHECK_EQ(10, map.size()); + CHECK(!map.empty()); + for (int i = 0; i < 90; i++) { + CHECK(!map.exists(i)); + } + for (int i = 90; i < 100; i++) { + CHECK(map.exists(i)); + } +} + +void IteratorSanityTest() { + const int nItems = 1000; + EvictingCacheMap map(nItems); + CHECK(map.begin() == map.end()); + for (int i = 0; i < nItems; i++) { + CHECK(!map.exists(i)); + map.set(i, i * 2); + CHECK(map.exists(i)); + CHECK_EQ(i * 2, map.get(i)); + } + + std::set seen; + for (auto& it : map) { + CHECK_EQ(0, seen.count(it.first)); + seen.insert(it.first); + CHECK_EQ(it.first * 2, it.second); + } + CHECK_EQ(nItems, seen.size()); +} + +void FindTest() { + const int nItems = 1000; + EvictingCacheMap map(nItems); + for (int i = 0; i < nItems; i++) { + map.set(i * 2, i * 2); + CHECK(map.exists(i * 2)); + CHECK_EQ(i * 2, map.get(i * 2)); + } + for (int i = 0; i < nItems * 2; i++) { + if (i % 2 == 0) { + auto it = map.find(i); + CHECK(it != map.end()); + CHECK_EQ(i, it->first); + CHECK_EQ(i, it->second); + } else { + CHECK(map.find(i) == map.end()); + } + } + for (int i = nItems * 2 - 1; i >= 0; i--) { + if (i % 2 == 0) { + auto it = map.find(i); + CHECK(it != map.end()); + CHECK_EQ(i, it->first); + CHECK_EQ(i, it->second); + } else { + CHECK(map.find(i) == map.end()); + } + } + CHECK_EQ(0, map.begin()->first); +} + +void FindWithoutPromotionTest() { + const int nItems = 1000; + EvictingCacheMap map(nItems); + for (int i = 0; i < nItems; i++) { + map.set(i * 2, i * 2); + CHECK(map.exists(i * 2)); + CHECK_EQ(i * 2, map.get(i * 2)); + } + for (int i = nItems * 2 - 1; i >= 0; i--) { + if (i % 2 == 0) { + auto it = map.findWithoutPromotion(i); + CHECK(it != map.end()); + CHECK_EQ(i, it->first); + CHECK_EQ(i, it->second); + } else { + CHECK(map.findWithoutPromotion(i) == map.end()); + } + } + CHECK_EQ((nItems - 1) * 2, map.begin()->first); +} + +void IteratorOrderingTest() { + const int nItems = 1000; + EvictingCacheMap map(nItems); + for (int i = 0; i < nItems; i++) { + map.set(i, i); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + int expected = nItems - 1; + for (auto it = map.begin(); it != map.end(); ++it) { + CHECK_EQ(expected, it->first); + expected--; + } + + expected = 0; + for (auto it = map.rbegin(); it != map.rend(); ++it) { + CHECK_EQ(expected, it->first); + expected++; + } + + { + auto it = map.end(); + expected = 0; + CHECK(it != map.begin()); + do { + --it; + CHECK_EQ(expected, it->first); + expected++; + } while (it != map.begin()); + CHECK_EQ(nItems, expected); + } + + { + auto it = map.rend(); + expected = nItems - 1; + do { + --it; + CHECK_EQ(expected, it->first); + expected--; + } while (it != map.rbegin()); + CHECK_EQ(-1, expected); + } +} + +void MoveTest() { + const int nItems = 1000; + EvictingCacheMap map(nItems); + for (int i = 0; i < nItems; i++) { + map.set(i, i); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + } + + EvictingCacheMap map2 = std::move(map); + CHECK(map.empty()); + for (int i = 0; i < nItems; i++) { + CHECK(map2.exists(i)); + CHECK_EQ(i, map2.get(i)); + } +} + +void CustomKeyEqual() { + const int nItems = 100; + struct Eq { + bool operator()(const int& a, const int& b) const { + return (a % mod) == (b % mod); + } + int mod; + }; + struct Hash { + size_t operator()(const int& a) const { return std::hash()(a % mod); } + int mod; + }; + EvictingCacheMap map(nItems, 1 /* clearSize */, + Hash{nItems}, Eq{nItems}); + for (int i = 0; i < nItems; i++) { + map.set(i, i); + CHECK(map.exists(i)); + CHECK_EQ(i, map.get(i)); + CHECK(map.exists(i + nItems)); + CHECK_EQ(i, map.get(i + nItems)); + } +} + +void IteratorConversion() { + using type = EvictingCacheMap; + using i = type::iterator; + using ci = type::const_iterator; + using ri = type::reverse_iterator; + using cri = type::const_reverse_iterator; + + CHECK((std::is_convertible::value)); + CHECK((std::is_convertible::value)); + CHECK(!(std::is_convertible::value)); + CHECK((std::is_convertible::value)); + + CHECK((std::is_convertible::value)); + CHECK((std::is_convertible::value)); + CHECK(!(std::is_convertible::value)); + CHECK((std::is_convertible::value)); +} + +int main(int argc, char** argv) { + SanityTest(); + PruneTest(); + PruneHookTest(); + SetMaxSize(); + SetClearSize(); + DestructorInvocationTest(); + LruSanityTest(); + LruPromotionTest(); + LruNoPromotionTest(); + IteratorSanityTest(); + FindTest(); + FindWithoutPromotionTest(); + IteratorOrderingTest(); + MoveTest(); + CustomKeyEqual(); + IteratorConversion(); + + LOG(INFO) << "All tests passed"; + return 0; +} diff --git a/test/runner.py b/test/runner.py index a02d4223b..4e493ec44 100755 --- a/test/runner.py +++ b/test/runner.py @@ -931,6 +931,20 @@ def run_llm_tests(meta, allocator, endpoints): cwd=os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'), ) + subprocess.check_call( + [ + './build/bin/aibrix_kv_cache_test', + '--client-num', + '4', + '--vineyard-ipc-sockets', + vineyard_ipc_socket_1, + vineyard_ipc_socket_2, + vineyard_ipc_socket_1, + vineyard_ipc_socket_2, + ], + cwd=os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'), + ) + def run_llm_python_tests(meta, allocator, endpoints, test_args): meta_prefix = 'vineyard_test_%s' % time.time() From def5a4570f8f317dc67904fc25d3632ded6c0892 Mon Sep 17 00:00:00 2001 From: DwyaneShi Date: Tue, 5 Nov 2024 22:09:26 -0800 Subject: [PATCH 2/5] build: use arrow 17.0.0 for testing some unit tests fail if using arrow 18.0.0 Signed-off-by: DwyaneShi --- .github/workflows/build-test.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index b91764ae1..d63692ddf 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -126,12 +126,12 @@ jobs: wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb sudo apt update - sudo apt install -y libarrow-dev \ - libarrow-dataset-dev \ - libarrow-acero-dev \ - libarrow-flight-dev \ - libgandiva-dev \ - libparquet-dev + sudo apt install -y libarrow-dev=17.0.0-1 \ + libarrow-dataset-dev=17.0.0-1 \ + libarrow-acero-dev=17.0.0-1 \ + libarrow-flight-dev=17.0.0-1 \ + libgandiva-dev=17.0.0-1 \ + libparquet-dev=17.0.0-1 # install deps for java sudo apt install -y default-jdk-headless maven From 09ebf7fee5374aa2802741d3eb290b16a33804f9 Mon Sep 17 00:00:00 2001 From: DwyaneShi Date: Mon, 11 Nov 2024 14:32:21 -0800 Subject: [PATCH 3/5] build: fix dependency issue Signed-off-by: DwyaneShi --- CMakeLists.txt | 2 +- docker/Dockerfile.vineyard-python-dev | 2 +- docker/Makefile | 2 +- docker/pypa/Dockerfile.manylinux1 | 31 ++++++++++++++++++++++++- docker/pypa/Dockerfile.manylinux1-wheel | 2 +- 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d6600cdc5..31bd7c352 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -745,7 +745,7 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux") endif() # boost is only required by some components -if(BUILD_VINEYARD_SERVER OR BUILD_VINEYARD_IO OR BUILD_VINEYARD_GRAPH) +if(BUILD_VINEYARD_SERVER OR BUILD_VINEYARD_IO OR BUILD_VINEYARD_GRAPH OR BUILD_VINEYARD_LLM_CACHE) find_boost() endif() diff --git a/docker/Dockerfile.vineyard-python-dev b/docker/Dockerfile.vineyard-python-dev index 6511b9d29..eb0fef241 100644 --- a/docker/Dockerfile.vineyard-python-dev +++ b/docker/Dockerfile.vineyard-python-dev @@ -13,7 +13,7 @@ # limitations under the License. # build vineyard-python-dev -FROM ghcr.io/aibrix/v6d/vineyard-manylinux2014:20241014 as wheel +FROM ghcr.io/aibrix/v6d/vineyard-manylinux2014:20241108 as wheel ENV python=cp310-cp310 diff --git a/docker/Makefile b/docker/Makefile index 19dd080a4..30613c014 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -18,7 +18,7 @@ ALPINE_TAG := $(ALPINE_MANIFEST_TAG)_$(PLATFORM) WHEEL_BUILDER_REGISTRY := $(REGISTRY) WHEEL_BUILDER_IMAGE := vineyard-manylinux2014 -WHEEL_BUILDER_MANIFEST_TAG := 20241014 +WHEEL_BUILDER_MANIFEST_TAG := 20241108 WHEEL_BUILDER_TAG := $(WHEEL_BUILDER_MANIFEST_TAG)_$(PLATFORM) WHEEL_PYTHON := cp311-cp311 diff --git a/docker/pypa/Dockerfile.manylinux1 b/docker/pypa/Dockerfile.manylinux1 index 8ddee3229..826baf2fe 100644 --- a/docker/pypa/Dockerfile.manylinux1 +++ b/docker/pypa/Dockerfile.manylinux1 @@ -26,7 +26,7 @@ RUN mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backu curl -o /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-altarch-7.repo; \ fi && \ yum -y update && \ - yum -y install devtoolset-10-libatomic-devel libtool + yum -y install devtoolset-10-libatomic-devel libtool openmpi-devel wget # target: ghcr.io/aibrix/v6d/vineyard-manylinux2014:20240218_$PLATFORM @@ -67,6 +67,35 @@ RUN echo "Installing gflags ..." && \ make install -j`nproc` && \ rm -rf /deps +# Install boost +RUN echo "Installing boost ..." && \ + cd /tmp && \ + wget -q https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz && \ + tar zxf boost_1_75_0.tar.gz && \ + cd boost_1_75_0 && \ + ./bootstrap.sh && \ + ./b2 install -j`nproc` link=static runtime-link=static variant=release threading=multi \ + --with-atomic \ + --with-chrono \ + --with-date_time \ + --with-filesystem \ + --with-random \ + --with-system \ + --with-thread && \ + cd /tmp && \ + rm -rf boost_1_75_0.tar.gz boost_1_75_0 + +# Install openssl +RUN echo "Installing openssl ..." && \ + cd /tmp && \ + wget -q --no-check-certificate https://www.openssl.org/source/openssl-1.1.1j.tar.gz && \ + tar zxf openssl-1.1.1j.tar.gz && \ + cd openssl-1.1.1j && \ + ./config -no-shared -no-tests && \ + make -j`nproc` && make install -j`nproc` || true && \ + cd /tmp && \ + rm -rf openssl-1.1.1j.tar.gz openssl-1.1.1j + RUN echo "Installing apache-arrow ..." && \ mkdir -p /deps && \ cd /deps && \ diff --git a/docker/pypa/Dockerfile.manylinux1-wheel b/docker/pypa/Dockerfile.manylinux1-wheel index e5c928754..0f5f3a68e 100644 --- a/docker/pypa/Dockerfile.manylinux1-wheel +++ b/docker/pypa/Dockerfile.manylinux1-wheel @@ -32,7 +32,7 @@ ADD . /work/v6d RUN cd /work/v6d && \ mkdir build && \ cd build && \ - export PATH=/opt/python/$python/bin:$PATH && \ + export PATH=/usr/lib64/openmpi/bin/:/opt/python/$python/bin:$PATH && \ pip install -U pip setuptools wheel libclang parsec && \ cmake .. -DCMAKE_CXX_STANDARD=17 \ -DBUILD_SHARED_LIBS=OFF \ From 06be803b99e65f2a261bbd223807358be3d8298d Mon Sep 17 00:00:00 2001 From: DwyaneShi Date: Mon, 11 Nov 2024 14:32:21 -0800 Subject: [PATCH 4/5] Enhance comments for major components Signed-off-by: DwyaneShi --- modules/llm-cache/ds/kv_cache_chunk.h | 55 +++++++++++++++++++ .../llm-cache/storage/aibrix_blob_storage.h | 48 +++++++++++++++- 2 files changed, 101 insertions(+), 2 deletions(-) diff --git a/modules/llm-cache/ds/kv_cache_chunk.h b/modules/llm-cache/ds/kv_cache_chunk.h index b7f3d133f..4181c4da2 100644 --- a/modules/llm-cache/ds/kv_cache_chunk.h +++ b/modules/llm-cache/ds/kv_cache_chunk.h @@ -39,6 +39,24 @@ namespace vineyard { // forward declaration struct LLMKV; +// A KVCacheChunk contains all the KV tensors of a fixed number of +// tokens (i.e., `chunk_size`). +// +// In its object blob, we first store all the KV tensors, and then +// store all the tokens (including prefix tokens and current tokens +// cached in the chunk), which will be used to avoid hash conflicts. +// +// In its metadata, we store the namespace (i.e., `ns_`), which will +// be used as the name prefix of each chunk. Clients can also use the +// namespace to list all the chunks. Access time (i.e., 'access_time_`) +// in its metadata is used for the TTL-based global GC. We also have +// the md5sum of all tokens (including prefix tokens and current tokens) +// in its metadata. When we reconstruct a chunk from the object blob +// and metadata, we calculate the md5sum of all tokens in the blob and +// compare it with the md5sum in the metadata. If they are the same, +// we consider the chunk is valid. Otherwise, we consider the chunk is +// corrupted. By far, we don't use the md5sum of the tensors to alleviate +// the compute overhead. class KVCacheChunk : public vineyard::Registered { public: inline static constexpr char kFieldNameNS[] = "namespace"; @@ -52,12 +70,17 @@ class KVCacheChunk : public vineyard::Registered { private: std::shared_ptr buffer_; + // number of prefix tokens and current tokens in the chunk int total_tokens_; int tensor_nbytes_; int layer_; int chunk_size_; + // access time is used for TTL-based global GC uint64_t access_time_; + // md5sum of all tokens (including prefix tokens and current tokens) std::string md5_; + // namespace. chunks within the same namespace will be shared + // among different clients std::string ns_; public: @@ -79,6 +102,29 @@ class KVCacheChunk : public vineyard::Registered { friend class KVCacheChunkBuilder; }; +// A KVCacheChunkBuilder is used to build a KVCacheChunk. +// +// We have two kinds of builders: +// 1. The builder to build a new chunk. +// 2. The builder to rebuild a chunk from the object blob and metadata. +// +// For the first kind of builder, `Make` creates an empty chunk and an +// `Update` filles the chunk with KV tensors. After `Update`, the chunk +// is marked as ready and waiting readers will be notified. This kind +// of builder can be sealed to a KVCacheChunk. +// +// For the second kind of builder, `Make` only assignes the chunk id and +// the first `Query` will trigger a construction of the chunk, i.e., +// constructing the corresponding chunk with fetched metadata and blob. +// After construction, the chunk is marked as ready and other waiting +// readers will be notified. This kind of builder will never be sealed +// since the chunk already exists in the object store. +// +// We also track the access time of the chunk in the builder. Global +// access time is the latest access time of the global object we know. +// Access time is the local access time that is updated by each access. +// The local access time will finally be updated to the global access +// time based on the policy used in AIBrixBlobStorage. class KVCacheChunkBuilder { private: RPCClient& rpc_client_; @@ -86,16 +132,23 @@ class KVCacheChunkBuilder { std::shared_ptr remote_buffer_writer_ = nullptr; ObjectID chunk_id_; std::shared_ptr buffer_ = nullptr; + int total_tokens_; int tensor_nbytes_; int layer_; int chunk_size_; std::string ns_; + + // `time_mu_` protects the access times of the chunk. std::shared_mutex time_mu_; uint64_t g_access_time_ = 0; uint64_t access_time_ = 0; + + // `mutex_` and `cv_` are used to block readers until the chunk + // is ready to be read. std::mutex mutex_; std::condition_variable cv_; + std::atomic is_ready_ = false; std::string md5_; @@ -140,6 +193,7 @@ class KVCacheChunkBuilder { return access_time_; } + // Whether the chunk is ready to be read. bool IsReady() { return is_ready_; } std::shared_ptr Seal(); @@ -150,6 +204,7 @@ class KVCacheChunkBuilder { void PrintKVCacheChunk(); + // Whether the chunk is the same as the chunk with the given metadata. Status IsSame(const ObjectMeta& meta); KVCacheChunkBuilder(RPCClient& rpc_client, int tensor_nbytes, int layer, diff --git a/modules/llm-cache/storage/aibrix_blob_storage.h b/modules/llm-cache/storage/aibrix_blob_storage.h index c6e141deb..bc09dd82f 100644 --- a/modules/llm-cache/storage/aibrix_blob_storage.h +++ b/modules/llm-cache/storage/aibrix_blob_storage.h @@ -34,17 +34,59 @@ limitations under the License. namespace vineyard { +// AIBrixBlobStorage is the storage backend of KVCacheChunk. +// It employs the S3-FIFO replacement policy to retain scan- +// resistant and recognize hot chunks. Please refer to member +// variable comments for more details of the S3-FIFO policy. +// +// In our implementation, the Main FIFO list of S3-FIFO is a +// mirror of the global chunk list. New chunks in the Main +// FIFO list will be periodically persisted to the global +// chunk list by the LocalSync function. Persisted chunks +// evicted from the Main FIFO list will be deleted from the +// global chunk list. +// +// Each chunk has an associated name that is generated by +// equation: name = namespace + "_" + hash(hash(previous chunk) +// + tokens of current chunk) +// Please refer to computeChunkHashesForTokens for more details. +// +// Each name is supposed to be unique. For a given prefix tokens +// and query tokens, after generating the chunk names, we will +// use the names to get the corresponding chunks if exist. +// +// Each global chunk has an associated label called "access_time", +// which indicates the last access time of the chunk. For those +// chunks cached in the local FIFO lists, we will update their +// access time upon each assess but only push the access time +// to the global during LocalSync function. +// +// In GlobalGC, we will list all the global chunks within the +// namespace, and check if any chunks reach the TTL. If so, we +// will delete them from the global. +// +// We use threadpool to perform memory copies in parallel for +// both `Query` and `Update` to speed up the cache. The return +// of `Query` and `Update` indicates the completion of all the +// memory copies and it is safe to reuse the input buffers. class AIBrixBlobStorage : public IStorage, public std::enable_shared_from_this { private: + // Max number of tokens supported by the cache. If the total + // number of prefix tokens and current tokens of an update + // exceeds the max tokens, we will drop the update. static constexpr int kMaxTokensPerSeq = 64 * 1024; static constexpr double kSmallFifoCapacityRatio = 0.3; + // The preferred number of evicted items for each eviction of + // the Main FIFO list to amortize the cost of deleting from + // the object store. static constexpr int kMinEviction = 32; RPCClient& rpc_client_; Client& ipc_client_; + // hash algorithm and hasher used to generate chunk hashes std::shared_ptr hash_alg_; std::shared_ptr hasher_; @@ -53,6 +95,8 @@ class AIBrixBlobStorage int chunk_size_; int capacity_; size_t chunk_obj_size_; + // namespace. chunks within the same namespace will be shared + // among different clients std::string kv_cache_ns_; // intervals in seconds @@ -61,6 +105,7 @@ class AIBrixBlobStorage // TTL in seconds std::chrono::duration global_ttl_s_; + // indicates whether the cache is closed bool exit_flag_ = false; // global GC is carried out in the global GC thread. @@ -120,8 +165,6 @@ class AIBrixBlobStorage EvictingCacheMap main_fifo_; // mirror of global chunk list - std::vector evict_list_; - public: AIBrixBlobStorage(RPCClient& rpc_client, Client& ipc_client, size_t tensor_nbytes, int capacity, int layer, @@ -196,6 +239,7 @@ class AIBrixBlobStorage std::vector>>& kv_tensors, size_t& matched); + // Seal and persist the chunk, and then put the given name for the chunk. Status SealAndPersist( const std::string& name, const std::shared_ptr& chunk_builder, From a034f48ed328b2fccf7216cd88249f6c4f327e57 Mon Sep 17 00:00:00 2001 From: DwyaneShi Date: Mon, 11 Nov 2024 14:32:21 -0800 Subject: [PATCH 5/5] Add exception handling Signed-off-by: DwyaneShi --- .../llm-cache/storage/aibrix_blob_storage.cc | 78 ++++++++++--------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/modules/llm-cache/storage/aibrix_blob_storage.cc b/modules/llm-cache/storage/aibrix_blob_storage.cc index 4916fffd6..512db437b 100644 --- a/modules/llm-cache/storage/aibrix_blob_storage.cc +++ b/modules/llm-cache/storage/aibrix_blob_storage.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -49,9 +50,10 @@ AIBrixBlobStorage::AIBrixBlobStorage( global_gc_enabled_(global_gc_enabled), global_gc_interval_s_(std::chrono::seconds(global_gc_interval_s)), global_ttl_s_(std::chrono::seconds(global_ttl_s)), - ghost_fifo_(capacity_), - small_fifo_(capacity_ * kSmallFifoCapacityRatio), - main_fifo_(capacity_ - capacity_ * kSmallFifoCapacityRatio, + ghost_fifo_(capacity_ / chunk_size), + small_fifo_(capacity_ / chunk_size * kSmallFifoCapacityRatio), + main_fifo_(capacity_ / chunk_size - + capacity_ / chunk_size * kSmallFifoCapacityRatio, kMinEviction) { kv_cache_ns_ = std::regex_replace(kv_cache_ns_, std::regex("/"), "_"); kv_cache_ns_ = std::regex_replace(kv_cache_ns_ + "_", std::regex("_+"), "_"); @@ -221,32 +223,34 @@ Status AIBrixBlobStorage::GetTokenChunkHashes( return Status::OK(); } -#define DEFINE_TASK_FN(FN, OP, CB) \ - auto FN = [this, &prefix, &tokens, &kv_tensors, cb = CB]( \ - size_t i, \ - std::shared_ptr builder) -> Status { \ - auto chunk_size = this->chunk_size_; \ - if (builder == nullptr) { \ - return Status::OK(); \ - } \ - \ - std::vector my_prefix(prefix.begin(), prefix.end()); \ - if (i > 0) { \ - my_prefix.insert(my_prefix.end(), tokens.begin(), \ - tokens.begin() + i * chunk_size); \ - } \ - std::vector my_tokens(tokens.begin() + i * chunk_size, \ - tokens.begin() + (i + 1) * chunk_size); \ - \ - std::vector>> my_kv_tensors( \ - kv_tensors.begin() + i * chunk_size, \ - kv_tensors.begin() + (i + 1) * chunk_size); \ - \ - auto status = builder->OP(my_prefix, my_tokens, my_kv_tensors); \ - if (status.ok()) { \ - cb(i, my_kv_tensors); \ - } \ - return status; \ +#define DEFINE_TASK_FN(FN, OP, CB) \ + auto FN = [this, &prefix, &tokens, &kv_tensors, cb = CB]( \ + size_t i, \ + std::shared_ptr builder) -> Status { \ + auto chunk_size = this->chunk_size_; \ + if (builder == nullptr) { \ + return Status::OK(); \ + } \ + \ + std::vector my_prefix(prefix.begin(), prefix.end()); \ + if (i > 0) { \ + my_prefix.insert(my_prefix.end(), tokens.begin(), \ + tokens.begin() + i * chunk_size); \ + } \ + std::vector my_tokens(tokens.begin() + i * chunk_size, \ + tokens.begin() + (i + 1) * chunk_size); \ + \ + std::vector>> my_kv_tensors( \ + kv_tensors.begin() + i * chunk_size, \ + kv_tensors.begin() + (i + 1) * chunk_size); \ + \ + try { \ + auto status = builder->OP(my_prefix, my_tokens, my_kv_tensors); \ + if (status.ok()) { \ + cb(i, my_kv_tensors); \ + } \ + return status; \ + } catch (const std::exception& e) { return Status::IOError(e.what()); } \ } #define WAIT_TASK_RESULTS(TIDS, COUNTER, FIRST_ERROR, OBJ_NAMES) \ @@ -864,12 +868,16 @@ Status AIBrixBlobStorage::GlobalGCFunc() { return; \ } \ LOG(INFO) << #NAME " started"; \ - Status status = self->NAME##Func(); \ - if (!status.ok()) { \ - LOG(ERROR) << #NAME " failed: " << status.ToString(); \ - /* Not a fatal error and wait for next time */ \ - } else { \ - LOG(INFO) << #NAME " completed"; \ + try { \ + Status status = self->NAME##Func(); \ + if (!status.ok()) { \ + LOG(ERROR) << #NAME " failed: " << status.ToString(); \ + /* Not a fatal error and wait for next time */ \ + } else { \ + LOG(INFO) << #NAME " completed"; \ + } \ + } catch (const std::exception& e) { \ + LOG(ERROR) << #NAME " failed: " << e.what(); \ } \ last_time = std::chrono::duration_cast( \ std::chrono::system_clock::now().time_since_epoch()) \