From 4ea0c2f8d95470b98fec5ee66132964d77581247 Mon Sep 17 00:00:00 2001 From: Raahul Kalyaan Jakka Date: Tue, 9 Dec 2025 01:20:44 -0800 Subject: [PATCH] Adding returnKVTensorMetaData flag to Staging Read Strategy (#5200) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2196 **Context:** When using SSD Offloaded Tensors during Training, we create a Partially Materialized Tensor (i.e., tensor-like) structure to be used for post-training (checkpointing and publishing) In this diff, we add a new flag "returnKVTensorMetaData" to the Staging Read Strategy, which helps us to read only the metadata of the KVTensor and not the KVTensor itself Reviewed By: chunzhao Differential Revision: D86634084 --- .../ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp | 6 ++++++ .../ssd_split_table_batched_embeddings.cpp | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp index 2a91a1faa2..de1e78e9e6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp @@ -108,4 +108,10 @@ std::string KVTensorWrapper::layout_str() { oss << options_.layout(); return oss.str(); } + +std::vector KVTensorWrapper::get_kvtensor_serializable_metadata() + const { + FBEXCEPTION("Not implemented"); + return std::vector{}; +} } // namespace ssd diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 541ae73c9d..24afc776ed 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -415,6 +415,10 @@ std::string KVTensorWrapper::serialize() const { std::vector KVTensorWrapper::get_kvtensor_serializable_metadata() const { std::vector metadata; + // Return empty metadata if checkpoint_handle_ is not initialized yet + if (checkpoint_handle_ == nullptr) { + return metadata; + } auto* db = dynamic_cast(db_.get()); auto checkpoint_paths = db->get_checkpoints(checkpoint_handle_->uuid); metadata.push_back(std::to_string(checkpoint_paths.size()));