From 2afb7123bf3903ab01af53edd4b213eb6659aa97 Mon Sep 17 00:00:00 2001 From: zhouyuhan Date: Sun, 26 Apr 2026 16:21:03 +0800 Subject: [PATCH 01/49] chore(byte-artifact): add verbose logs for batch_set observability --- .../controllers/byte_artifact_controller.cc | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/daemon/service/controllers/byte_artifact_controller.cc b/daemon/service/controllers/byte_artifact_controller.cc index 15acb096..f68f12f2 100644 --- a/daemon/service/controllers/byte_artifact_controller.cc +++ b/daemon/service/controllers/byte_artifact_controller.cc @@ -1732,8 +1732,16 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( std::size_t grpc_fetch_count{0}; std::size_t remote_mirror_count{0}; std::uint64_t remote_mirror_bytes{0}; + std::size_t remote_communicator_source_count{0}; + std::size_t remote_communicator_source_direct_write_count{0}; + std::size_t remote_communicator_source_batched_direct_write_count{0}; + std::size_t remote_full_pack_mirror_items{0}; std::size_t stage_body_count{0}; std::size_t fast_cpu_stage_count{0}; + std::size_t stage_loader_count{0}; + std::size_t stage_local_source_count{0}; + std::size_t stage_p2p_count{0}; + std::size_t stage_local_replica_count{0}; std::size_t reuse_attempt_count{0}; std::size_t reuse_hit_count{0}; std::size_t authority_item_count{0}; @@ -1887,6 +1895,28 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( timing_stats.communicator_open_elapsed += absl::Now() - resolve_started_at; ++timing_stats.communicator_open_count; resolved_it = resolved_batch_sources.emplace(transport_id, std::move(*resolved_or)).first; + const bool source_supports_direct_write = + resolved_it->second.source != nullptr && resolved_it->second.source->supports_direct_write_at(); + const bool source_supports_batched_direct_write = + resolved_it->second.source != nullptr && resolved_it->second.source->supports_batched_direct_write_at(); + if (resolved_it->second.remote) { + ++timing_stats.remote_communicator_source_count; + if (source_supports_direct_write) { + ++timing_stats.remote_communicator_source_direct_write_count; + } + if (source_supports_batched_direct_write) { + ++timing_stats.remote_communicator_source_batched_direct_write_count; + } + } + LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_open" + << " operation_id=" << operation_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " remote=" << resolved_it->second.remote + << " item_count=" << transport_it->second->manifest().entries_size() + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " source_direct_write_at=" << source_supports_direct_write + << " source_batched_direct_write_at=" << source_supports_batched_direct_write + << " resolve_ms=" << absl::ToDoubleMilliseconds(absl::Now() - resolve_started_at); } source_kind = resolved_it->second.remote ? store::loading::MaterializationSource::kP2P : store::loading::MaterializationSource::kLocalReplica; @@ -1907,14 +1937,22 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( timing_stats.remote_mirror_elapsed += mirror_elapsed; ++timing_stats.remote_mirror_count; timing_stats.remote_mirror_bytes += transport_it->second->manifest().total_size(); + timing_stats.remote_full_pack_mirror_items += transport_it->second->manifest().entries_size(); mirrored_it = mirrored_remote_batch_payloads.emplace(transport_id, std::move(*mirrored_or)).first; + const bool source_supports_direct_write = + resolved_it->second.source != nullptr && resolved_it->second.source->supports_direct_write_at(); + const bool source_supports_batched_direct_write = + resolved_it->second.source != nullptr && resolved_it->second.source->supports_batched_direct_write_at(); LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_mirror" << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" << " remote=true" << " read_mode=full_pack" + << " realization=full_pack_mirror" << " payload_bytes=" << transport_it->second->manifest().total_size() << " item_count=" << transport_it->second->manifest().entries_size() + << " source_direct_write_at=" << source_supports_direct_write + << " source_batched_direct_write_at=" << source_supports_batched_direct_write << " mirror_ms=" << absl::ToDoubleMilliseconds(mirror_elapsed) << " subsequent_item_slices_local=true"; } @@ -2092,6 +2130,16 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( .local_source = std::move(prepared_item.local_source), .source_kind = prepared_item.source_kind, }); + if (stage_work_items.back().local_source.has_value()) { + ++timing_stats.stage_local_source_count; + } else { + ++timing_stats.stage_loader_count; + } + if (stage_work_items.back().source_kind == store::loading::MaterializationSource::kP2P) { + ++timing_stats.stage_p2p_count; + } else { + ++timing_stats.stage_local_replica_count; + } } const auto run_stage_work_item = [&](StageWorkItem stage_work_item) -> StageWorkItemResult { @@ -2284,10 +2332,20 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( << " payload_ref_items=" << timing_stats.payload_ref_items << " communicator_open_count=" << timing_stats.communicator_open_count << " grpc_fetch_count=" << timing_stats.grpc_fetch_count + << " remote_communicator_source_count=" << timing_stats.remote_communicator_source_count + << " remote_communicator_source_direct_write_count=" + << timing_stats.remote_communicator_source_direct_write_count + << " remote_communicator_source_batched_direct_write_count=" + << timing_stats.remote_communicator_source_batched_direct_write_count << " remote_mirror_count=" << timing_stats.remote_mirror_count << " remote_mirror_bytes=" << timing_stats.remote_mirror_bytes + << " remote_full_pack_mirror_items=" << timing_stats.remote_full_pack_mirror_items << " stage_body_count=" << timing_stats.stage_body_count << " fast_cpu_stage_count=" << timing_stats.fast_cpu_stage_count + << " stage_loader_count=" << timing_stats.stage_loader_count + << " stage_local_source_count=" << timing_stats.stage_local_source_count + << " stage_p2p_count=" << timing_stats.stage_p2p_count + << " stage_local_replica_count=" << timing_stats.stage_local_replica_count << " stage_body_wall_ms=" << absl::ToDoubleMilliseconds(stage_wall_elapsed) << " reuse_attempt_count=" << timing_stats.reuse_attempt_count << " reuse_hit_count=" << timing_stats.reuse_hit_count @@ -4451,8 +4509,12 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( std::size_t remote_batch_transport_grpc_count{0}; std::size_t remote_batch_transport_item_count{0}; std::uint64_t remote_batch_transport_bytes{0}; + std::size_t remote_batch_pack_count{0}; + std::size_t remote_batch_pack_item_count{0}; + std::uint64_t remote_batch_pack_bytes{0}; absl::Duration local_stage_elapsed{absl::ZeroDuration()}; absl::Duration remote_stage_elapsed{absl::ZeroDuration()}; + absl::Duration remote_batch_pack_elapsed{absl::ZeroDuration()}; absl::Duration local_home_apply_elapsed{absl::ZeroDuration()}; absl::Duration remote_home_rpc_elapsed{absl::ZeroDuration()}; } stats; @@ -4961,6 +5023,10 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( << " requested_items=" << task.batch.items.size() << " local_stage_ms=" << absl::ToDoubleMilliseconds(task_stats.local_stage_elapsed) << " remote_stage_ms=" << absl::ToDoubleMilliseconds(task_stats.remote_stage_elapsed) + << " remote_batch_pack_count=" << task_stats.remote_batch_pack_count + << " remote_batch_pack_items=" << task_stats.remote_batch_pack_item_count + << " remote_batch_pack_bytes=" << task_stats.remote_batch_pack_bytes + << " remote_batch_pack_ms=" << absl::ToDoubleMilliseconds(task_stats.remote_batch_pack_elapsed) << " local_home_apply_ms=" << absl::ToDoubleMilliseconds(task_stats.local_home_apply_elapsed) << " remote_home_rpc_ms=" << absl::ToDoubleMilliseconds(task_stats.remote_home_rpc_elapsed) << " total_ms=" << absl::ToDoubleMilliseconds(task_result.total_elapsed); @@ -5021,14 +5087,32 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( batch_entry_outcome_indices.push_back(pending.outcome_index); } if (!batch_entries.empty()) { + const absl::Time pack_started_at = absl::Now(); auto packs_or = pack_batch_payload_entries( batch_entries, d_.payload_transport_broker.max_batch_payload_bytes(), d_.payload_transport_broker.max_batch_items()); + const absl::Duration pack_elapsed = absl::Now() - pack_started_at; + task_stats.remote_batch_pack_elapsed += pack_elapsed; if (!packs_or.ok()) { LOG(WARNING) << "batch_put_if_absent_from_region batch transport fallback to payload_ref: " << packs_or.status(); } else { + std::uint64_t packed_bytes = 0; + std::size_t packed_items = 0; + for (const auto& pack : *packs_or) { + packed_bytes += pack.manifest.total_size(); + packed_items += pack.source_indices.size(); + } + task_stats.remote_batch_pack_count += packs_or->size(); + task_stats.remote_batch_pack_item_count += packed_items; + task_stats.remote_batch_pack_bytes += packed_bytes; + LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_pack_realization" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " remote=true" + << " mode=staged_slab" + << " pack_count=" << packs_or->size() << " item_count=" << packed_items + << " payload_bytes=" << packed_bytes << " pack_ms=" << absl::ToDoubleMilliseconds(pack_elapsed); for (auto& pack : *packs_or) { const bool use_communicator_transport = peer_transport_support.supports_v2() && d_.payload_transport_broker.batch_transport_communicator_enabled(); @@ -5129,6 +5213,7 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( << " operation_id=" << operation_id << " shard_id=" << task.shard_id << " holder_daemon_id=" << task.route.holder_daemon_id << " transport_id=" << transport_id << " kind=" << (emitted_communicator_transport ? "communicator_source" : "grpc_chunk_ref") + << " source_realization_mode=staged_slab" << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size(); @@ -5363,8 +5448,12 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( stats.remote_batch_transport_grpc_count += delta.remote_batch_transport_grpc_count; stats.remote_batch_transport_item_count += delta.remote_batch_transport_item_count; stats.remote_batch_transport_bytes += delta.remote_batch_transport_bytes; + stats.remote_batch_pack_count += delta.remote_batch_pack_count; + stats.remote_batch_pack_item_count += delta.remote_batch_pack_item_count; + stats.remote_batch_pack_bytes += delta.remote_batch_pack_bytes; stats.local_stage_elapsed += delta.local_stage_elapsed; stats.remote_stage_elapsed += delta.remote_stage_elapsed; + stats.remote_batch_pack_elapsed += delta.remote_batch_pack_elapsed; stats.local_home_apply_elapsed += delta.local_home_apply_elapsed; stats.remote_home_rpc_elapsed += delta.remote_home_rpc_elapsed; }; @@ -5427,8 +5516,12 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( << " remote_batch_transport_grpc=" << stats.remote_batch_transport_grpc_count << " remote_batch_transport_items=" << stats.remote_batch_transport_item_count << " remote_batch_transport_bytes=" << stats.remote_batch_transport_bytes + << " remote_batch_pack_count=" << stats.remote_batch_pack_count + << " remote_batch_pack_items=" << stats.remote_batch_pack_item_count + << " remote_batch_pack_bytes=" << stats.remote_batch_pack_bytes << " local_stage_ms=" << absl::ToDoubleMilliseconds(stats.local_stage_elapsed) << " remote_stage_ms=" << absl::ToDoubleMilliseconds(stats.remote_stage_elapsed) + << " remote_batch_pack_ms=" << absl::ToDoubleMilliseconds(stats.remote_batch_pack_elapsed) << " local_home_apply_ms=" << absl::ToDoubleMilliseconds(stats.local_home_apply_elapsed) << " remote_home_rpc_ms=" << absl::ToDoubleMilliseconds(stats.remote_home_rpc_elapsed); rctx.mark_success(); From 0327b0907824c4775dace667a5e1a3884fa28e86 Mon Sep 17 00:00:00 2001 From: zhouyuhan Date: Sun, 26 Apr 2026 17:48:13 +0800 Subject: [PATCH 02/49] feat(byte-artifact): remove full-pack mirror for RDMA-based batch_set --- daemon/BUILD | 7 + .../controllers/byte_artifact_controller.cc | 193 ++++++--- .../grpc_service_impl_batch_runtime_test.cc | 394 ++++++++++++++++++ ...e-and-routed-byte-artifact-architecture.md | 201 +++++++-- 4 files changed, 695 insertions(+), 100 deletions(-) diff --git a/daemon/BUILD b/daemon/BUILD index 244f00f3..e5cfe650 100644 --- a/daemon/BUILD +++ b/daemon/BUILD @@ -2138,12 +2138,19 @@ cc_test( ":routed_authority_protocol_hdr", ":routed_authority_wire_hdr", "//core/common:artifact_hash_lib", + "//core/communicator:routing_adapter_lib", + "//core/communicator:routing_context_lib", + "//core/communicator:topology_lib", "//core/store:store_engine", "//core/store:testing_global_store_client_stub", "//core/store:testing_recording_global_store_client", "//core/testing:test_helpers", "//proto/tensorcast/daemon/v2:daemon_grpc_cc", + "@abseil-cpp//absl/log:log_entry", + "@abseil-cpp//absl/log:log_sink", + "@abseil-cpp//absl/log:log_sink_registry", "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/synchronization", "@abseil-cpp//absl/types:span", "@catch2//:catch2_main", ], diff --git a/daemon/service/controllers/byte_artifact_controller.cc b/daemon/service/controllers/byte_artifact_controller.cc index f68f12f2..e8acbe90 100644 --- a/daemon/service/controllers/byte_artifact_controller.cc +++ b/daemon/service/controllers/byte_artifact_controller.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -366,8 +367,15 @@ HomeBatchGetResponseShape inspect_home_batch_get_response_shape(const v2::HomeBa class SourceSlice final : public store::loader::SeekableSource { public: - SourceSlice(std::shared_ptr source, std::uint64_t base_offset, std::uint64_t length) - : source_(std::move(source)), base_offset_(base_offset), length_(length) {} + SourceSlice( + std::shared_ptr source, + std::uint64_t base_offset, + std::uint64_t length, + std::shared_ptr source_mutex = nullptr) + : source_(std::move(source)), + source_mutex_(std::move(source_mutex)), + base_offset_(base_offset), + length_(length) {} [[nodiscard]] uint64_t total_bytes() const override { return length_; @@ -386,8 +394,12 @@ class SourceSlice final : public store::loader::SeekableSource { if (offset >= length_ || bytes == 0) { return static_cast(0); } - return source_->read_at( - base_offset_ + offset, dst, static_cast(std::min(bytes, length_ - offset))); + const size_t bounded_bytes = static_cast(std::min(bytes, length_ - offset)); + if (source_mutex_ != nullptr) { + std::lock_guard lock(*source_mutex_); + return source_->read_at(base_offset_ + offset, dst, bounded_bytes); + } + return source_->read_at(base_offset_ + offset, dst, bounded_bytes); } [[nodiscard]] bool supports_direct_write_at() const override { @@ -402,15 +414,17 @@ class SourceSlice final : public store::loader::SeekableSource { if (src_offset >= length_ || bytes == 0) { return static_cast(0); } - return source_->read_into_at( - base_offset_ + src_offset, - dest_va_offset, - static_cast(std::min(bytes, length_ - src_offset)), - grant); + const size_t bounded_bytes = static_cast(std::min(bytes, length_ - src_offset)); + if (source_mutex_ != nullptr) { + std::lock_guard lock(*source_mutex_); + return source_->read_into_at(base_offset_ + src_offset, dest_va_offset, bounded_bytes, grant); + } + return source_->read_into_at(base_offset_ + src_offset, dest_va_offset, bounded_bytes, grant); } private: std::shared_ptr source_; + std::shared_ptr source_mutex_; std::uint64_t base_offset_{0}; std::uint64_t length_{0}; std::uint64_t cursor_{0}; @@ -429,9 +443,28 @@ std::unique_ptr make_loader_from_payload_slice( std::unique_ptr make_loader_from_source_slice( std::shared_ptr source, std::uint64_t offset, - std::uint64_t length) { + std::uint64_t length, + std::shared_ptr source_mutex = nullptr) { return std::make_unique( - std::make_shared(std::move(source), offset, length), length); + std::make_shared(std::move(source), offset, length, std::move(source_mutex)), length); +} + +struct PutRemoteCommunicatorSourceEligibility { + bool remote{false}; + bool source_supports_direct_write{false}; + bool source_supports_batched_direct_write{false}; + bool direct_remote_slice{false}; +}; + +PutRemoteCommunicatorSourceEligibility classify_put_remote_communicator_source( + const PayloadTransportBroker::BatchPayloadSource& source) { + PutRemoteCommunicatorSourceEligibility eligibility; + eligibility.remote = source.remote; + eligibility.source_supports_direct_write = source.source != nullptr && source.source->supports_direct_write_at(); + eligibility.source_supports_batched_direct_write = + source.source != nullptr && source.source->supports_batched_direct_write_at(); + eligibility.direct_remote_slice = eligibility.remote && eligibility.source_supports_direct_write; + return eligibility; } absl::StatusOr> mirror_seekable_source_payload( @@ -1735,6 +1768,9 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( std::size_t remote_communicator_source_count{0}; std::size_t remote_communicator_source_direct_write_count{0}; std::size_t remote_communicator_source_batched_direct_write_count{0}; + std::size_t remote_direct_slice_transport_count{0}; + std::size_t remote_direct_slice_items{0}; + std::uint64_t remote_direct_slice_bytes{0}; std::size_t remote_full_pack_mirror_items{0}; std::size_t stage_body_count{0}; std::size_t fast_cpu_stage_count{0}; @@ -1788,6 +1824,8 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( absl::flat_hash_map resolved_batch_payloads; absl::flat_hash_map resolved_batch_sources; absl::flat_hash_map> mirrored_remote_batch_payloads; + absl::flat_hash_map> remote_direct_source_mutexes; + absl::flat_hash_set direct_remote_batch_payloads; struct PreparedHomeBatchPutItem { const v2::HomeBatchPutIfAbsentItem* item{nullptr}; @@ -1895,79 +1933,103 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( timing_stats.communicator_open_elapsed += absl::Now() - resolve_started_at; ++timing_stats.communicator_open_count; resolved_it = resolved_batch_sources.emplace(transport_id, std::move(*resolved_or)).first; - const bool source_supports_direct_write = - resolved_it->second.source != nullptr && resolved_it->second.source->supports_direct_write_at(); - const bool source_supports_batched_direct_write = - resolved_it->second.source != nullptr && resolved_it->second.source->supports_batched_direct_write_at(); - if (resolved_it->second.remote) { + const auto eligibility = classify_put_remote_communicator_source(resolved_it->second); + if (eligibility.remote) { ++timing_stats.remote_communicator_source_count; - if (source_supports_direct_write) { + if (eligibility.source_supports_direct_write) { ++timing_stats.remote_communicator_source_direct_write_count; } - if (source_supports_batched_direct_write) { + if (eligibility.source_supports_batched_direct_write) { ++timing_stats.remote_communicator_source_batched_direct_write_count; } } LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_open" << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" - << " remote=" << resolved_it->second.remote + << " remote=" << eligibility.remote << " item_count=" << transport_it->second->manifest().entries_size() << " payload_bytes=" << transport_it->second->manifest().total_size() - << " source_direct_write_at=" << source_supports_direct_write - << " source_batched_direct_write_at=" << source_supports_batched_direct_write + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write << " resolve_ms=" << absl::ToDoubleMilliseconds(absl::Now() - resolve_started_at); } source_kind = resolved_it->second.remote ? store::loading::MaterializationSource::kP2P : store::loading::MaterializationSource::kLocalReplica; if (resolved_it->second.remote) { - auto mirrored_it = mirrored_remote_batch_payloads.find(transport_id); - if (mirrored_it == mirrored_remote_batch_payloads.end()) { - const absl::Time mirror_started_at = absl::Now(); - auto mirrored_or = mirror_seekable_source_payload( - resolved_it->second.source, transport_it->second->manifest().total_size()); - if (!mirrored_or.ok()) { + const auto eligibility = classify_put_remote_communicator_source(resolved_it->second); + if (eligibility.direct_remote_slice) { + if (direct_remote_batch_payloads.emplace(transport_id).second) { + ++timing_stats.remote_direct_slice_transport_count; + timing_stats.remote_direct_slice_bytes += transport_it->second->manifest().total_size(); + timing_stats.remote_direct_slice_items += transport_it->second->manifest().entries_size(); + LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_read_mode" + << " operation_id=" << operation_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " remote=true" + << " read_mode=direct_remote_slice" + << " realization=source_slice_loader" + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " item_count=" << transport_it->second->manifest().entries_size() + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write + << " mirror_ms=0" + << " subsequent_item_slices_local=false"; + } + auto& source_mutex = remote_direct_source_mutexes[transport_id]; + if (source_mutex == nullptr) { + source_mutex = std::make_shared(); + } + loader = make_loader_from_source_slice( + resolved_it->second.source, + item.batch_payload_slice().offset(), + item.batch_payload_slice().length(), + source_mutex); + } else { + auto mirrored_it = mirrored_remote_batch_payloads.find(transport_id); + if (mirrored_it == mirrored_remote_batch_payloads.end()) { + const absl::Time mirror_started_at = absl::Now(); + auto mirrored_or = mirror_seekable_source_payload( + resolved_it->second.source, transport_it->second->manifest().total_size()); + if (!mirrored_or.ok()) { + deferred_outcomes[index] = make_outcome( + artifact_id, + batch_item_status_from_absl_status(mirrored_or.status()), + std::string(mirrored_or.status().message())); + continue; + } + const absl::Duration mirror_elapsed = absl::Now() - mirror_started_at; + timing_stats.remote_mirror_elapsed += mirror_elapsed; + ++timing_stats.remote_mirror_count; + timing_stats.remote_mirror_bytes += transport_it->second->manifest().total_size(); + timing_stats.remote_full_pack_mirror_items += transport_it->second->manifest().entries_size(); + mirrored_it = mirrored_remote_batch_payloads.emplace(transport_id, std::move(*mirrored_or)).first; + LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_mirror" + << " operation_id=" << operation_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " remote=true" + << " read_mode=full_pack" + << " realization=full_pack_mirror" + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " item_count=" << transport_it->second->manifest().entries_size() + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write + << " mirror_ms=" << absl::ToDoubleMilliseconds(mirror_elapsed) + << " subsequent_item_slices_local=true"; + } + if (item.batch_payload_slice().offset() + item.batch_payload_slice().length() > + mirrored_it->second->size()) { deferred_outcomes[index] = make_outcome( - artifact_id, - batch_item_status_from_absl_status(mirrored_or.status()), - std::string(mirrored_or.status().message())); + artifact_id, v2::BATCH_ITEM_STATUS_FAILED_PRECONDITION, "batch transport payload is truncated"); continue; } - const absl::Duration mirror_elapsed = absl::Now() - mirror_started_at; - timing_stats.remote_mirror_elapsed += mirror_elapsed; - ++timing_stats.remote_mirror_count; - timing_stats.remote_mirror_bytes += transport_it->second->manifest().total_size(); - timing_stats.remote_full_pack_mirror_items += transport_it->second->manifest().entries_size(); - mirrored_it = mirrored_remote_batch_payloads.emplace(transport_id, std::move(*mirrored_or)).first; - const bool source_supports_direct_write = - resolved_it->second.source != nullptr && resolved_it->second.source->supports_direct_write_at(); - const bool source_supports_batched_direct_write = - resolved_it->second.source != nullptr && resolved_it->second.source->supports_batched_direct_write_at(); - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_mirror" - << " operation_id=" << operation_id << " transport_id=" << transport_id - << " kind=communicator_source" - << " remote=true" - << " read_mode=full_pack" - << " realization=full_pack_mirror" - << " payload_bytes=" << transport_it->second->manifest().total_size() - << " item_count=" << transport_it->second->manifest().entries_size() - << " source_direct_write_at=" << source_supports_direct_write - << " source_batched_direct_write_at=" << source_supports_batched_direct_write - << " mirror_ms=" << absl::ToDoubleMilliseconds(mirror_elapsed) - << " subsequent_item_slices_local=true"; - } - if (item.batch_payload_slice().offset() + item.batch_payload_slice().length() > mirrored_it->second->size()) { - deferred_outcomes[index] = make_outcome( - artifact_id, v2::BATCH_ITEM_STATUS_FAILED_PRECONDITION, "batch transport payload is truncated"); - continue; + local_source = BodyBackingManager::LocalByteSpan{ + .owner = std::shared_ptr( + mirrored_it->second, static_cast(mirrored_it->second->data())), + .data = reinterpret_cast(mirrored_it->second->data()) + + item.batch_payload_slice().offset(), + .size_bytes = item.batch_payload_slice().length(), + }; } - local_source = BodyBackingManager::LocalByteSpan{ - .owner = std::shared_ptr( - mirrored_it->second, static_cast(mirrored_it->second->data())), - .data = reinterpret_cast(mirrored_it->second->data()) + - item.batch_payload_slice().offset(), - .size_bytes = item.batch_payload_slice().length(), - }; } else { loader = make_loader_from_source_slice( resolved_it->second.source, item.batch_payload_slice().offset(), item.batch_payload_slice().length()); @@ -2337,6 +2399,9 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( << timing_stats.remote_communicator_source_direct_write_count << " remote_communicator_source_batched_direct_write_count=" << timing_stats.remote_communicator_source_batched_direct_write_count + << " remote_direct_slice_transport_count=" << timing_stats.remote_direct_slice_transport_count + << " remote_direct_slice_items=" << timing_stats.remote_direct_slice_items + << " remote_direct_slice_bytes=" << timing_stats.remote_direct_slice_bytes << " remote_mirror_count=" << timing_stats.remote_mirror_count << " remote_mirror_bytes=" << timing_stats.remote_mirror_bytes << " remote_full_pack_mirror_items=" << timing_stats.remote_full_pack_mirror_items diff --git a/daemon/service/grpc_service_impl_batch_runtime_test.cc b/daemon/service/grpc_service_impl_batch_runtime_test.cc index 34364b3f..2ca5c559 100644 --- a/daemon/service/grpc_service_impl_batch_runtime_test.cc +++ b/daemon/service/grpc_service_impl_batch_runtime_test.cc @@ -12,13 +12,20 @@ #include #include +#include "absl/log/log_entry.h" +#include "absl/log/log_sink.h" +#include "absl/log/log_sink_registry.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "core/common/artifact_hash.h" #include "core/common/artifact_identity.h" +#include "core/communicator/routing/adapter.h" +#include "core/communicator/routing/routing_context.h" +#include "core/communicator/topology/topology.h" #include "core/cuda/cuda_api.h" #include "core/cuda/cuda_ipc.h" #include "core/store/components/endpoint_id.h" @@ -40,6 +47,17 @@ namespace { +using tensorcast::communicator::routing::EndpointBinding; +using tensorcast::communicator::routing::PcieAdapter; +using tensorcast::communicator::routing::RoutingContext; +using tensorcast::communicator::topology::Endpoint; +using tensorcast::communicator::topology::EndpointKind; +using tensorcast::communicator::topology::EndpointType; +using tensorcast::communicator::topology::Link; +using tensorcast::communicator::topology::LinkType; +using tensorcast::communicator::topology::Pool; +using tensorcast::communicator::topology::PoolType; +using tensorcast::communicator::topology::Topology; using tensorcast::daemon::DaemonOptions; using tensorcast::daemon::DaemonServiceHarness; using tensorcast::daemon::v2::BatchExistsRequest; @@ -136,6 +154,137 @@ class WorkerDirectoryTestGlobalStoreClient final : public tensorcast::store::tes } }; +class CollectingLogSink final : public absl::LogSink { + public: + void Send(const absl::LogEntry& entry) override { + absl::MutexLock lock(&mu_); + messages_.push_back(std::string(entry.text_message())); + } + + bool Contains(std::string_view needle) const { + absl::MutexLock lock(&mu_); + for (const auto& message : messages_) { + if (message.find(needle) != std::string::npos) { + return true; + } + } + return false; + } + + bool ContainsAll(std::initializer_list needles) const { + absl::MutexLock lock(&mu_); + for (const auto needle : needles) { + bool found = false; + for (const auto& message : messages_) { + if (message.find(needle) != std::string::npos) { + found = true; + break; + } + } + if (!found) { + return false; + } + } + return true; + } + + private: + mutable absl::Mutex mu_; + std::vector messages_ ABSL_GUARDED_BY(mu_); +}; + +class ScopedCollectingLogSink { + public: + explicit ScopedCollectingLogSink(CollectingLogSink& sink) : sink_(sink) { + absl::AddLogSink(&sink_); + } + + ~ScopedCollectingLogSink() { + absl::RemoveLogSink(&sink_); + } + + ScopedCollectingLogSink(const ScopedCollectingLogSink&) = delete; + ScopedCollectingLogSink& operator=(const ScopedCollectingLogSink&) = delete; + + private: + CollectingLogSink& sink_; +}; + +static Topology make_pcie_batch_payload_topology( + std::string_view local_endpoint_id, + std::string_view remote_endpoint_id) { + std::vector pools; + pools.push_back(Pool{"cpu0", "cpu0", PoolType::kCpu}); + pools.push_back(Pool{"gpu0", "gpu0", PoolType::kGpu}); + pools.push_back(Pool{"gpu1", "gpu1", PoolType::kGpu}); + + std::vector endpoints; + Endpoint local; + local.id = std::string(local_endpoint_id); + local.name = local.id; + local.kind = EndpointKind::kClient; + local.type = EndpointType::kPcie; + local.pool_ids = {"cpu0", "gpu0"}; + endpoints.push_back(std::move(local)); + + Endpoint remote; + remote.id = std::string(remote_endpoint_id); + remote.name = remote.id; + remote.kind = EndpointKind::kClient; + remote.type = EndpointType::kPcie; + remote.pool_ids = {"cpu0", "gpu1"}; + endpoints.push_back(std::move(remote)); + + std::vector links; + Link link; + link.id = absl::StrCat(local_endpoint_id, "_to_", remote_endpoint_id); + link.name = link.id; + link.type = LinkType::kP2P; + link.src_endpoint_id = std::string(local_endpoint_id); + link.dst_endpoint_id = std::string(remote_endpoint_id); + links.push_back(std::move(link)); + + auto topology_or = Topology::Build( + std::move(pools), + std::move(endpoints), + std::move(links), + {.require_endpoint_links = true, .require_connected = false}); + INFO(topology_or.status().message()); + REQUIRE(topology_or.ok()); + return std::move(*topology_or); +} + +static void install_pcie_batch_payload_routing_context( + const std::shared_ptr& comm_manager, + std::string_view local_endpoint_id, + std::string_view remote_endpoint_id, + std::string_view local_host, + uint16_t local_port, + std::string_view remote_host, + uint16_t remote_port) { + auto engine = comm_manager->get_shared_engine(); + auto routing_context = std::make_shared( + RoutingContext::Options{}, engine, nullptr, std::make_shared(engine)); + REQUIRE(routing_context->set_topology(make_pcie_batch_payload_topology(local_endpoint_id, remote_endpoint_id)).ok()); + REQUIRE(routing_context + ->set_endpoint_bindings({ + EndpointBinding{ + .endpoint_id = std::string(local_endpoint_id), + .node_id = "node-pcie-direct", + .ip = std::string(local_host), + .port = local_port, + }, + EndpointBinding{ + .endpoint_id = std::string(remote_endpoint_id), + .node_id = "node-pcie-direct", + .ip = std::string(remote_host), + .port = remote_port, + }, + }) + .ok()); + comm_manager->set_routing_context(std::move(routing_context)); +} + static tensorcast::communicator::v1::CommunicatorConfig make_single_node_communicator_config( int node_index, std::string_view nic_name) { @@ -1660,6 +1809,251 @@ TEST_CASE("HomeBatchPutIfAbsent accepts batch payload slices", "[daemon][batch][ REQUIRE(get_resp.items(1).inline_payload() == payload_b); } +TEST_CASE( + "HomeBatchPutIfAbsent consumes remote direct communicator slices without full-pack mirror", + "[daemon][batch][batch_payload_ref][put][communicator]") { + constexpr std::string_view kSourceDaemonId = "daemon-put-source-direct"; + constexpr std::string_view kHomeDaemonId = "daemon-put-home-direct"; + constexpr std::string_view kHost = "127.0.0.1"; + constexpr std::string_view kLocalEndpoint = "pcie-home-put-direct"; + constexpr std::string_view kRemoteEndpoint = "pcie-source-put-direct"; + + auto source_opts = make_opts_basic(); + source_opts.comm_manager = + make_comm_manager_with_config(make_single_node_communicator_config(/*node_index=*/0, /*nic_name=*/"eth0")); + auto source_engine = std::make_shared(source_opts); + + auto home_opts = make_opts_basic(); + home_opts.comm_manager = + make_comm_manager_with_config(make_single_node_communicator_config(/*node_index=*/1, /*nic_name=*/"eth1")); + install_pcie_batch_payload_routing_context( + home_opts.comm_manager, + kLocalEndpoint, + kRemoteEndpoint, + kHost, + home_opts.comm_manager->listen_port(), + kHost, + source_opts.comm_manager->listen_port()); + auto home_engine = std::make_shared(home_opts); + + auto source_daemon_options = make_daemon_options(); + source_daemon_options.daemon_id = std::string(kSourceDaemonId); + source_daemon_options.storage_path = make_test_storage_root("batch-runtime-put-direct-source"); + auto source = make_harness(source_engine, source_daemon_options); + + auto home_daemon_options = make_daemon_options(); + home_daemon_options.daemon_id = std::string(kHomeDaemonId); + home_daemon_options.storage_path = make_test_storage_root("batch-runtime-put-direct-home"); + auto home = make_harness(home_engine, home_daemon_options); + + const std::string artifact_id_a = make_test_byte_artifact_id("batch-put-direct-a:blk-6a"); + const std::string artifact_id_b = artifact_on_same_shard(artifact_id_a, "batchputdirect"); + const std::string payload_a = "remote-direct-put-alpha"; + const std::string payload_b = "remote-direct-put-beta-more-bytes"; + const std::string slab = payload_a + payload_b; + const std::uint64_t shard_id = shard_for_artifact(artifact_id_a); + + tensorcast::daemon::v2::BatchPayloadManifest manifest; + auto* entry_a = manifest.add_entries(); + entry_a->set_artifact_id(artifact_id_a); + entry_a->set_offset(0); + entry_a->set_length(payload_a.size()); + entry_a->set_digest_alg("sha256"); + entry_a->set_digest_hex(sha256_hex(payload_a)); + auto* entry_b = manifest.add_entries(); + entry_b->set_artifact_id(artifact_id_b); + entry_b->set_offset(payload_a.size()); + entry_b->set_length(payload_b.size()); + entry_b->set_digest_alg("sha256"); + entry_b->set_digest_hex(sha256_hex(payload_b)); + manifest.set_total_size(slab.size()); + + auto export_or = source->kernel().payload_transport_broker().issue_batch_payload_communicator_export( + manifest, + std::make_shared(slab), + tensorcast::common::v1::PAYLOAD_REF_DIRECTION_PUT, + /*operation_id=*/"op-home-batch-put-direct-slice", + absl::Now() + absl::Minutes(1), + kHomeDaemonId); + REQUIRE(export_or.ok()); + + HomeBatchPutIfAbsentRequest put_req; + put_req.mutable_fence()->set_shard_id(shard_id); + put_req.mutable_fence()->set_lease_generation(1); + put_req.mutable_fence()->set_holder_daemon_id(std::string(kHomeDaemonId)); + put_req.mutable_fence()->set_routing_epoch(1); + put_req.set_operation_id("op-home-batch-put-direct-slice"); + auto* transport = put_req.add_batch_transports(); + transport->set_transport_id("transport-put-direct"); + transport->mutable_manifest()->CopyFrom(manifest); + auto* communicator_source = transport->mutable_communicator_source(); + communicator_source->set_batch_payload_ref(export_or->batch_payload_ref); + communicator_source->set_protocol_version(2); + communicator_source->set_producer_daemon_id(std::string(kSourceDaemonId)); + communicator_source->set_consumer_daemon_id(std::string(kHomeDaemonId)); + communicator_source->set_producer_host(std::string(kHost)); + communicator_source->set_producer_port(source_opts.comm_manager->listen_port()); + communicator_source->set_remote_endpoint_id(std::string(kRemoteEndpoint)); + communicator_source->set_local_endpoint_id_hint(std::string(kLocalEndpoint)); + communicator_source->set_memory_location(tensorcast::daemon::v2::BATCH_PAYLOAD_MEMORY_LOCATION_HOST); + communicator_source->set_total_payload_bytes(slab.size()); + for (const auto& memory_key : export_or->export_registration.remote_memory_keys) { + communicator_source->add_remote_memory_keys(memory_key); + } + for (const auto buffer_size : export_or->export_registration.buffer_sizes) { + communicator_source->add_buffer_sizes(buffer_size); + } + + auto* item_a = put_req.add_items(); + item_a->set_artifact_id(artifact_id_a); + set_invariant(item_a->mutable_invariant(), "layout_v1", payload_a); + item_a->mutable_batch_payload_slice()->set_transport_id("transport-put-direct"); + item_a->mutable_batch_payload_slice()->set_offset(0); + item_a->mutable_batch_payload_slice()->set_length(payload_a.size()); + + auto* item_b = put_req.add_items(); + item_b->set_artifact_id(artifact_id_b); + set_invariant(item_b->mutable_invariant(), "layout_v1", payload_b); + item_b->mutable_batch_payload_slice()->set_transport_id("transport-put-direct"); + item_b->mutable_batch_payload_slice()->set_offset(payload_a.size()); + item_b->mutable_batch_payload_slice()->set_length(payload_b.size()); + + CollectingLogSink sink; + HomeBatchPutIfAbsentResponse put_resp; + grpc::ServerContext put_ctx; + { + ScopedCollectingLogSink scoped_sink(sink); + REQUIRE(home->service().HomeBatchPutIfAbsent(&put_ctx, &put_req, &put_resp).ok()); + } + REQUIRE(put_resp.outcomes_size() == 2); + REQUIRE(put_resp.outcomes(0).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + REQUIRE(put_resp.outcomes(1).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + CHECK(sink.Contains( + "byte_artifact.home_batch_put_if_absent_transport_read_mode " + "operation_id=op-home-batch-put-direct-slice")); + CHECK(sink.Contains("read_mode=direct_remote_slice")); + CHECK(sink.Contains("realization=source_slice_loader")); + CHECK_FALSE(sink.Contains( + "byte_artifact.home_batch_put_if_absent_transport_mirror " + "operation_id=op-home-batch-put-direct-slice")); + + HomeBatchGetRequest get_req; + get_req.mutable_fence()->CopyFrom(put_req.fence()); + get_req.add_artifact_ids(artifact_id_a); + get_req.add_artifact_ids(artifact_id_b); + HomeBatchGetResponse get_resp; + grpc::ServerContext get_ctx; + REQUIRE(home->service().HomeBatchGet(&get_ctx, &get_req, &get_resp).ok()); + REQUIRE(get_resp.items_size() == 2); + REQUIRE(get_resp.items(0).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + REQUIRE(get_resp.items(1).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + REQUIRE(get_resp.items(0).inline_payload() == payload_a); + REQUIRE(get_resp.items(1).inline_payload() == payload_b); +} + +TEST_CASE( + "HomeBatchPutIfAbsent keeps full-pack mirror for remote non-direct communicator sources", + "[daemon][batch][batch_payload_ref][put][communicator]") { + constexpr std::string_view kSourceDaemonId = "daemon-put-source-fallback"; + constexpr std::string_view kHomeDaemonId = "daemon-put-home-fallback"; + constexpr std::string_view kSourceHost = "127.0.0.1"; + + auto source_opts = make_opts_basic(); + source_opts.comm_manager = + make_comm_manager_with_config(make_single_node_communicator_config(/*node_index=*/0, /*nic_name=*/"eth0")); + auto source_engine = std::make_shared(source_opts); + + auto home_opts = make_opts_basic(); + home_opts.comm_manager = + make_comm_manager_with_config(make_single_node_communicator_config(/*node_index=*/1, /*nic_name=*/"eth1")); + auto home_engine = std::make_shared(home_opts); + + auto source_daemon_options = make_daemon_options(); + source_daemon_options.daemon_id = std::string(kSourceDaemonId); + source_daemon_options.storage_path = make_test_storage_root("batch-runtime-put-fallback-source"); + auto source = make_harness(source_engine, source_daemon_options); + + auto home_daemon_options = make_daemon_options(); + home_daemon_options.daemon_id = std::string(kHomeDaemonId); + home_daemon_options.storage_path = make_test_storage_root("batch-runtime-put-fallback-home"); + auto home = make_harness(home_engine, home_daemon_options); + + const std::string artifact_id = make_test_byte_artifact_id("batch-put-fallback-a:blk-6a"); + const std::string payload = "remote-fallback-put-payload"; + const std::uint64_t shard_id = shard_for_artifact(artifact_id); + + tensorcast::daemon::v2::BatchPayloadManifest manifest; + auto* entry = manifest.add_entries(); + entry->set_artifact_id(artifact_id); + entry->set_offset(0); + entry->set_length(payload.size()); + entry->set_digest_alg("sha256"); + entry->set_digest_hex(sha256_hex(payload)); + manifest.set_total_size(payload.size()); + + auto export_or = source->kernel().payload_transport_broker().issue_batch_payload_communicator_export( + manifest, + std::make_shared(payload), + tensorcast::common::v1::PAYLOAD_REF_DIRECTION_PUT, + /*operation_id=*/"op-home-batch-put-fallback-mirror", + absl::Now() + absl::Minutes(1), + kHomeDaemonId); + REQUIRE(export_or.ok()); + + HomeBatchPutIfAbsentRequest put_req; + put_req.mutable_fence()->set_shard_id(shard_id); + put_req.mutable_fence()->set_lease_generation(1); + put_req.mutable_fence()->set_holder_daemon_id(std::string(kHomeDaemonId)); + put_req.mutable_fence()->set_routing_epoch(1); + put_req.set_operation_id("op-home-batch-put-fallback-mirror"); + auto* transport = put_req.add_batch_transports(); + transport->set_transport_id("transport-put-fallback"); + transport->mutable_manifest()->CopyFrom(manifest); + auto* communicator_source = transport->mutable_communicator_source(); + communicator_source->set_batch_payload_ref(export_or->batch_payload_ref); + communicator_source->set_protocol_version(2); + communicator_source->set_producer_daemon_id(std::string(kSourceDaemonId)); + communicator_source->set_consumer_daemon_id(std::string(kHomeDaemonId)); + communicator_source->set_producer_host(std::string(kSourceHost)); + communicator_source->set_producer_port(source_opts.comm_manager->listen_port()); + communicator_source->set_remote_endpoint_id( + tensorcast::store::components::derive_endpoint_id( + "node-put-source-fallback", tensorcast::common::memory::MemoryLocation::CPU, /*device_id=*/0)); + communicator_source->set_memory_location(tensorcast::daemon::v2::BATCH_PAYLOAD_MEMORY_LOCATION_HOST); + communicator_source->set_total_payload_bytes(payload.size()); + for (const auto& memory_key : export_or->export_registration.remote_memory_keys) { + communicator_source->add_remote_memory_keys(memory_key); + } + for (const auto buffer_size : export_or->export_registration.buffer_sizes) { + communicator_source->add_buffer_sizes(buffer_size); + } + + auto* item = put_req.add_items(); + item->set_artifact_id(artifact_id); + set_invariant(item->mutable_invariant(), "layout_v1", payload); + item->mutable_batch_payload_slice()->set_transport_id("transport-put-fallback"); + item->mutable_batch_payload_slice()->set_offset(0); + item->mutable_batch_payload_slice()->set_length(payload.size()); + + CollectingLogSink sink; + HomeBatchPutIfAbsentResponse put_resp; + grpc::ServerContext put_ctx; + { + ScopedCollectingLogSink scoped_sink(sink); + REQUIRE(home->service().HomeBatchPutIfAbsent(&put_ctx, &put_req, &put_resp).ok()); + } + REQUIRE(put_resp.outcomes_size() == 1); + INFO(put_resp.outcomes(0).message()); + REQUIRE(put_resp.outcomes(0).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + CHECK(sink.Contains( + "byte_artifact.home_batch_put_if_absent_transport_mirror " + "operation_id=op-home-batch-put-fallback-mirror")); + CHECK(sink.Contains("realization=full_pack_mirror")); + CHECK_FALSE(sink.Contains( + "byte_artifact.home_batch_put_if_absent_transport_read_mode " + "operation_id=op-home-batch-put-fallback-mirror")); +} + TEST_CASE( "HomeBatchGet emits segmented communicator transport without staged batch payload slabs", "[daemon][batch][batch_payload_ref][get][communicator]") { diff --git a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md index 9ad45415..93f8b697 100644 --- a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md +++ b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md @@ -4,7 +4,7 @@ title: Unified Artifact Runtime and Routed Byte Artifact Architecture status: implemented areas: ["daemon", "sdk", "global_store", "proto", "core", "integrations", "docs"] created: 2026-03-08 -last_updated: 2026-04-23 +last_updated: 2026-04-26 related_code: - docs/designs/0017-client-generated-artifact-id.md - docs/designs/0056-programmable-framework-adv.md @@ -90,8 +90,8 @@ authoritative for the artifact value model and routed byte-artifact runtime itse # Current Implementation Snapshot -As of 2026-04-23, the live implementation matches this design, with one -narrower RDMA source-side follow-on still in progress: +As of 2026-04-26, the live implementation matches this design, with narrower +RDMA realization follow-ons still in progress: - `StoreDaemonServiceImpl` delegates `Batch*`, `HomeBatch*`, `FetchPayloadRefChunk`, and `FetchBatchPayloadRefChunk` through controller entrypoints. @@ -114,10 +114,16 @@ narrower RDMA source-side follow-on still in progress: or as a segmented `communicator_source` over retained-body export views for eligible RDMA packs. - Remote `communicator_source` consume paths now diverge by transport - realization: RDMA lowers the remote pack into the shared batched - direct-write path without a mandatory full-pack local mirror, while MTCP and - staged fallback paths still materialize one full local pack payload per - `transport_id` before per-item slicing. + realization on both get and put paths: eligible RDMA lowers remote pack + slices directly from the opened source without a mandatory full-pack local + mirror, while MTCP and staged fallback paths still materialize one full local + pack payload per `transport_id` before per-item slicing. +- `HomeBatchPutIfAbsent` accepts batch transports on the put path and consumes + eligible remote RDMA `communicator_source` transports through per-item + `SourceSlice` loaders staged by `BodyBackingManager::stage_body(...)`. The + first put-side cut intentionally keeps source-side staged-slab pack + construction and per-item home staging in place; it only removes the home + daemon's mandatory remote full-pack mirror. - `BodyHandle` now exposes the export-view API used by source-side segmented communicator export. - The remaining RDMA follow-on is producer-side read-plan servicing: @@ -538,10 +544,13 @@ Normative rules: digest}` tuple. 4. A response or request may carry multiple transports. Batching is per remote-home bucket, and segmentation is still allowed by size, item count, or staging constraints. -5. Current realizations still differ by transport: +5. Current realizations still differ by direction and transport: - `v1 grpc_chunk_ref` materializes one local payload buffer per transport, - - RDMA `v2 communicator_source` already provides a reusable remote slice - loader over one open remote pack through the shared direct-write path, + - get-side RDMA `v2 communicator_source` already provides a reusable + remote slice loader over one open remote pack through the shared + direct-write path, + - put-side RDMA `v2 communicator_source` uses the same no-mirror + remote-slice consume shape for eligible direct-write-capable sources, - MTCP or staged-fallback `v2 communicator_source` paths still mirror one full pack into local host memory before serving subsequent item slices from that local mirror. @@ -631,8 +640,12 @@ Normative rules: - `BatchPayloadCommunicatorSource` is the serializable control-plane descriptor. The broker lowers it into a runtime `RemoteKeySource` backed by the shared communicator rather than using `P2PSource` itself as the wire schema. - The current remote consume path is also transport-dependent: - - RDMA lowers item slices directly from the remote source into the shared - batched direct-write path without a mandatory full-pack local mirror, + - RDMA get lowers item slices directly from the remote source into the + shared batched direct-write path without a mandatory full-pack local + mirror, + - RDMA put lowers item slices directly from the remote source into per-item + `SourceSlice` loaders and home staging without a mandatory full-pack local + mirror, - MTCP and staged fallback paths still mirror one full pack into local host memory before per-item slicing. - The remaining accepted RDMA follow-on below is producer-side servicing: @@ -717,8 +730,12 @@ Put-path rules: 4. The current implementation realizes transport state as one contiguous staged slab per pack. The wire contract does not depend on that choice, but the live controller and broker paths do. 5. The home daemon still verifies each item's invariant and still installs routed join truth per artifact. -6. The current home-daemon consume path reads one full pack per transport, then stages or reuses per-item slices from - that local pack payload before join installation. +6. The home-daemon consume path is transport-dependent. Eligible remote RDMA + `communicator_source` sources are consumed as per-item remote `SourceSlice` + loaders without a local full-pack mirror; non-direct communicator sources + and `grpc_chunk_ref` fallbacks still read one full pack per transport, then + stage or reuse per-item slices from that local pack payload before join + installation. 7. When the consumed pack slice is already available as local host bytes, the home daemon currently prefers a one-pass fast CPU staging path that copies into final backing while computing the required content digest, instead of rereading the bytes through the generic loader pipeline. @@ -728,6 +745,98 @@ Put-path rules: 9. Partial success remains legal. One failed item in a pack must not force unrelated items in the same pack to be reported as semantic failure if their bytes verified and their join succeeded. +#### 5.5.6a Implemented put-side no-mirror consume + +The put path is directionally symmetric with get after authority routing has +selected a remote home: the source daemon owns a logical pack byte space, while +the home daemon is the consumer that stages each successful item into retained +home backing before installing `PUT_IF_ABSENT_JOIN` truth. The first put-side +realization step removes only the home daemon's mandatory remote full-pack +mirror. It does not change source-side pack construction, manifest semantics, +home authority, or per-item join behavior. + +Eligibility rules: + +1. The candidate is a `HomeBatchPutIfAbsent` item with + `batch_payload_slice.transport_id` referring to a `v2 communicator_source` + transport opened with `PAYLOAD_REF_DIRECTION_PUT`. +2. The opened source is remote from the home daemon and resolves to a + non-null `SeekableSource`. +3. First-cut executable eligibility requires + `SeekableSource::supports_direct_write_at() == true`. The + `supports_batched_direct_write_at()` bit is logged for parity with get-side + composite execution and future put-side vectored work, but by itself is not + the first-cut per-item `SourceSlice` eligibility signal unless the source + also advertises scalar `read_into_at(...)` support. +4. Manifest lookup must prove that the item slice exactly matches one + `{artifact_id, offset, length}` entry in the transport manifest before any + loader is constructed. + +Direct consume realization: + +1. The home daemon must not call `mirror_seekable_source_payload(...)` for an + eligible remote source. +2. The home daemon builds a per-item `SourceSlice` over the opened remote + source and wraps it in `SeekableSourceLoader`. +3. The item keeps `source_kind = kP2P` and is staged through + `BodyBackingManager::stage_body(...)`, preserving the existing body + placement, invariant validation, digest, and verified-content output + semantics. +4. The direct-slice loader holds the resolved source through shared ownership + until item staging completes; there is no local full-pack owner buffer and + no `LocalByteSpan` pointing into a mirrored pack. +5. This first cut may still execute one home staging operation per item. It is + not the put-side composite or vectored direct-write phase, and it is not + required to remove all per-item lowering, `StreamingPinnedBuffer`, or + final-backing copy costs. + +Fallback rules: + +1. Remote `communicator_source` paths that do not advertise scalar direct-write + support keep the current `full_pack_mirror` realization. +2. `grpc_chunk_ref` keeps the current full-pack fetch realization. +3. Same-daemon or otherwise local `communicator_source` paths keep local source + slicing and do not count as remote no-mirror RDMA validation. +4. Fallback must be pre-issue. After a direct-slice staging attempt has begun, + the controller must not silently retry the same item through full-pack + mirror because the staging path may already have performed partial writes + into transient backing. Such failures remain current-operation item + failures and must not rewrite routed truth. + +Correctness and authority rules: + +1. `HomeBatchPutIfAbsent` remains the only authority point that installs + first-writer routed truth. Direct-slice consumption is only a byte-movement + realization. +2. A failed transport open or manifest validation failure affects only items + that reference the invalid transport or slice, subject to existing + batch-scoped capability validation rules. +3. If direct-slice staging fails for one item, the item must not be passed to + `ByteArtifactAuthorityService::batch_put_if_absent(...)`; unrelated items + in the same request or different transports may still complete. +4. Digest and verification behavior remains item-scoped. The home daemon still + validates the item invariant against the staged body descriptor before join + installation. +5. Implementation may serialize or parallelize direct-slice staging per + transport. The semantic contract requires explicit-offset reads and correct + lifetime ownership, not a particular scheduling order. + +Observability rules: + +1. Opening a put-side remote `communicator_source` must continue to log + source direct-write and batched direct-write capability. +2. Eligible direct-slice consumption must emit a read-mode log such as + `byte_artifact.home_batch_put_if_absent_transport_read_mode` with + `read_mode=direct_remote_slice`, `realization=source_slice_loader`, + `mirror_ms=0`, and `subsequent_item_slices_local=false`. +3. Fallback mirror consumption must continue to emit + `byte_artifact.home_batch_put_if_absent_transport_mirror` with + `realization=full_pack_mirror`. +4. The home summary must distinguish remote direct-slice transports and items + from remote full-pack mirror transports and bytes so benchmark analysis can + prove that `remote_mirror_count` and `remote_mirror_bytes` collapse to zero + on the intended RDMA put path. + #### 5.5.7 Implemented v2 communicator-backed realization `v2 communicator_source` is the current communicator-backed realization. It moves routed byte-artifact remote transport @@ -761,15 +870,17 @@ Rules: - eligible RDMA get paths may export one logical pack as segmented retained body views, - MTCP-compatible and fallback paths may still realize one staged host pack. -6. Current remote `v2` consume paths open one communicator source per +6. Current get-side remote `v2` consume paths open one communicator source per `transport_id`: - - on RDMA, item slices lower directly from that remote source into the + - on RDMA get, item slices lower directly from that remote source into the shared `0115` batched direct-write path without a mandatory full-pack local mirror, - on MTCP or staged fallback paths, the consumer still mirrors one full pack into local host memory before per-item slicing, - local same-daemon communicator packs can still serve source slices directly. + Put-side home consume matches this no-mirror RDMA shape for eligible + direct-write-capable sources as described in 5.5.6a. 7. Semantic success remains item-scoped and whole-artifact-scoped. `v2` does not introduce sub-artifact success, partial artifact visibility, or sub-artifact digest semantics. 8. `v2` does not require per-pack Global Store transport sessions. Peer addressability and export descriptors come from @@ -850,6 +961,9 @@ Design intent: direct source-to-target writes, - current RDMA get paths already remove the daemon-owned pack slab and the old mandatory sink full-pack mirror on eligible paths, +- put-side `HomeBatchPutIfAbsent` now removes the same mandatory home-daemon + full-pack mirror for eligible RDMA `communicator_source` transports while + keeping source-side staged-slab pack construction in place, - the remaining RDMA bottleneck is producer-side servicing: CPU source slices are still copied from retained backing into pinned staged response buffers before remote reads, @@ -878,66 +992,72 @@ Normative rules: 3. `v1 grpc_chunk_ref` keeps the staged contiguous pack realization. This remains the preferred realization for MTCP or other non-direct-write transports, and it remains a valid per-pack fallback even when protocol version `2` is available. -4. Sink-side RDMA consume paths must not unconditionally mirror a full remote pack into local host memory. When a - remote `communicator_source` lowers to a `SeekableSource` that supports `read_into_at(...)`, the consumer should - lower per-item `SourceSlice` loaders directly from that remote source and let the shared materialization dataplane - choose direct-write execution. -5. Source-side RDMA communicator export should continue to prefer no-pack-copy +4. Sink-side RDMA consume paths must not unconditionally mirror a full remote pack into local host memory. On the get + path, when a remote `communicator_source` lowers to a `SeekableSource` that supports `read_into_at(...)`, the + consumer should lower per-item `SourceSlice` loaders directly from that remote source and let the shared + materialization dataplane choose direct-write execution. +5. Put-side home consume must follow the same no-mirror rule for eligible + remote `communicator_source` sources: `HomeBatchPutIfAbsent` should build + per-item `SourceSlice` loaders over the remote pack and stage those items + through `BodyBackingManager::stage_body(...)` instead of first materializing + a full local pack mirror. +6. Source-side RDMA communicator export should continue to prefer no-pack-copy segmented export over daemon-owned pack slab realization. The producer may expose one logical pack byte space by concatenating per-entry or per-backing exported segments through `remote_memory_keys[]`, `buffer_sizes[]`, and `total_payload_bytes`; it does not need to copy those bytes into one daemon-owned slab first. -6. The next RDMA follow-on is source servicing, not a new sink API. Eligible +7. The next RDMA get-side follow-on is source servicing, not a new sink API. Eligible retained CPU backings should be served as direct-readable source segments in the read-plan response instead of first being copied into pinned staged buffers. -7. The accepted source-side realization seam is `BodyHandle`, as further +8. The accepted source-side realization seam is `BodyHandle`, as further specified by `0089`. `BodyHandle` provides the transport-neutral export-view acquisition API that `PayloadTransportBroker` uses to obtain exportable backing views and keepalive state without reimplementing replica-runtime inspection or export logic. -8. Direct-source RDMA response windows are descriptor-driven, not +9. Direct-source RDMA response windows are descriptor-driven, not staging-driven. They must not consume `FlowCreditLedger`, `StageLease`, or staged ACK-release semantics, and they must not be split merely because staged `buffers_per_flow` credit is exhausted. -9. Direct-source window sizing should instead be bounded by descriptor/control +10. Direct-source window sizing should instead be bounded by descriptor/control limits such as segment count, control payload bytes, and request budgeting. The accepted benchmark target is that one routed transport's `32` source segments fit in one direct-source response window and one sink `read_multi()` call, even if transport realization still posts many RDMA WRs internally. -10. RDMA producer hot paths may optionally retain a publish-time export-view +11. RDMA producer hot paths may optionally retain a publish-time export-view keepalive keyed by backing identity as an optimization hint for later direct-source servicing. This cache is RDMA-only, best-effort, and advisory: missing, expired, or invalidated retained exports must never change authority truth or manifest semantics, and they must not disable request- time export acquisition or staged fallback. -11. Publish-time retained exports must live outside `BackingRecord` snapshots. +12. Publish-time retained exports must live outside `BackingRecord` snapshots. Snapshot copies of backing metadata must not silently extend export lifetime; source-side preregistration is a separate bounded cache over previously acquired `BodyHandle` export views. -12. Publish-time retained-export cache lifetime must be explicitly bounded by +13. Publish-time retained-export cache lifetime must be explicitly bounded by TTL and live-entry/live-byte budgets, and it must be invalidated on backing lifecycle changes such as invalidation, rebind, prune, or replacement. -13. RDMA zero-copy in this design means "no mandatory pack copy, no mandatory +14. RDMA zero-copy in this design means "no mandatory pack copy, no mandatory source-side staging copy, and no mandatory full-pack mirror" when direct- write source and target paths exist. It does not remove item-scoped digests, item-scoped lowering, or per-item success and failure outcomes. -14. If a candidate pack or item cannot produce the required export view, cannot satisfy lifetime requirements, or +15. If a candidate pack or item cannot produce the required export view, cannot satisfy lifetime requirements, or resolves to a non-direct-write transport, the daemon may fall back per pack to the staged contiguous realization and the existing `grpc_chunk_ref` or staged `communicator_source` paths. MTCP-validated behavior must remain available. -15. The intended implementation order is: +16. The intended implementation order is: - sink-side no-mirror consume path first, because the lower dataplane already supports direct-write remote sources, - `BodyHandle` export-view API second, - source-side no-pack segmented communicator export third, - source-side direct-readable RDMA servicing fourth, + - put-side home no-mirror consume as the first batch-set parity step, - and source-side publish-time retained-export warming as an optional follow-on optimization on top of the same `BodyHandle` seam. -16. Session-scoped staging reuse remains complementary follow-on work after this transport-specific split. It must not +17. Session-scoped staging reuse remains complementary follow-on work after this transport-specific split. It must not be used as a reason to keep RDMA on the forced pack-plus-mirror path. -17. Shared composite direct-write and routed vectored pull semantics are +18. Shared composite direct-write and routed vectored pull semantics are defined by `0115`. `0087` owns byte-artifact authority plus the consumer- side and producer-side realization rules that decide whether a routed pack is staged or direct-readable; it does not own a transport-private RDMA sink @@ -1019,10 +1139,15 @@ Rules: - `byte_artifact.batch_get_into_region_apply_plan` - `byte_artifact.batch_get_into_region_transport_apply_summary` - `byte_artifact.batch_get_into_region_summary` + - `byte_artifact.home_batch_put_if_absent_transport_open` + - `byte_artifact.home_batch_put_if_absent_transport_read_mode` - `byte_artifact.home_batch_put_if_absent_transport_mirror` - `byte_artifact.home_batch_put_if_absent_stage_plan` + - `byte_artifact.home_batch_put_if_absent_summary` + - `byte_artifact.batch_put_if_absent_from_region_pack_realization` - `byte_artifact.batch_put_if_absent_from_region_transport_emit` - `byte_artifact.batch_put_if_absent_from_region_home_rpc_result` + - `byte_artifact.put_shard_task_summary` - `byte_artifact.batch_put_if_absent_from_region_summary` - `batch_payload_ref.communicator_export_summary` - `batch_payload_ref.communicator_open_summary` @@ -1345,9 +1470,10 @@ design delta. - `HomeBatch*`, `FetchPayloadRefChunk`, and `FetchBatchPayloadRefChunk` assume a trusted intra-cluster network unless a separate peer-auth layer is introduced, - batch-native transport is now implemented, but the remaining RDMA cost is no - longer daemon-owned pack slabs or mandatory sink full-pack mirrors on the get - path; it is producer-side servicing that still stages retained CPU backing - bytes into pinned response buffers before remote reads, + longer daemon-owned pack slabs or mandatory sink/home full-pack mirrors on + eligible RDMA get/put consume paths; it is producer-side servicing that still + stages retained CPU backing bytes into pinned response buffers before remote + reads, - `v1 grpc_chunk_ref` improves transport shape but still inherits gRPC chunk framing and server or client memory-copy costs; it should be treated as an incremental step rather than the final cross-host performance target, - `v2 communicator_source` reduces that gap but the current realization still pays unnecessary RDMA copy costs on both @@ -1384,6 +1510,9 @@ design delta. repack, source-side staged response copy, or full-pack mirror when the same manifest and per-item semantics can be preserved through direct-write and `BodyHandle`-backed export views. +- Put-side `HomeBatchPutIfAbsent` RDMA `communicator_source` consumption now + follows the same no-mirror remote-slice rule before any larger batch-set + composite or vectored optimization is considered accepted. - `PayloadTransportBroker` remains the transport boundary, but source-side no-copy export must consume the `BodyHandle` export-view seam described by `0089` rather than growing broker-private `StoreEngine` inspection logic. - `GetServerConfig` is the peer-discovery surface for batch-transport protocol version and realization support. From 908628f26aef5d8282894b102a6a53539a9f984a Mon Sep 17 00:00:00 2001 From: zhouyuhan Date: Sun, 26 Apr 2026 19:43:50 +0800 Subject: [PATCH 03/49] feat(byte-artifact): enable composite batch stage into final bodies for HomeBatchPut --- .../runtime/ingestion/ingestion_runtime.cc | 11 + .../runtime/ingestion/ingestion_runtime.h | 8 + .../ingestion/materialization_facade.cc | 166 ++++++++++++ .../ingestion/materialization_facade.h | 21 ++ daemon/service/body_backing_manager.cc | 157 +++++++++++ daemon/service/body_backing_manager.h | 27 ++ .../controllers/byte_artifact_controller.cc | 249 +++++++++++++++++- .../grpc_service_impl_batch_runtime_test.cc | 158 +++++++++++ ...e-and-routed-byte-artifact-architecture.md | 161 ++++++++++- ...terialization-and-vectored-direct-write.md | 67 ++++- 10 files changed, 1007 insertions(+), 18 deletions(-) diff --git a/core/store/runtime/ingestion/ingestion_runtime.cc b/core/store/runtime/ingestion/ingestion_runtime.cc index 07fae62f..e781284f 100644 --- a/core/store/runtime/ingestion/ingestion_runtime.cc +++ b/core/store/runtime/ingestion/ingestion_runtime.cc @@ -89,6 +89,17 @@ absl::StatusOr IngestionRuntime::ingest_mapped_loader_in logical_artifact_id, physical_artifact_id, target_device, target, std::move(loader), mapping, hints, source_kind); } +absl::StatusOr IngestionRuntime:: + ingest_mapped_sources_into_replicas( + std::vector targets, + std::vector> sources, + const loader::ByteRangeMap& mapping, + const loading::MaterializeHints& hints, + loading::MaterializationSource source_kind) { + return materialization_facade_->ingest_mapped_sources_into_replicas( + std::move(targets), std::move(sources), mapping, hints, source_kind); +} + absl::StatusOr IngestionRuntime::execute_artifact_lowering_plan( ingestion::ArtifactLoweringPlan plan) { return materialization_facade_->execute_artifact_lowering_plan(std::move(plan)); diff --git a/core/store/runtime/ingestion/ingestion_runtime.h b/core/store/runtime/ingestion/ingestion_runtime.h index f8d12b3a..6cb6efde 100644 --- a/core/store/runtime/ingestion/ingestion_runtime.h +++ b/core/store/runtime/ingestion/ingestion_runtime.h @@ -91,6 +91,14 @@ class IngestionRuntime { const loading::MaterializeHints& hints, loading::MaterializationSource source_kind); + absl::StatusOr + ingest_mapped_sources_into_replicas( + std::vector targets, + std::vector> sources, + const loader::ByteRangeMap& mapping, + const loading::MaterializeHints& hints, + loading::MaterializationSource source_kind); + absl::StatusOr execute_artifact_lowering_plan( ingestion::ArtifactLoweringPlan plan); diff --git a/core/store/runtime/ingestion/materialization_facade.cc b/core/store/runtime/ingestion/materialization_facade.cc index 59adc29e..cfc967db 100644 --- a/core/store/runtime/ingestion/materialization_facade.cc +++ b/core/store/runtime/ingestion/materialization_facade.cc @@ -5043,6 +5043,172 @@ absl::StatusOr MaterializationFacade::mate target_device, target_layout, std::move(sources), mapping, hints, source_kind); } +absl::StatusOr MaterializationFacade:: + ingest_mapped_sources_into_replicas( + std::vector targets, + std::vector> sources, + const loader::ByteRangeMap& mapping, + const loading::MaterializeHints& hints, + loading::MaterializationSource source_kind) { + if (targets.empty()) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas requires at least one target"); + } + if (sources.empty()) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas requires at least one source"); + } + if (sources.size() < mapping.num_sources) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas sources do not satisfy map.num_sources"); + } + + const DeviceKey target_device = targets.front().target_device; + const common::memory::MemoryLocation target_location = targets.front().target.location.type; + if (target_location != common::memory::MemoryLocation::CPU && + target_location != common::memory::MemoryLocation::GPU) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas requires CPU or GPU targets"); + } + + struct PreparedReplicaTarget { + std::string logical_artifact_id; + loading::ReplicaKey key; + std::shared_ptr replica; + common::memory::MemoryLocation target_location{common::memory::MemoryLocation::NONE}; + std::uint64_t size_bytes{0}; + }; + + auto registry = &config_.replica_runtime->registry(); + std::vector prepared; + prepared.reserve(targets.size()); + loading::IntoTargetLayout target_layout; + target_layout.storages.reserve(targets.size()); + std::uint64_t total_bytes = 0; + + for (const auto& target : targets) { + if (target.logical_artifact_id.empty()) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas requires logical_artifact_id"); + } + if (target.physical_artifact_id.empty()) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas requires physical_artifact_id"); + } + if (target.size_bytes == 0) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas requires non-empty target size"); + } + if (target.target.location.type != target_location) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas target locations must match"); + } + const DeviceKey target_location_device = target.target.location.to_device_key(); + if (target.target_device.type != target_device.type || target.target_device.ordinal != target_device.ordinal || + target.target_device.uuid != target_device.uuid || target_location_device.type != target.target_device.type || + target_location_device.ordinal != target.target_device.ordinal || + target_location_device.uuid != target.target_device.uuid) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas target device mismatch"); + } + if (total_bytes > std::numeric_limits::max() - target.size_bytes) { + return absl::OutOfRangeError("ingest_mapped_sources_into_replicas target size overflow"); + } + + loading::ReplicaKey key{ + .artifact_id = target.physical_artifact_id, + .view_id = std::nullopt, + .device = target.target_device, + .replica = 0, + }; + auto existing_or = registry->find(key); + if (existing_or.ok()) { + return absl::AlreadyExistsError("ingest_mapped_sources_into_replicas target replica already exists"); + } + if (!absl::IsNotFound(existing_or.status())) { + return existing_or.status(); + } + + loading::InlineBufferSource inline_source{.data = nullptr, .size_bytes = target.size_bytes}; + replica::ReplicaConfig cfg{ + .source = inline_source, + .artifact_identifier = key.artifact_id, + .device_type = target.target_device.type, + .local_device_id = target.target_device.type == DeviceType::GPU ? target.target_device.ordinal : -1, + .pinned_buffer_pool = config_.runtime_context->pinned_buffer_pool(), + .async_runtime = gsl::not_null>{config_.runtime_context->async_runtime()}, + .artifact_chunk_bytes = config_.artifact_chunk_bytes, + .expected_artifact_size = target.size_bytes, + .byte_mapping_config = config_.options->byte_mapping, + .materialization_strategy = config_.options->materialization_strategy, + .memory_tier_config = config_.options->memory_tier_config, + }; + cfg.pinned_memory_timeout = hints.pinned_timeout.count() > 0 ? hints.pinned_timeout : config_.pinned_memory_timeout; + cfg.streaming_buffer_chunks = std::max(1, config_.runtime_context->options().streaming_buffer_chunks); + cfg.cpu_shared_memory_enabled = config_.runtime_context->options().cpu_shared_memory_enabled; + + auto replica_or = replica::Replica::create(cfg); + if (!replica_or.ok()) { + return replica_or.status(); + } + auto replica = std::shared_ptr(std::move(replica_or.value())); + auto allocate_status = replica->get_memory_manager().allocate_memory(target_location); + if (!allocate_status.ok()) { + return allocate_status; + } + auto ptrs = replica->get_data_pointer(target_location); + if (ptrs.empty() || ptrs.front() == nullptr) { + return absl::FailedPreconditionError("ingest_mapped_sources_into_replicas target pointer unavailable"); + } + + target_layout.storages.push_back( + loading::IntoTargetStorage{ + .base_ptr = gsl::not_null{ptrs.front()}, + .length = target.size_bytes, + .keepalive = replica, + }); + total_bytes += target.size_bytes; + prepared.push_back( + PreparedReplicaTarget{ + .logical_artifact_id = target.logical_artifact_id, + .key = std::move(key), + .replica = std::move(replica), + .target_location = target_location, + .size_bytes = target.size_bytes, + }); + } + + if (mapping.total_bytes != total_bytes) { + return absl::InvalidArgumentError("ingest_mapped_sources_into_replicas mapping total_bytes mismatch"); + } + target_layout.total_size = total_bytes; + + auto materialize_or = materialize_mapped_sources_into_target( + target_device, target_layout, std::move(sources), mapping, hints, source_kind); + if (!materialize_or.ok()) { + return materialize_or.status(); + } + + std::vector inserted_keys; + inserted_keys.reserve(prepared.size()); + IngestMappedSourcesIntoReplicasResult result; + result.materialize_result = std::move(*materialize_or); + result.replica_handles.reserve(prepared.size()); + for (const auto& entry : prepared) { + auto mark_status = entry.replica->mark_loaded(entry.target_location); + if (!mark_status.ok()) { + for (const auto& inserted_key : inserted_keys) { + (void)registry->erase(inserted_key); + } + return mark_status; + } + entry.replica->set_ready_signal(entry.target_location, absl::OkStatus()); + auto emplace_status = registry->emplace(entry.key, gsl::not_null{entry.replica}); + if (!emplace_status.ok()) { + for (const auto& inserted_key : inserted_keys) { + (void)registry->erase(inserted_key); + } + return emplace_status; + } + inserted_keys.push_back(entry.key); + loading::ReplicaHandle handle = build_local_replica_handle(entry.key, entry.replica, entry.target_location); + handle.source = source_kind; + result.replica_handles.push_back(std::move(handle)); + } + return result; +} + absl::StatusOr MaterializationFacade::execute_artifact_lowering_plan( ArtifactLoweringPlan plan) { auto validation_status = validate_artifact_lowering_plan(plan); diff --git a/core/store/runtime/ingestion/materialization_facade.h b/core/store/runtime/ingestion/materialization_facade.h index ea740edc..33366252 100644 --- a/core/store/runtime/ingestion/materialization_facade.h +++ b/core/store/runtime/ingestion/materialization_facade.h @@ -22,6 +22,7 @@ #include "core/store/materialization/contracts/byte_range/byte_range_map.h" #include "core/store/materialization/control/materialization_backend.h" #include "core/store/materialization/dataplane/contracts/loader.h" +#include "core/store/materialization/dataplane/contracts/source.h" #include "core/store/materialization/runtime/pipeline/ingestion_pipeline.h" #include "core/store/replica/collective_disk_loader.h" #include "core/store/runtime/context/runtime_context.h" @@ -86,6 +87,19 @@ class MaterializationFacade : public materialization::control::MaterializationBa public: using SealProgressCallback = std::function; + struct MappedReplicaTarget { + std::string logical_artifact_id; + std::string physical_artifact_id; + DeviceKey target_device; + loading::ReplicaTarget target; + std::uint64_t size_bytes{0}; + }; + + struct IngestMappedSourcesIntoReplicasResult { + std::vector replica_handles; + loading::MaterializeIntoTargetResult materialize_result; + }; + struct SealAssemblyCutInput { struct BoundCanonicalSpan { int device_id{-1}; @@ -164,6 +178,13 @@ class MaterializationFacade : public materialization::control::MaterializationBa const loading::MaterializeHints& hints, loading::MaterializationSource source_kind); + absl::StatusOr ingest_mapped_sources_into_replicas( + std::vector targets, + std::vector> sources, + const loader::ByteRangeMap& mapping, + const loading::MaterializeHints& hints, + loading::MaterializationSource source_kind); + absl::StatusOr execute_artifact_lowering_plan(ArtifactLoweringPlan plan); absl::StatusOr ingest_from_disk( diff --git a/daemon/service/body_backing_manager.cc b/daemon/service/body_backing_manager.cc index 4aebf839..0813c924 100644 --- a/daemon/service/body_backing_manager.cc +++ b/daemon/service/body_backing_manager.cc @@ -801,6 +801,163 @@ absl::StatusOr BodyBackingManager::stage_body(S return stage_result_or; } +absl::StatusOr BodyBackingManager::stage_bodies_composite( + StageBodiesCompositeRequest request) const { + if (request.source == nullptr) { + return absl::InvalidArgumentError("composite body staging requires source"); + } + if (request.items.empty()) { + return absl::InvalidArgumentError("composite body staging requires at least one item"); + } + if (!request.source->supports_direct_write_at() || !request.source->supports_batched_direct_write_at()) { + return absl::FailedPreconditionError("composite body staging requires batched direct-write source"); + } + + struct PreparedCompositeItem { + const CompositeStageItem* item{nullptr}; + ResolvedStorePolicy resolved_policy; + BodyBackingIntent intent; + std::string physical_artifact_id; + v2::ByteArtifactVerificationMode verification_mode{v2::BYTE_ARTIFACT_VERIFICATION_MODE_UNSPECIFIED}; + }; + + std::vector prepared_items; + prepared_items.reserve(request.items.size()); + std::vector targets; + targets.reserve(request.items.size()); + store::loader::ByteRangeMap mapping; + mapping.num_sources = 1; + mapping.segments.reserve(request.items.size()); + + std::uint64_t target_cursor = 0; + const std::uint64_t source_total_bytes = request.source->total_bytes(); + for (const auto& item : request.items) { + if (item.artifact_id.empty()) { + return absl::InvalidArgumentError("composite body staging requires artifact_id"); + } + if (item.invariant.layout_id().empty()) { + return absl::InvalidArgumentError("composite body staging requires invariant.layout_id"); + } + if (item.length == 0 || item.length != item.invariant.byte_length()) { + return absl::InvalidArgumentError("composite body staging length must match invariant.byte_length"); + } + if (item.source_offset > std::numeric_limits::max() - item.length) { + return absl::OutOfRangeError("composite body staging source range overflow"); + } + if (source_total_bytes != 0 && item.source_offset + item.length > source_total_bytes) { + return absl::OutOfRangeError("composite body staging source range exceeds source size"); + } + if (target_cursor > std::numeric_limits::max() - item.length) { + return absl::OutOfRangeError("composite body staging target range overflow"); + } + + const auto verification_mode = invariant_verification_mode(item.invariant); + if (verification_mode_requires_payload_digest(verification_mode)) { + return absl::FailedPreconditionError("composite body staging supports layout-and-size verification only"); + } + + auto resolved_policy_or = resolve_body_store_policy(item.access_class, item.route_role, item.resolved_store_policy); + if (!resolved_policy_or.ok()) { + return resolved_policy_or.status(); + } + const BodyPlacementContext context = + normalize_placement_context(item.access_class, item.route_role, item.invariant.byte_length()); + const BodyBackingIntent intent = classify_intent(context, *resolved_policy_or); + if (intent.preferred_residency != BodyPreferredResidency::kCpu) { + record_body_backing_metrics("stage_composite", item.access_class, intent, "target_not_cpu"); + return absl::FailedPreconditionError("composite body staging requires CPU final backings"); + } + + const std::string physical_artifact_id = build_body_backing_artifact_id(item.artifact_id, item.invariant); + const auto target_device = resolve_target_device(intent); + targets.push_back( + store::runtime::ingestion::MaterializationFacade::MappedReplicaTarget{ + .logical_artifact_id = item.artifact_id, + .physical_artifact_id = physical_artifact_id, + .target_device = target_device, + .target = build_replica_target(intent), + .size_bytes = item.invariant.byte_length(), + }); + mapping.segments.push_back( + store::loader::ByteRangeSegment{ + .kind = store::loader::ByteRangeSegment::Kind::kData, + .dst_offset = target_cursor, + .length = item.length, + .src_offset = item.source_offset, + .source_index = 0, + }); + prepared_items.push_back( + PreparedCompositeItem{ + .item = &item, + .resolved_policy = std::move(*resolved_policy_or), + .intent = intent, + .physical_artifact_id = physical_artifact_id, + .verification_mode = verification_mode, + }); + target_cursor += item.length; + } + mapping.total_bytes = target_cursor; + if (mapping.total_bytes == 0) { + return absl::InvalidArgumentError("composite body staging requires non-empty total bytes"); + } + + auto hints = build_lowering_hints( + request.transport_id.empty() ? std::string_view("home_batch_put_if_absent_composite") + : std::string_view(request.transport_id), + request.operation_id); + auto materialized_or = engine_.ingestion_runtime().ingest_mapped_sources_into_replicas( + std::move(targets), + std::vector>{std::move(request.source)}, + mapping, + hints, + request.source_kind); + if (!materialized_or.ok()) { + return materialized_or.status(); + } + if (materialized_or->replica_handles.size() != prepared_items.size()) { + for (auto& handle : materialized_or->replica_handles) { + (void)engine_.retire_replica_status(handle.key()); + } + return absl::InternalError("composite body staging returned unexpected replica handle count"); + } + + StageBodiesCompositeResult result; + result.materialize_result = std::move(materialized_or->materialize_result); + result.staged_bodies.reserve(prepared_items.size()); + auto retire_unpublished = [&]() { + for (auto& staged_body : result.staged_bodies) { + (void)staged_body.body_handle.retire(); + } + }; + for (std::size_t index = 0; index < prepared_items.size(); ++index) { + const auto& prepared = prepared_items[index]; + const auto* item = prepared.item; + auto stage_result_or = finalize_staged_replica( + engine_, + "stage_composite", + item->access_class, + prepared.intent, + prepared.resolved_policy, + item->invariant.layout_id(), + prepared.verification_mode, + std::move(materialized_or->replica_handles[index]), + store::runtime::ingestion::BackingIdentity{ + .physical_artifact_id = prepared.physical_artifact_id, + }, + build_invariant_verified_content_descriptor(item->invariant), + build_layout_and_size_verification_record(item->invariant.layout_id(), absl::Now())); + if (!stage_result_or.ok()) { + retire_unpublished(); + for (std::size_t remaining = index + 1; remaining < materialized_or->replica_handles.size(); ++remaining) { + (void)engine_.retire_replica_status(materialized_or->replica_handles[remaining].key()); + } + return stage_result_or.status(); + } + result.staged_bodies.push_back(std::move(*stage_result_or)); + } + return result; +} + absl::StatusOr BodyBackingManager::stage_body_fast_cpu_verified( std::string artifact_id, const v2::PutIfAbsentInvariant& invariant, diff --git a/daemon/service/body_backing_manager.h b/daemon/service/body_backing_manager.h index ae22994c..dd022698 100644 --- a/daemon/service/body_backing_manager.h +++ b/daemon/service/body_backing_manager.h @@ -7,8 +7,10 @@ #include #include #include +#include #include "absl/status/statusor.h" +#include "core/store/materialization/dataplane/contracts/source.h" #include "core/store/store_engine.h" #include "daemon/service/body_backing_types.h" #include "daemon/service/byte_artifact_body_handle.h" @@ -45,6 +47,29 @@ class BodyBackingManager { store::runtime::ingestion::BackingIdentity backing_identity; }; + struct CompositeStageItem { + std::string artifact_id; + v2::PutIfAbsentInvariant invariant; + std::uint64_t source_offset{0}; + std::uint64_t length{0}; + BodyAccessClass access_class{BodyAccessClass::kHomeDefault}; + BodyRouteRole route_role{BodyRouteRole::kHomeAuthority}; + std::optional resolved_store_policy; + }; + + struct StageBodiesCompositeRequest { + std::shared_ptr source; + std::vector items; + store::loading::MaterializationSource source_kind{store::loading::MaterializationSource::kUnspecified}; + std::string operation_id; + std::string transport_id; + }; + + struct StageBodiesCompositeResult { + std::vector staged_bodies; + store::loading::MaterializeIntoTargetResult materialize_result; + }; + struct ReuseRequest { std::string artifact_id; v2::PutIfAbsentInvariant invariant; @@ -59,6 +84,8 @@ class BodyBackingManager { explicit BodyBackingManager(store::StoreEngine& engine); [[nodiscard]] absl::StatusOr stage_body(StageRequest request) const; + [[nodiscard]] absl::StatusOr stage_bodies_composite( + StageBodiesCompositeRequest request) const; [[nodiscard]] absl::StatusOr stage_body_fast_cpu_verified( std::string artifact_id, const v2::PutIfAbsentInvariant& invariant, diff --git a/daemon/service/controllers/byte_artifact_controller.cc b/daemon/service/controllers/byte_artifact_controller.cc index e8acbe90..2588b767 100644 --- a/daemon/service/controllers/byte_artifact_controller.cc +++ b/daemon/service/controllers/byte_artifact_controller.cc @@ -1771,6 +1771,13 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( std::size_t remote_direct_slice_transport_count{0}; std::size_t remote_direct_slice_items{0}; std::uint64_t remote_direct_slice_bytes{0}; + std::size_t remote_composite_stage_transport_count{0}; + std::size_t remote_composite_stage_items{0}; + std::uint64_t remote_composite_stage_bytes{0}; + std::size_t remote_composite_materialize_calls{0}; + std::size_t remote_composite_batched_direct_write_count{0}; + std::size_t remote_composite_fallback_count{0}; + std::size_t remote_composite_fallback_items{0}; std::size_t remote_full_pack_mirror_items{0}; std::size_t stage_body_count{0}; std::size_t fast_cpu_stage_count{0}; @@ -1784,6 +1791,7 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( absl::Duration communicator_open_elapsed{absl::ZeroDuration()}; absl::Duration grpc_fetch_elapsed{absl::ZeroDuration()}; absl::Duration remote_mirror_elapsed{absl::ZeroDuration()}; + absl::Duration remote_composite_stage_elapsed{absl::ZeroDuration()}; absl::Duration stage_body_elapsed{absl::ZeroDuration()}; absl::Duration fast_cpu_stage_elapsed{absl::ZeroDuration()}; absl::Duration reuse_elapsed{absl::ZeroDuration()}; @@ -1827,11 +1835,18 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( absl::flat_hash_map> remote_direct_source_mutexes; absl::flat_hash_set direct_remote_batch_payloads; + struct CompositeStageCandidate { + std::string transport_id; + std::uint64_t source_offset{0}; + std::uint64_t length{0}; + }; + struct PreparedHomeBatchPutItem { const v2::HomeBatchPutIfAbsentItem* item{nullptr}; std::string artifact_id; std::unique_ptr loader; std::optional local_source; + std::optional composite_candidate; store::loading::MaterializationSource source_kind = store::loading::MaterializationSource::kLocalReplica; std::optional staged_body; }; @@ -1890,6 +1905,7 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( std::unique_ptr loader; std::optional local_source; + std::optional composite_candidate; store::loading::MaterializationSource source_kind = store::loading::MaterializationSource::kLocalReplica; std::optional staged_body; if (has_batch_payload_slice) { @@ -1957,7 +1973,16 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( : store::loading::MaterializationSource::kLocalReplica; if (resolved_it->second.remote) { const auto eligibility = classify_put_remote_communicator_source(resolved_it->second); - if (eligibility.direct_remote_slice) { + const bool composite_stage_eligible = eligibility.direct_remote_slice && + eligibility.source_supports_batched_direct_write && + !verification_mode_requires_payload_digest(invariant_verification_mode(item.invariant())); + if (composite_stage_eligible) { + composite_candidate = CompositeStageCandidate{ + .transport_id = transport_id, + .source_offset = item.batch_payload_slice().offset(), + .length = item.batch_payload_slice().length(), + }; + } else if (eligibility.direct_remote_slice) { if (direct_remote_batch_payloads.emplace(transport_id).second) { ++timing_stats.remote_direct_slice_transport_count; timing_stats.remote_direct_slice_bytes += transport_it->second->manifest().total_size(); @@ -2080,7 +2105,7 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( continue; } } - if (loader != nullptr || local_source.has_value()) { + if (loader != nullptr || local_source.has_value() || composite_candidate.has_value()) { // Batch transport already provided a readable source. } else if (!item.inline_payload().empty()) { auto payload = std::make_shared(item.inline_payload()); @@ -2164,11 +2189,223 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( .artifact_id = artifact_id, .loader = std::move(loader), .local_source = std::move(local_source), + .composite_candidate = std::move(composite_candidate), .source_kind = source_kind, .staged_body = std::move(staged_body), }; } + const auto fallback_composite_group_to_direct_slice = + [&](const std::string& transport_id, const std::vector& indices, std::string_view reason) { + ++timing_stats.remote_composite_fallback_count; + timing_stats.remote_composite_fallback_items += indices.size(); + const auto transport_it = batch_transports_by_id.find(transport_id); + const auto resolved_it = resolved_batch_sources.find(transport_id); + if (transport_it == batch_transports_by_id.end() || resolved_it == resolved_batch_sources.end() || + resolved_it->second.source == nullptr || !resolved_it->second.source->supports_direct_write_at()) { + for (const int index : indices) { + deferred_outcomes[index] = make_outcome( + req.items(index).artifact_id(), + v2::BATCH_ITEM_STATUS_FAILED_PRECONDITION, + absl::StrCat("composite fallback unavailable: ", reason)); + } + return; + } + const auto eligibility = classify_put_remote_communicator_source(resolved_it->second); + LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=true" + << " read_mode=direct_remote_slice" + << " materialize_mode=per_item" + << " stage_mode=source_slice_loader" + << " batched_direct_write=false" + << " source_count=1" + << " mapping_segments=0" + << " item_count=" << indices.size() << " item_bytes=0" + << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" + << " fallback_reason=" << reason; + if (direct_remote_batch_payloads.emplace(transport_id).second) { + ++timing_stats.remote_direct_slice_transport_count; + timing_stats.remote_direct_slice_bytes += transport_it->second->manifest().total_size(); + timing_stats.remote_direct_slice_items += transport_it->second->manifest().entries_size(); + LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_read_mode" + << " operation_id=" << operation_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " remote=true" + << " read_mode=direct_remote_slice" + << " realization=source_slice_loader" + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " item_count=" << transport_it->second->manifest().entries_size() + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write + << " mirror_ms=0" + << " subsequent_item_slices_local=false"; + } + auto& source_mutex = remote_direct_source_mutexes[transport_id]; + if (source_mutex == nullptr) { + source_mutex = std::make_shared(); + } + for (const int index : indices) { + if (deferred_outcomes[index].has_value()) { + continue; + } + auto& prepared_item = prepared_items[static_cast(index)]; + if (!prepared_item.composite_candidate.has_value()) { + continue; + } + prepared_item.loader = make_loader_from_source_slice( + resolved_it->second.source, + prepared_item.composite_candidate->source_offset, + prepared_item.composite_candidate->length, + source_mutex); + prepared_item.composite_candidate.reset(); + } + }; + + absl::flat_hash_map> composite_groups; + for (int index = 0; index < req.items_size(); ++index) { + if (deferred_outcomes[index].has_value()) { + continue; + } + const auto& prepared_item = prepared_items[static_cast(index)]; + if (prepared_item.composite_candidate.has_value()) { + composite_groups[prepared_item.composite_candidate->transport_id].push_back(index); + } + } + for (const auto& [transport_id, indices] : composite_groups) { + if (indices.empty()) { + continue; + } + const auto transport_it = batch_transports_by_id.find(transport_id); + const auto resolved_it = resolved_batch_sources.find(transport_id); + if (transport_it == batch_transports_by_id.end() || resolved_it == resolved_batch_sources.end() || + resolved_it->second.source == nullptr) { + for (const int index : indices) { + deferred_outcomes[index] = make_outcome( + req.items(index).artifact_id(), + v2::BATCH_ITEM_STATUS_FAILED_PRECONDITION, + "composite transport source is missing"); + } + continue; + } + absl::flat_hash_set seen_artifacts; + bool has_duplicate = false; + for (const int index : indices) { + if (!seen_artifacts.emplace(prepared_items[static_cast(index)].artifact_id).second) { + has_duplicate = true; + break; + } + } + if (has_duplicate) { + fallback_composite_group_to_direct_slice(transport_id, indices, "duplicate_key"); + continue; + } + + std::vector composite_items; + composite_items.reserve(indices.size()); + std::uint64_t item_bytes = 0; + bool invalid_group = false; + for (const int index : indices) { + const auto& prepared_item = prepared_items[static_cast(index)]; + if (!prepared_item.composite_candidate.has_value() || prepared_item.item == nullptr) { + invalid_group = true; + break; + } + if (item_bytes > std::numeric_limits::max() - prepared_item.composite_candidate->length) { + invalid_group = true; + break; + } + item_bytes += prepared_item.composite_candidate->length; + composite_items.push_back( + BodyBackingManager::CompositeStageItem{ + .artifact_id = prepared_item.artifact_id, + .invariant = prepared_item.item->invariant(), + .source_offset = prepared_item.composite_candidate->source_offset, + .length = prepared_item.composite_candidate->length, + .access_class = BodyAccessClass::kHomeDefault, + .route_role = BodyRouteRole::kHomeAuthority, + }); + } + if (invalid_group) { + fallback_composite_group_to_direct_slice(transport_id, indices, "mapping_invalid"); + continue; + } + + const absl::Time composite_started_at = absl::Now(); + auto composite_or = body_backing_manager_.stage_bodies_composite( + BodyBackingManager::StageBodiesCompositeRequest{ + .source = resolved_it->second.source, + .items = std::move(composite_items), + .source_kind = store::loading::MaterializationSource::kP2P, + .operation_id = std::string(operation_id), + .transport_id = transport_id, + }); + const absl::Duration composite_elapsed = absl::Now() - composite_started_at; + timing_stats.remote_composite_stage_elapsed += composite_elapsed; + if (!composite_or.ok()) { + LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=true" + << " read_mode=batched_direct_write" + << " materialize_mode=single_source_composite" + << " stage_mode=composite_final_body" + << " batched_direct_write=true" + << " source_count=1" + << " mapping_segments=" << indices.size() << " item_count=" << indices.size() + << " item_bytes=" << item_bytes + << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" + << " materialize_ms=" << absl::ToDoubleMilliseconds(composite_elapsed) << " outcome=failed" + << " status=" << composite_or.status(); + for (const int index : indices) { + deferred_outcomes[index] = make_outcome( + prepared_items[static_cast(index)].artifact_id, + batch_item_status_from_absl_status(composite_or.status()), + std::string(composite_or.status().message())); + } + continue; + } + if (composite_or->staged_bodies.size() != indices.size()) { + for (auto& staged_body : composite_or->staged_bodies) { + (void)staged_body.body_handle.retire(); + } + for (const int index : indices) { + deferred_outcomes[index] = make_outcome( + prepared_items[static_cast(index)].artifact_id, + v2::BATCH_ITEM_STATUS_INTERNAL_ERROR, + "composite stage returned unexpected item count"); + } + continue; + } + ++timing_stats.remote_composite_stage_transport_count; + timing_stats.remote_composite_stage_items += indices.size(); + timing_stats.remote_composite_stage_bytes += item_bytes; + ++timing_stats.remote_composite_materialize_calls; + if (composite_or->materialize_result.direct_write_supported) { + ++timing_stats.remote_composite_batched_direct_write_count; + } + LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=true" + << " read_mode=batched_direct_write" + << " materialize_mode=single_source_composite" + << " stage_mode=composite_final_body" + << " batched_direct_write=true" + << " source_count=1" + << " mapping_segments=" << indices.size() << " item_count=" << indices.size() + << " item_bytes=" << item_bytes + << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" + << " materialize_ms=" << absl::ToDoubleMilliseconds(composite_elapsed) + << " direct_write_supported=" << composite_or->materialize_result.direct_write_supported + << " fallback_reason=none"; + for (std::size_t local_index = 0; local_index < indices.size(); ++local_index) { + auto& prepared_item = prepared_items[static_cast(indices[local_index])]; + prepared_item.staged_body = std::move(composite_or->staged_bodies[local_index]); + prepared_item.loader.reset(); + prepared_item.local_source.reset(); + prepared_item.composite_candidate.reset(); + } + } + std::vector stage_work_items; stage_work_items.reserve(prepared_items.size()); for (int index = 0; index < req.items_size(); ++index) { @@ -2402,6 +2639,13 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( << " remote_direct_slice_transport_count=" << timing_stats.remote_direct_slice_transport_count << " remote_direct_slice_items=" << timing_stats.remote_direct_slice_items << " remote_direct_slice_bytes=" << timing_stats.remote_direct_slice_bytes + << " remote_composite_stage_transport_count=" << timing_stats.remote_composite_stage_transport_count + << " remote_composite_stage_items=" << timing_stats.remote_composite_stage_items + << " remote_composite_stage_bytes=" << timing_stats.remote_composite_stage_bytes + << " remote_composite_materialize_calls=" << timing_stats.remote_composite_materialize_calls + << " remote_composite_batched_direct_write_count=" << timing_stats.remote_composite_batched_direct_write_count + << " remote_composite_fallback_count=" << timing_stats.remote_composite_fallback_count + << " remote_composite_fallback_items=" << timing_stats.remote_composite_fallback_items << " remote_mirror_count=" << timing_stats.remote_mirror_count << " remote_mirror_bytes=" << timing_stats.remote_mirror_bytes << " remote_full_pack_mirror_items=" << timing_stats.remote_full_pack_mirror_items @@ -2418,6 +2662,7 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( << " communicator_open_ms=" << absl::ToDoubleMilliseconds(timing_stats.communicator_open_elapsed) << " grpc_fetch_ms=" << absl::ToDoubleMilliseconds(timing_stats.grpc_fetch_elapsed) << " remote_mirror_ms=" << absl::ToDoubleMilliseconds(timing_stats.remote_mirror_elapsed) + << " remote_composite_stage_ms=" << absl::ToDoubleMilliseconds(timing_stats.remote_composite_stage_elapsed) << " stage_body_ms=" << absl::ToDoubleMilliseconds(timing_stats.stage_body_elapsed) << " fast_cpu_stage_ms=" << absl::ToDoubleMilliseconds(timing_stats.fast_cpu_stage_elapsed) << " reuse_ms=" << absl::ToDoubleMilliseconds(timing_stats.reuse_elapsed) diff --git a/daemon/service/grpc_service_impl_batch_runtime_test.cc b/daemon/service/grpc_service_impl_batch_runtime_test.cc index 2ca5c559..b65e9b60 100644 --- a/daemon/service/grpc_service_impl_batch_runtime_test.cc +++ b/daemon/service/grpc_service_impl_batch_runtime_test.cc @@ -3919,6 +3919,164 @@ TEST_CASE("BodyBackingManager fast CPU staging hashes during local byte ingress" REQUIRE(staged_or->body_handle.retire().ok()); } +class RecordingCompositeSource final : public tensorcast::store::loader::SeekableSource { + public: + explicit RecordingCompositeSource(std::string data) : data_(std::move(data)) {} + + [[nodiscard]] uint64_t total_bytes() const override { + return data_.size(); + } + + absl::StatusOr read(void* dst, size_t max_bytes) override { + auto read_or = read_at(cursor_, dst, max_bytes); + if (!read_or.ok()) { + return read_or.status(); + } + cursor_ += *read_or; + return *read_or; + } + + absl::StatusOr read_at(uint64_t offset, void* dst, size_t bytes) override { + if (offset >= data_.size() || bytes == 0) { + return static_cast(0); + } + const size_t to_copy = static_cast(std::min(bytes, data_.size() - offset)); + std::memcpy(dst, data_.data() + offset, to_copy); + return to_copy; + } + + [[nodiscard]] bool supports_direct_write_at() const override { + return true; + } + + [[nodiscard]] bool supports_batched_direct_write_at() const override { + return true; + } + + absl::StatusOr read_into_at( + uint64_t src_offset, + uint64_t dest_va_offset, + size_t bytes, + const tensorcast::store::DirectWriteGrant& grant) override { + if (src_offset > data_.size() || bytes > data_.size() - src_offset) { + return absl::OutOfRangeError("source read exceeds source data"); + } + size_t copied = 0; + uint64_t cursor = dest_va_offset; + while (copied < bytes) { + bool matched = false; + for (const auto& window : grant.windows) { + if (cursor < window.va_offset || cursor >= window.va_offset + window.length) { + continue; + } + const uint64_t window_offset = cursor - window.va_offset; + const size_t take = static_cast(std::min(bytes - copied, window.length - window_offset)); + std::memcpy( + reinterpret_cast(window.local_addr + window_offset), data_.data() + src_offset + copied, take); + copied += take; + cursor += take; + matched = true; + break; + } + if (!matched) { + return absl::OutOfRangeError("direct write grant does not cover destination range"); + } + } + return copied; + } + + absl::StatusOr readv_into_at( + absl::Span ops, + const tensorcast::store::DirectWriteGrant& grant) override { + ++readv_calls_; + size_t total = 0; + for (const auto& op : ops) { + auto wrote_or = read_into_at(op.src_offset, op.dest_va_offset, op.bytes, grant); + if (!wrote_or.ok()) { + return wrote_or.status(); + } + total += *wrote_or; + } + return total; + } + + [[nodiscard]] int readv_calls() const { + return readv_calls_; + } + + private: + std::string data_; + uint64_t cursor_{0}; + int readv_calls_{0}; +}; + +TEST_CASE( + "BodyBackingManager composite staging writes one source into multiple final bodies", + "[daemon][body_backing][composite]") { + auto engine = std::make_shared(make_opts_basic()); + tensorcast::daemon::BodyBackingManager manager(*engine); + + const std::string payload_a = "composite-body-alpha"; + const std::string payload_b = "composite-body-beta-longer"; + const std::string prefix = "source-prefix:"; + const std::string gap = ":gap:"; + const std::string slab = prefix + payload_a + gap + payload_b; + const std::uint64_t offset_a = prefix.size(); + const std::uint64_t offset_b = prefix.size() + payload_a.size() + gap.size(); + const auto source = std::make_shared(slab); + + const std::string artifact_id_a = make_test_byte_artifact_id("composite-stage-a:blk-4"); + const std::string artifact_id_b = make_test_byte_artifact_id("composite-stage-b:blk-4"); + tensorcast::daemon::v2::PutIfAbsentInvariant invariant_a; + tensorcast::daemon::v2::PutIfAbsentInvariant invariant_b; + set_invariant(&invariant_a, "layout_v1", payload_a); + set_invariant(&invariant_b, "layout_v1", payload_b); + invariant_a.set_verification_mode(tensorcast::daemon::v2::BYTE_ARTIFACT_VERIFICATION_MODE_LAYOUT_AND_SIZE_ONLY); + invariant_b.set_verification_mode(tensorcast::daemon::v2::BYTE_ARTIFACT_VERIFICATION_MODE_LAYOUT_AND_SIZE_ONLY); + + auto staged_or = manager.stage_bodies_composite( + tensorcast::daemon::BodyBackingManager::StageBodiesCompositeRequest{ + .source = source, + .items = + { + tensorcast::daemon::BodyBackingManager::CompositeStageItem{ + .artifact_id = artifact_id_a, + .invariant = invariant_a, + .source_offset = offset_a, + .length = payload_a.size(), + }, + tensorcast::daemon::BodyBackingManager::CompositeStageItem{ + .artifact_id = artifact_id_b, + .invariant = invariant_b, + .source_offset = offset_b, + .length = payload_b.size(), + }, + }, + .source_kind = tensorcast::store::loading::MaterializationSource::kP2P, + .operation_id = "op-body-composite-stage", + .transport_id = "transport-body-composite-stage", + }); + REQUIRE(staged_or.ok()); + REQUIRE(staged_or->staged_bodies.size() == 2); + CHECK(source->readv_calls() > 0); + CHECK(staged_or->materialize_result.direct_write_supported); + + auto read_a_or = staged_or->staged_bodies[0].body_handle.read_all_bytes(); + auto read_b_or = staged_or->staged_bodies[1].body_handle.read_all_bytes(); + REQUIRE(read_a_or.ok()); + REQUIRE(read_b_or.ok()); + CHECK(*read_a_or == payload_a); + CHECK(*read_b_or == payload_b); + CHECK(staged_or->staged_bodies[0].descriptor.payload_digest_alg.empty()); + CHECK(staged_or->staged_bodies[0].descriptor.payload_digest_hex.empty()); + CHECK( + staged_or->staged_bodies[0].verification_record.verification_method == + tensorcast::store::runtime::ingestion::VerificationMethod::kLayoutAndSizeContract); + + REQUIRE(staged_or->staged_bodies[0].body_handle.retire().ok()); + REQUIRE(staged_or->staged_bodies[1].body_handle.retire().ok()); +} + TEST_CASE("HomeBatchTouchTtl keeps immortal entries immortal", "[daemon][batch][ttl]") { auto engine = std::make_shared(make_opts_basic()); auto harness = make_harness(engine, make_daemon_options()); diff --git a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md index 93f8b697..a989c8d2 100644 --- a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md +++ b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md @@ -17,6 +17,8 @@ related_code: - daemon/state/daemon_kernel.h - daemon/state/worker_directory_cache.h - daemon/service/artifact_profile_registry.h + - daemon/service/body_backing_manager.h + - daemon/service/body_backing_manager.cc - daemon/service/byte_artifact_body_handle.h - daemon/service/byte_artifact_body_handle.cc - daemon/service/byte_artifact_body_store.h @@ -124,6 +126,11 @@ RDMA realization follow-ons still in progress: first put-side cut intentionally keeps source-side staged-slab pack construction and per-item home staging in place; it only removes the home daemon's mandatory remote full-pack mirror. +- The put-side composite final-body staging cut is implemented for eligible + `LAYOUT_AND_SIZE_ONLY` byte artifacts: the home daemon prepares unpublished + retained body backings, lowers one remote pack source into a composite target + layout, and lets the shared `0115` dataplane execute one batched direct-write + materialization before authority join installation. - `BodyHandle` now exposes the export-view API used by source-side segmented communicator export. - The remaining RDMA follow-on is producer-side read-plan servicing: @@ -837,6 +844,132 @@ Observability rules: prove that `remote_mirror_count` and `remote_mirror_bytes` collapse to zero on the intended RDMA put path. +#### 5.5.6b Accepted put-side composite final-body stage + +After the no-mirror consume cut, the remaining put-side batch-set cost is +per-item home staging. The accepted next realization is to stage one eligible +remote pack directly into the unpublished final retained body backings for all +eligible items in that transport, using the shared composite/vectored +direct-write execution contract from `0115`. + +Scope and eligibility: + +1. This path applies only inside `HomeBatchPutIfAbsent` after the home shard, + fence, route epoch, request operation, transport capability, and manifest + slice validation have already succeeded. +2. The first implementation is limited to remote `v2 communicator_source` + transports opened with `PAYLOAD_REF_DIRECTION_PUT` whose resolved + `SeekableSource` is non-null and advertises + `supports_batched_direct_write_at() == true`. +3. The group is initially per `transport_id`: all composite items share one + opened source and one transport manifest. Ineligible items in the same RPC + continue through the 5.5.6a per-item direct-slice path or the existing + full-pack fallback path. +4. Each item must use + `BYTE_ARTIFACT_VERIFICATION_MODE_LAYOUT_AND_SIZE_ONLY`. Modes that require + payload digest verification, including the default strict SHA256 mode, stay + on the existing per-item staging path until a later stream or post-write + digest phase is accepted. +5. The resolved body policy must produce CPU final backings that can be exposed + as writable target-layout storage for the shared materialization dataplane. + GPU-only, non-direct-writable, zero-length, or otherwise unsupported target + shapes are pre-issue fallback cases. +6. Duplicate artifact ids or duplicate join keys inside one composite group are + pre-issue fallback cases until a batch-level de-duplication policy is + explicitly defined. This keeps authority join and cleanup semantics + identical to the existing item-scoped path. + +Body staging seam: + +1. `BodyBackingManager` needs an internal batch staging seam such as + `stage_bodies_composite(...)`, or an equivalent + `prepare_body_stage_targets(...)` plus `finalize_staged_bodies(...)` split. +2. That seam owns final body backing preparation. It creates one unpublished + retained backing per item, with the same logical body identity, + `BodyBackingIntent`, `ResolvedStorePolicy`, stable-retention admission, and + `BodyHandle` lifetime rules that single-item `stage_body(...)` uses. +3. Prepared bodies are not routed truth. They are transient unpublished + candidates until `ByteArtifactAuthorityService::batch_put_if_absent(...)` + accepts the corresponding item. +4. The first implementation may expose one all-in-one + `stage_bodies_composite(...)` API that prepares, materializes, finalizes, + and cleans up on failure. A later split is allowed if tests need finer + control over pre-issue preparation. + +Composite mapping semantics: + +1. The source byte space is the transport pack manifest byte space. Each item + contributes exactly one source slice + `{source_index=0, src_offset=manifest.offset, length=manifest.length}`. +2. The target byte space is a synthetic concatenation of the unpublished final + body backings in composite item order. Item `i` owns target range + `[cursor_i, cursor_i + invariant.byte_length)`, and the backing storage for + that range must have exactly that length. +3. The `ByteRangeMap` uses `total_bytes=sum(item.byte_length)`, + `num_sources=1`, and one segment per item mapping the source pack slice to + the corresponding composite target range. `mapping.total_bytes` is the + composite target byte count, not the remote pack's total advertised size. +4. The `IntoTargetLayout` storages are the final body backing writable spans. + Stable local backing metadata may be attached when the backing manager can + prove the storage is daemon-managed and long-lived enough for the `0115` + direct-write grant, but this metadata remains local placement state and not + routed artifact identity. +5. The controller calls the shared materialization seam with one source vector + entry, the composite `ByteRangeMap`, `source_kind=kP2P`, and a + transport-scoped operation hint. The expected fast path is + `readv_into_at(...)` / `ReadPlan` / RDMA vectored direct-write, not a + byte-artifact-private communicator API. + +Verification, authority, and cleanup: + +1. For `LAYOUT_AND_SIZE_ONLY`, successful composite materialization produces + the same per-item `StageResult` shape as `stage_body(...)`: descriptor, + observation, `BodyHandle`, `VerifiedContentDescriptor`, verification + record, and backing identity. Payload digest fields remain advisory and + must not block publication. +2. Each finalized item still runs + `validate_invariant_body_descriptor(...)` through + `ByteArtifactAuthorityService::batch_put_if_absent(...)`; composite staging + does not install authority truth by itself. +3. Conflict, duplicate-writer, invalid-artifact, or invariant failures after + finalization retire the corresponding unpublished body handle using the + existing authority cleanup path. +4. If preparation, capability validation, policy resolution, target layout + construction, or mapping validation fails before the composite dataplane is + issued, the controller may fall back to the per-item 5.5.6a path for the + affected items. +5. Once the composite materialization has crossed the `0115` issue boundary, + hidden full-pack mirror or per-item staged fallback is forbidden. A + post-issue failure marks the affected composite items failed and retires all + prepared body handles that were not installed. +6. Partial success is item-scoped only after composite materialization + succeeds. A composite execution failure before per-item finalization fails + the whole composite group because individual final backings may have been + partially dirtied. + +Observability rules: + +1. Eligible composite execution must log a put-side apply summary such as + `byte_artifact.home_batch_put_if_absent_transport_apply_summary` with + `read_mode=batched_direct_write`, + `materialize_mode=single_source_composite`, + `stage_mode=composite_final_body`, + `batched_direct_write=true`, `source_count=1`, `mapping_segments`, + `item_count`, `item_bytes`, `transport_payload_bytes`, and `mirror_ms=0`. +2. Pre-issue fallback must log a bounded `fallback_reason` and preserve the + existing `read_mode=direct_remote_slice` or `full_pack_mirror` records so + benchmarks can distinguish no-mirror scalar staging from composite staging. +3. Home summaries should add composite counters for transport count, item + count, byte count, materialization calls, batched-direct-write calls, + fallback count, fallback items, and cleanup/retire count, while retaining + the existing remote mirror and direct-slice counters. +4. Expected RDMA SGLang KV evidence is: source batch-set packs still show + `mode=staged_slab`, home full-pack mirror remains zero, eligible home + transports move from `read_mode=direct_remote_slice` to + `read_mode=batched_direct_write`, and `stage_loader_count` for eligible + layout-and-size items drops toward zero because per-item + `SeekableSourceLoader` staging is bypassed. + #### 5.5.7 Implemented v2 communicator-backed realization `v2 communicator_source` is the current communicator-backed realization. It moves routed byte-artifact remote transport @@ -964,6 +1097,10 @@ Design intent: - put-side `HomeBatchPutIfAbsent` now removes the same mandatory home-daemon full-pack mirror for eligible RDMA `communicator_source` transports while keeping source-side staged-slab pack construction in place, +- put-side `HomeBatchPutIfAbsent` may next consume the same shared `0115` + composite execution contract to batch-stage one remote pack into unpublished + final body backings for `LAYOUT_AND_SIZE_ONLY` items, while preserving + first-writer authority semantics, - the remaining RDMA bottleneck is producer-side servicing: CPU source slices are still copied from retained backing into pinned staged response buffers before remote reads, @@ -1001,22 +1138,28 @@ Normative rules: per-item `SourceSlice` loaders over the remote pack and stage those items through `BodyBackingManager::stage_body(...)` instead of first materializing a full local pack mirror. -6. Source-side RDMA communicator export should continue to prefer no-pack-copy +6. The accepted next put-side step is composite final-body staging, not a new + transport API: `HomeBatchPutIfAbsent` prepares unpublished final retained + backings through the body-backing seam, constructs a one-source + `ByteRangeMap` from pack offsets to those backing offsets, and delegates + materialization to the shared `0115` composite/vectored direct-write + dataplane before authority join installation. +7. Source-side RDMA communicator export should continue to prefer no-pack-copy segmented export over daemon-owned pack slab realization. The producer may expose one logical pack byte space by concatenating per-entry or per-backing exported segments through `remote_memory_keys[]`, `buffer_sizes[]`, and `total_payload_bytes`; it does not need to copy those bytes into one daemon-owned slab first. -7. The next RDMA get-side follow-on is source servicing, not a new sink API. Eligible +8. The next RDMA get-side follow-on is source servicing, not a new sink API. Eligible retained CPU backings should be served as direct-readable source segments in the read-plan response instead of first being copied into pinned staged buffers. -8. The accepted source-side realization seam is `BodyHandle`, as further +9. The accepted source-side realization seam is `BodyHandle`, as further specified by `0089`. `BodyHandle` provides the transport-neutral export-view acquisition API that `PayloadTransportBroker` uses to obtain exportable backing views and keepalive state without reimplementing replica-runtime inspection or export logic. -9. Direct-source RDMA response windows are descriptor-driven, not +10. Direct-source RDMA response windows are descriptor-driven, not staging-driven. They must not consume `FlowCreditLedger`, `StageLease`, or staged ACK-release semantics, and they must not be split merely because staged `buffers_per_flow` credit is exhausted. @@ -1142,6 +1285,7 @@ Rules: - `byte_artifact.home_batch_put_if_absent_transport_open` - `byte_artifact.home_batch_put_if_absent_transport_read_mode` - `byte_artifact.home_batch_put_if_absent_transport_mirror` + - `byte_artifact.home_batch_put_if_absent_transport_apply_summary` - `byte_artifact.home_batch_put_if_absent_stage_plan` - `byte_artifact.home_batch_put_if_absent_summary` - `byte_artifact.batch_put_if_absent_from_region_pack_realization` @@ -1511,8 +1655,13 @@ design delta. manifest and per-item semantics can be preserved through direct-write and `BodyHandle`-backed export views. - Put-side `HomeBatchPutIfAbsent` RDMA `communicator_source` consumption now - follows the same no-mirror remote-slice rule before any larger batch-set - composite or vectored optimization is considered accepted. + follows the same no-mirror remote-slice rule, which is the required baseline + for larger batch-set composite or vectored optimizations. +- Put-side composite final-body staging is accepted as the next + `LAYOUT_AND_SIZE_ONLY` batch-set optimization only if it uses the shared + `0115` composite/vectored direct-write dataplane, stages into unpublished + final body backings, and keeps authority join, conflict, and cleanup behavior + item-scoped after successful materialization. - `PayloadTransportBroker` remains the transport boundary, but source-side no-copy export must consume the `BodyHandle` export-view seam described by `0089` rather than growing broker-private `StoreEngine` inspection logic. - `GetServerConfig` is the peer-discovery surface for batch-transport protocol version and realization support. diff --git a/docs/designs/0115-composite-materialization-and-vectored-direct-write.md b/docs/designs/0115-composite-materialization-and-vectored-direct-write.md index da573ece..ba83734d 100644 --- a/docs/designs/0115-composite-materialization-and-vectored-direct-write.md +++ b/docs/designs/0115-composite-materialization-and-vectored-direct-write.md @@ -4,12 +4,14 @@ title: Composite Materialization and Vectored Direct-Write status: implemented areas: ["core", "daemon", "docs", "benchmarks", "integrations"] created: 2026-04-19 -last_updated: 2026-04-25 +last_updated: 2026-04-26 related_code: - docs/designs/0088-unified-artifact-profiles-with-shared-dataplane.md - docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md - docs/designs/0089-core-backed-body-handles-and-backing-policy.md - docs/designs/0108-tensor-aware-materialization-strategy-plane.md + - daemon/service/body_backing_manager.h + - daemon/service/body_backing_manager.cc - core/store/materialization/contracts/loading_spec.h - core/store/materialization/dataplane/contracts/source.h - core/store/materialization/dataplane/contracts/sink.h @@ -56,8 +58,8 @@ byte-artifact-specific optimization. Instead: semantics, - communicator owns efficient execution of a vectored pull plan, - RDMA transport owns chained WR realization, -- and byte-artifact batch-get becomes only one consumer of that shared - capability. +- and byte-artifact batch-get plus the accepted put-side final-body staging + follow-on become consumers of that shared capability. This design keeps the repository's current architectural boundaries: @@ -68,10 +70,16 @@ This design keeps the repository's current architectural boundaries: lifetimes, - and `0115` owns the new common execution contract below that seam. -One boundary matters for the next RDMA optimization step: `0115` owns the sink- -side composite execution contract and the routed vectored pull API, but it does -not by itself remove producer-side staged response windows for CPU sources. -That direct-readable source-side follow-on remains owned by `0087-01`. +Two boundaries matter for the next RDMA optimization steps: + +1. `0115` owns the sink-side composite execution contract and the routed + vectored pull API, but it does not by itself remove producer-side staged + response windows for CPU sources. That direct-readable source-side follow-on + remains owned by `0087-01`. +2. `0115` also does not own byte-artifact authority, final body backing + allocation, or `BodyHandle` cleanup. Put-side composite final-body staging + may use the `0115` execution contract only after `0087`/`0089` code has + prepared unpublished target backings and a valid `IntoTargetLayout`. # Implementation Status @@ -105,6 +113,13 @@ Landed outcomes: `materialize_mode=single_source_composite batched_direct_write=true`, confirming that the consumer uses the shared `0115` seams rather than a byte-artifact-private communicator API. +- put-side byte-artifact composite final-body staging now consumes the same + capability for eligible `LAYOUT_AND_SIZE_ONLY` remote put transports. Its + profile-specific work remains target acquisition and authority cleanup: + `BodyBackingManager` prepares unpublished final body backings, builds the + composite target layout, and finalizes per-item `StageResult` values around + the shared materialization call. `0115` does not gain a put-specific + communicator API. Remaining follow-on work outside `0115`: @@ -331,6 +346,34 @@ Rules: 4. `ByteRangeMappedSource` remains the generic composite-source executor and fallback surface. +### Put-side final-body staging consumer + +Routed byte-artifact put-side staging may consume this internal helper once the +profile-specific body-backing layer has prepared a target layout. The contract +is intentionally the same as batch-get composite consume: + +1. The source vector contains the opened remote pack source, usually one + `RemoteKeySource` reached through `BatchPayloadCommunicatorSource`. +2. The target layout is a synthetic concatenation of unpublished final body + backing writable spans. `0115` treats those spans as ordinary target + storage; it does not know whether a span will later become routed byte truth. +3. The `ByteRangeMap` maps source pack offsets to target-layout offsets. + `mapping.total_bytes` is the target composite byte count, and + `mapping.num_sources` names the source vector size. +4. Successful materialization reports committed bytes and execution mode only. + The caller remains responsible for per-item descriptor construction, + verification records, `BodyHandle` creation, authority join, conflict + cleanup, and unpublished backing retirement. +5. In the first accepted put-side consumer, only + `LAYOUT_AND_SIZE_ONLY` items should enter the composite path. Strict digest + modes remain a profile-level fallback until the caller can compute or + validate payload digests over the composite write without weakening the + `0115` issue/failure model. +6. The generic post-issue failure rule applies unchanged: after a composite + direct-write batch is issued, `0115` will not silently fall back to staged + per-item writes. The byte-artifact caller must mark the affected unpublished + targets failed and retire them. + ## 4. `ByteRangeMappedSource` and `RemoteKeySource` `ByteRangeMappedSource` and `RemoteKeySource` must both become batch-aware. @@ -1390,11 +1433,15 @@ Compatibility rules: 2. `ByteRangeMap` remains the exact fallback IR. 3. `DirectWriteGrant` stays transport-neutral in the first cut. 4. MTCP and staged paths remain valid fallbacks. -5. `0087` byte-artifact batch-get may consume this capability later but may not +5. `0087` byte-artifact batch-get consumes this capability, but may not redefine it. -6. Stronger profile-level retry policy, if ever added, must layer above this +6. `0087` byte-artifact batch-set may consume this capability for put-side + final-body staging only after unpublished target backing and authority + cleanup semantics are owned above this seam; it may not redefine the + communicator or direct-write API. +7. Stronger profile-level retry policy, if ever added, must layer above this contract rather than weaken the generic failure model. -7. Stable local backing metadata remains local placement state and must not be +8. Stable local backing metadata remains local placement state and must not be promoted into routed identity or remote destination capability. Acceptance criteria: From 197bb7b59e501dc54e624d544b00f5019fe764a0 Mon Sep 17 00:00:00 2001 From: zhouyuhan Date: Mon, 27 Apr 2026 14:28:22 +0800 Subject: [PATCH 04/49] feat(byte-artifact): remove source-side full-pack slab and use direct write RDMA serving for HomeBatchPut --- daemon/service/byte_artifact_region_layout.cc | 44 ++ daemon/service/byte_artifact_region_layout.h | 13 + ...artifact_region_layout_host_shared_test.cc | 15 + .../controllers/byte_artifact_controller.cc | 378 +++++++++++++++++- ...pc_service_impl_batch_redirect_e2e_test.cc | 135 +++++++ daemon/service/payload_transport_broker.cc | 127 ++++++ daemon/service/payload_transport_broker.h | 14 + ...e-and-routed-byte-artifact-architecture.md | 269 +++++++++++-- 8 files changed, 949 insertions(+), 46 deletions(-) diff --git a/daemon/service/byte_artifact_region_layout.cc b/daemon/service/byte_artifact_region_layout.cc index 1f013363..24e96cdf 100644 --- a/daemon/service/byte_artifact_region_layout.cc +++ b/daemon/service/byte_artifact_region_layout.cc @@ -184,11 +184,15 @@ absl::StatusOr ByteArtifactRegionLayout::acquire( if (!region_id_or.ok()) { return region_id_or.status(); } + IpcRegionRegistry::HostRegionClass storage_host_region_class = IpcRegionRegistry::HostRegionClass::kNone; + bool storage_daemon_managed = false; if (*storage_memory_kind_or == IpcRegionRegistry::MemoryKind::kHostShared) { auto region_desc_or = registry.describe(*region_id_or); if (!region_desc_or.ok()) { return region_desc_or.status(); } + storage_host_region_class = region_desc_or->host_region_class; + storage_daemon_managed = region_desc_or->daemon_managed; if (!host_region_class.has_value()) { host_region_class = region_desc_or->host_region_class; } else if (*host_region_class != region_desc_or->host_region_class) { @@ -202,6 +206,8 @@ absl::StatusOr ByteArtifactRegionLayout::acquire( .logical_base = logical_cursor, .length = storage.storage_length(), .memory_kind = *storage_memory_kind_or, + .host_region_class = storage_host_region_class, + .daemon_managed = storage_daemon_managed, .base_ptr = result.storage_lease_.storages().at(static_cast(i)).base_ptr.get(), .device_id = storage.device_id(), .stable_backing = result.storage_lease_.storages().at(static_cast(i)).stable_backing, @@ -371,6 +377,44 @@ absl::StatusOr> ByteArtifactRegio gsl::not_null{item_base_ptr}, device_id_, range.logical_length)); } +absl::StatusOr ByteArtifactRegionLayout::open_host_shared_source_span( + std::string_view artifact_id) const { + const auto it = items_.find(std::string(artifact_id)); + if (it == items_.end()) { + return absl::NotFoundError("artifact_id is not mapped in region layout"); + } + const auto& range = it->second; + const auto& storage = storages_.at(range.storage_index); + if (storage.memory_kind != IpcRegionRegistry::MemoryKind::kHostShared) { + return absl::FailedPreconditionError("source layout item is not HOST_SHARED"); + } + if (!storage.daemon_managed) { + return absl::FailedPreconditionError("HOST_SHARED source layout item is not daemon-managed"); + } + if (storage.base_ptr == nullptr) { + return absl::FailedPreconditionError("HOST_SHARED source layout base pointer is unavailable"); + } + if (range.logical_length == 0) { + return absl::InvalidArgumentError("HOST_SHARED source layout item length must be > 0"); + } + if (storage.host_region_class == IpcRegionRegistry::HostRegionClass::kAllocator && !range.slot_token.has_value()) { + return absl::FailedPreconditionError("allocator-backed HOST_SHARED source span requires slot token"); + } + if (storage.keepalive == nullptr) { + return absl::FailedPreconditionError("HOST_SHARED source span requires region keepalive"); + } + const void* item_base_ptr = static_cast(storage.base_ptr) + range.storage_local_offset; + return HostSharedSourceSpan{ + .data = item_base_ptr, + .length = range.logical_length, + .region_id = storage.region_id, + .host_region_class = storage.host_region_class, + .daemon_managed = storage.daemon_managed, + .slot_token = range.slot_token, + .keepalive = storage.keepalive, + }; +} + absl::Status ByteArtifactRegionLayout::activate_stable_local_backings( store::components::CommunicationManager& comm_manager) const { if (!comm_manager.is_enabled()) { diff --git a/daemon/service/byte_artifact_region_layout.h b/daemon/service/byte_artifact_region_layout.h index 31a591e2..937147b8 100644 --- a/daemon/service/byte_artifact_region_layout.h +++ b/daemon/service/byte_artifact_region_layout.h @@ -28,6 +28,16 @@ class ByteArtifactRegionLayout { std::optional slot_generation; }; + struct HostSharedSourceSpan { + const void* data{nullptr}; + std::uint64_t length{0}; + std::string region_id; + IpcRegionRegistry::HostRegionClass host_region_class{IpcRegionRegistry::HostRegionClass::kNone}; + bool daemon_managed{false}; + std::optional slot_token; + std::shared_ptr keepalive; + }; + ByteArtifactRegionLayout() = default; ByteArtifactRegionLayout(const ByteArtifactRegionLayout&) = delete; ByteArtifactRegionLayout& operator=(const ByteArtifactRegionLayout&) = delete; @@ -50,6 +60,7 @@ class ByteArtifactRegionLayout { std::string_view artifact_id) const; [[nodiscard]] absl::StatusOr> open_item_source( std::string_view artifact_id) const; + [[nodiscard]] absl::StatusOr open_host_shared_source_span(std::string_view artifact_id) const; [[nodiscard]] absl::Status activate_stable_local_backings( store::components::CommunicationManager& comm_manager) const; @@ -60,6 +71,8 @@ class ByteArtifactRegionLayout { std::uint64_t logical_base{0}; std::uint64_t length{0}; IpcRegionRegistry::MemoryKind memory_kind{IpcRegionRegistry::MemoryKind::kVram}; + IpcRegionRegistry::HostRegionClass host_region_class{IpcRegionRegistry::HostRegionClass::kNone}; + bool daemon_managed{false}; void* base_ptr{nullptr}; int device_id{-1}; std::optional stable_backing; diff --git a/daemon/service/byte_artifact_region_layout_host_shared_test.cc b/daemon/service/byte_artifact_region_layout_host_shared_test.cc index ee6d4194..c3a34056 100644 --- a/daemon/service/byte_artifact_region_layout_host_shared_test.cc +++ b/daemon/service/byte_artifact_region_layout_host_shared_test.cc @@ -112,6 +112,21 @@ TEST_CASE( REQUIRE(item_layout_or->storages.size() == 1); auto* region_bytes = static_cast(item_layout_or->storages[0].base_ptr.get()); REQUIRE(std::memcmp(region_bytes, buffer, sizeof(buffer)) == 0); + + auto source_span_or = validated_or->layout.open_host_shared_source_span("artifact-a"); + REQUIRE(source_span_or.ok()); + REQUIRE(source_span_or->data != nullptr); + REQUIRE(source_span_or->length == kItemBytes); + REQUIRE(source_span_or->region_id == desc_or->region_id); + REQUIRE(source_span_or->host_region_class == tensorcast::daemon::IpcRegionRegistry::HostRegionClass::kAllocator); + REQUIRE(source_span_or->daemon_managed); + REQUIRE(source_span_or->slot_token.has_value()); + REQUIRE(source_span_or->slot_token->slot_index.has_value()); + REQUIRE(source_span_or->slot_token->slot_generation.has_value()); + REQUIRE(*source_span_or->slot_token->slot_index == 7); + REQUIRE(*source_span_or->slot_token->slot_generation == 11); + REQUIRE(source_span_or->keepalive != nullptr); + REQUIRE(std::memcmp(source_span_or->data, buffer, sizeof(buffer)) == 0); } TEST_CASE( diff --git a/daemon/service/controllers/byte_artifact_controller.cc b/daemon/service/controllers/byte_artifact_controller.cc index 2588b767..30956d9c 100644 --- a/daemon/service/controllers/byte_artifact_controller.cc +++ b/daemon/service/controllers/byte_artifact_controller.cc @@ -100,6 +100,18 @@ void attach_slot_tokens_to_outcomes( } } +std::string_view host_region_class_label(IpcRegionRegistry::HostRegionClass host_region_class) { + switch (host_region_class) { + case IpcRegionRegistry::HostRegionClass::kScratch: + return "scratch"; + case IpcRegionRegistry::HostRegionClass::kAllocator: + return "allocator"; + case IpcRegionRegistry::HostRegionClass::kNone: + default: + return "none"; + } +} + v2::BatchItemOutcome make_outcome( std::string_view artifact_id, v2::BatchItemStatus status, @@ -311,6 +323,15 @@ struct BatchPayloadPackEntry { absl::Time capability_expires_at{absl::InfiniteFuture()}; }; +struct SourceLayoutBatchPayloadEntry { + std::string artifact_id; + std::uint64_t payload_size_bytes{0}; + ByteArtifactRegionLayout::HostSharedSourceSpan source_span; + std::string digest_alg; + std::string digest_hex; + absl::Time capability_expires_at{absl::InfiniteFuture()}; +}; + struct PlannedBatchPayload { v2::BatchPayloadManifest manifest; std::vector source_indices; @@ -519,6 +540,28 @@ absl::Status validate_batch_payload_pack_entry(const BatchPayloadPackEntry& entr return absl::OkStatus(); } +absl::Status validate_source_layout_batch_payload_entry(const SourceLayoutBatchPayloadEntry& entry) { + if (entry.artifact_id.empty()) { + return absl::InvalidArgumentError("source-layout batch payload entry requires artifact_id"); + } + if (entry.payload_size_bytes == 0) { + return absl::InvalidArgumentError("source-layout batch payload entry payload_size_bytes must be > 0"); + } + if (entry.source_span.data == nullptr) { + return absl::InvalidArgumentError("source-layout batch payload entry requires source span data"); + } + if (entry.source_span.length != entry.payload_size_bytes) { + return absl::InvalidArgumentError("source-layout batch payload entry source span size mismatch"); + } + if (entry.source_span.keepalive == nullptr) { + return absl::InvalidArgumentError("source-layout batch payload entry requires source span keepalive"); + } + if (entry.digest_alg.empty() != entry.digest_hex.empty()) { + return absl::InvalidArgumentError("source-layout batch payload entry digest_alg and digest_hex must both be set"); + } + return absl::OkStatus(); +} + absl::Status fill_batch_payload_pack_entry(const BatchPayloadPackEntry& entry, char* dst) { auto validate_status = validate_batch_payload_pack_entry(entry); if (!validate_status.ok()) { @@ -674,6 +717,90 @@ absl::StatusOr> plan_batch_payload_entries( return packs; } +absl::StatusOr> plan_source_layout_batch_payload_entries( + const std::vector& entries, + std::uint64_t max_payload_bytes, + std::uint32_t max_items) { + std::vector packs; + if (entries.empty()) { + return packs; + } + + struct PendingPack { + std::vector entry_indices; + std::uint64_t total_bytes{0}; + absl::Time capability_expires_at{absl::InfiniteFuture()}; + }; + + const auto flush_pack = [&](const PendingPack& pending) -> absl::StatusOr { + PlannedBatchPayload packed; + packed.source_indices = pending.entry_indices; + packed.capability_expires_at = pending.capability_expires_at; + + std::uint64_t offset = 0; + for (const auto entry_index : pending.entry_indices) { + const auto& entry = entries[entry_index]; + auto validate_status = validate_source_layout_batch_payload_entry(entry); + if (!validate_status.ok()) { + return validate_status; + } + + auto* manifest_entry = packed.manifest.add_entries(); + manifest_entry->set_artifact_id(entry.artifact_id); + manifest_entry->set_offset(offset); + manifest_entry->set_length(entry.payload_size_bytes); + manifest_entry->set_digest_alg(entry.digest_alg); + manifest_entry->set_digest_hex(entry.digest_hex); + + v2::BatchPayloadSlice slice; + slice.set_offset(offset); + slice.set_length(entry.payload_size_bytes); + packed.slices.push_back(std::move(slice)); + offset += entry.payload_size_bytes; + } + packed.manifest.set_total_size(offset); + return packed; + }; + + PendingPack pending; + for (std::size_t entry_index = 0; entry_index < entries.size(); ++entry_index) { + const auto& entry = entries[entry_index]; + auto validate_status = validate_source_layout_batch_payload_entry(entry); + if (!validate_status.ok()) { + return validate_status; + } + const std::uint64_t entry_bytes = entry.payload_size_bytes; + if (max_payload_bytes != 0 && entry_bytes > max_payload_bytes) { + return absl::InvalidArgumentError("source-layout batch payload entry exceeds max_payload_bytes"); + } + + const bool reaches_item_limit = max_items != 0 && pending.entry_indices.size() >= max_items; + const bool reaches_byte_limit = + max_payload_bytes != 0 && pending.total_bytes != 0 && pending.total_bytes + entry_bytes > max_payload_bytes; + if (!pending.entry_indices.empty() && (reaches_item_limit || reaches_byte_limit)) { + auto packed_or = flush_pack(pending); + if (!packed_or.ok()) { + return packed_or.status(); + } + packs.push_back(std::move(*packed_or)); + pending = PendingPack{}; + } + + pending.entry_indices.push_back(entry_index); + pending.total_bytes += entry_bytes; + pending.capability_expires_at = std::min(pending.capability_expires_at, entry.capability_expires_at); + } + + if (!pending.entry_indices.empty()) { + auto packed_or = flush_pack(pending); + if (!packed_or.ok()) { + return packed_or.status(); + } + packs.push_back(std::move(*packed_or)); + } + return packs; +} + absl::StatusOr> realize_staged_batch_payload( const std::vector& entries, const PlannedBatchPayload& plan) { @@ -773,6 +900,32 @@ acquire_segmented_batch_payload_source_segments( return source_segments; } +absl::StatusOr> +acquire_segmented_region_source_segments( + const std::vector& entries, + const PlannedBatchPayload& plan) { + std::vector source_segments; + source_segments.reserve(plan.source_indices.size()); + for (std::size_t plan_index = 0; plan_index < plan.source_indices.size(); ++plan_index) { + const auto entry_index = plan.source_indices[plan_index]; + const auto& entry = entries[entry_index]; + auto validate_status = validate_source_layout_batch_payload_entry(entry); + if (!validate_status.ok()) { + return validate_status; + } + if (entry.payload_size_bytes != plan.slices[plan_index].length()) { + return absl::FailedPreconditionError("segmented region source size mismatch"); + } + source_segments.push_back( + PayloadTransportBroker::BatchCommunicatorRegionSourceSegment{ + .data = entry.source_span.data, + .size_bytes = entry.source_span.length, + .keepalive = entry.source_span.keepalive, + }); + } + return source_segments; +} + absl::StatusOr open_loader_from_resolved_source_capability( PayloadTransportBroker& payload_transport_broker, WorkerDirectoryCache& worker_directory_cache, @@ -4822,9 +4975,13 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( std::size_t remote_batch_pack_count{0}; std::size_t remote_batch_pack_item_count{0}; std::uint64_t remote_batch_pack_bytes{0}; + std::size_t remote_batch_segmented_region_export_count{0}; + std::size_t remote_batch_segmented_region_export_item_count{0}; + std::uint64_t remote_batch_segmented_region_export_bytes{0}; absl::Duration local_stage_elapsed{absl::ZeroDuration()}; absl::Duration remote_stage_elapsed{absl::ZeroDuration()}; absl::Duration remote_batch_pack_elapsed{absl::ZeroDuration()}; + absl::Duration remote_batch_segmented_region_export_elapsed{absl::ZeroDuration()}; absl::Duration local_home_apply_elapsed{absl::ZeroDuration()}; absl::Duration remote_home_rpc_elapsed{absl::ZeroDuration()}; } stats; @@ -5337,6 +5494,12 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( << " remote_batch_pack_items=" << task_stats.remote_batch_pack_item_count << " remote_batch_pack_bytes=" << task_stats.remote_batch_pack_bytes << " remote_batch_pack_ms=" << absl::ToDoubleMilliseconds(task_stats.remote_batch_pack_elapsed) + << " remote_batch_segmented_region_export_count=" << task_stats.remote_batch_segmented_region_export_count + << " remote_batch_segmented_region_export_items=" + << task_stats.remote_batch_segmented_region_export_item_count + << " remote_batch_segmented_region_export_bytes=" << task_stats.remote_batch_segmented_region_export_bytes + << " remote_batch_segmented_region_export_ms=" + << absl::ToDoubleMilliseconds(task_stats.remote_batch_segmented_region_export_elapsed) << " local_home_apply_ms=" << absl::ToDoubleMilliseconds(task_stats.local_home_apply_elapsed) << " remote_home_rpc_ms=" << absl::ToDoubleMilliseconds(task_stats.remote_home_rpc_elapsed) << " total_ms=" << absl::ToDoubleMilliseconds(task_result.total_elapsed); @@ -5365,14 +5528,199 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( ++task_stats.remote_home_batch_count; task_stats.remote_home_item_count += task.batch.items.size(); - for (auto& pending : task.batch.items) { - stage_pending_body(&pending, BodyAccessClass::kTransientForward); - } - const PeerBatchTransportSupport peer_transport_support = resolve_peer_batch_transport_support_for_task(task.route.holder_daemon_id); std::vector batch_transports; absl::flat_hash_map batch_slice_by_outcome_index; + + const auto log_source_no_pack_fallback = + [&](const PendingPut& pending, std::string_view reason, std::string_view message = "") { + VLOG(2) << "byte_artifact.batch_put_if_absent_from_region_segmented_export_fallback" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " artifact_id=" << pending.artifact_id + << " reason=" << reason << " message=" << message; + }; + + const auto stage_source_layout_fallback_entries = + [&](absl::Span pending_items, std::string_view reason, std::string_view message = "") { + for (PendingPut* pending : pending_items) { + if (pending == nullptr || has_outcome(pending->outcome_index) || pending->body_handle.has_value()) { + continue; + } + log_source_no_pack_fallback(*pending, reason, message); + stage_pending_body(pending, BodyAccessClass::kTransientForward); + } + }; + + const bool source_layout_no_pack_possible = peer_transport_support.supports_segmented_communicator_export() && + d_.payload_transport_broker.batch_transport_segmented_communicator_export_enabled() && + local_producer_endpoint.available && local_producer_endpoint.p2p_port != 0; + + std::vector source_layout_entries; + std::vector source_layout_pending; + source_layout_entries.reserve(task.batch.items.size()); + source_layout_pending.reserve(task.batch.items.size()); + for (auto& pending : task.batch.items) { + if (has_outcome(pending.outcome_index)) { + continue; + } + bool admitted_source_layout_no_pack = false; + if (pending.needs_source_layout) { + if (!source_layout_no_pack_possible) { + std::string_view reason = "peer_lacks_segmented_export"; + if (peer_transport_support.supports_segmented_communicator_export() && + !d_.payload_transport_broker.batch_transport_segmented_communicator_export_enabled()) { + reason = "local_segmented_export_disabled"; + } else if (peer_transport_support.supports_segmented_communicator_export()) { + reason = "producer_endpoint_unavailable"; + } + log_source_no_pack_fallback(pending, reason); + } else if (!source_layout.has_value()) { + log_source_no_pack_fallback(pending, "source_layout_missing"); + } else if (verification_mode_requires_payload_digest(invariant_verification_mode(pending.invariant))) { + log_source_no_pack_fallback(pending, "strict_digest"); + } else { + auto source_span_or = source_layout->open_host_shared_source_span(pending.artifact_id); + if (!source_span_or.ok()) { + log_source_no_pack_fallback(pending, "not_host_shared", std::string(source_span_or.status().message())); + } else if (source_span_or->length != pending.invariant.byte_length()) { + log_source_no_pack_fallback(pending, "source_span_length_mismatch"); + } else { + source_layout_entries.push_back( + SourceLayoutBatchPayloadEntry{ + .artifact_id = pending.artifact_id, + .payload_size_bytes = source_span_or->length, + .source_span = std::move(*source_span_or), + .digest_alg = normalize_body_digest_value(pending.invariant.payload_digest_alg()), + .digest_hex = normalize_body_digest_value(pending.invariant.payload_digest_hex()), + .capability_expires_at = absl::InfiniteFuture(), + }); + source_layout_pending.push_back(&pending); + admitted_source_layout_no_pack = true; + } + } + } + if (!admitted_source_layout_no_pack) { + stage_pending_body(&pending, BodyAccessClass::kTransientForward); + } + } + + if (!source_layout_entries.empty()) { + const absl::Time export_started_at = absl::Now(); + auto plans_or = plan_source_layout_batch_payload_entries( + source_layout_entries, + d_.payload_transport_broker.max_batch_payload_bytes(), + d_.payload_transport_broker.max_batch_items()); + if (!plans_or.ok()) { + stage_source_layout_fallback_entries( + absl::MakeSpan(source_layout_pending), "segment_budget_exceeded", std::string(plans_or.status().message())); + } else { + for (const auto& pack : *plans_or) { + std::vector pack_pending; + pack_pending.reserve(pack.source_indices.size()); + std::string host_region_class = "mixed"; + for (const auto entry_index : pack.source_indices) { + pack_pending.push_back(source_layout_pending[entry_index]); + const auto current_label = + host_region_class_label(source_layout_entries[entry_index].source_span.host_region_class); + if (host_region_class == "mixed") { + host_region_class = std::string(current_label); + } else if (host_region_class != current_label) { + host_region_class = "mixed"; + } + } + + auto source_segments_or = acquire_segmented_region_source_segments(source_layout_entries, pack); + if (!source_segments_or.ok()) { + stage_source_layout_fallback_entries( + absl::MakeSpan(pack_pending), + "export_registration_failed", + std::string(source_segments_or.status().message())); + continue; + } + auto communicator_export_or = d_.payload_transport_broker.issue_batch_payload_communicator_export( + pack.manifest, + absl::MakeSpan(*source_segments_or), + tensorcast::common::v1::PAYLOAD_REF_DIRECTION_PUT, + operation_id, + pack.capability_expires_at, + task.route.holder_daemon_id); + if (!communicator_export_or.ok()) { + stage_source_layout_fallback_entries( + absl::MakeSpan(pack_pending), + "export_registration_failed", + std::string(communicator_export_or.status().message())); + continue; + } + + v2::BatchPayloadTransport transport; + const std::string transport_id = absl::StrCat("batch-transport-", batch_transports.size() + 1); + transport.set_transport_id(transport_id); + transport.mutable_manifest()->CopyFrom(pack.manifest); + auto* communicator_source = transport.mutable_communicator_source(); + communicator_source->set_batch_payload_ref(communicator_export_or->batch_payload_ref); + communicator_source->set_protocol_version(d_.payload_transport_broker.batch_transport_protocol_version()); + communicator_source->set_producer_daemon_id(local_daemon_id); + communicator_source->set_consumer_daemon_id(task.route.holder_daemon_id); + communicator_source->set_producer_host(local_producer_endpoint.node_address); + communicator_source->set_producer_port(local_producer_endpoint.p2p_port); + for (const auto& remote_memory_key : communicator_export_or->export_registration.remote_memory_keys) { + communicator_source->add_remote_memory_keys(remote_memory_key); + } + for (const auto buffer_size : communicator_export_or->export_registration.buffer_sizes) { + communicator_source->add_buffer_sizes(buffer_size); + } + if (!local_producer_endpoint.node_id.empty()) { + communicator_source->set_remote_endpoint_id( + store::components::derive_endpoint_id( + local_producer_endpoint.node_id, common::memory::MemoryLocation::CPU, /*device_id=*/0)); + } + const auto consumer_endpoint_it = prebuilt_remote_cpu_endpoint_ids.find(task.route.holder_daemon_id); + if (consumer_endpoint_it != prebuilt_remote_cpu_endpoint_ids.end()) { + communicator_source->set_local_endpoint_id_hint(consumer_endpoint_it->second); + } + communicator_source->set_memory_location(v2::BATCH_PAYLOAD_MEMORY_LOCATION_HOST); + communicator_source->set_total_payload_bytes(pack.manifest.total_size()); + + ++task_stats.remote_batch_transport_count; + ++task_stats.remote_batch_transport_communicator_count; + task_stats.remote_batch_transport_item_count += pack.source_indices.size(); + task_stats.remote_batch_transport_bytes += pack.manifest.total_size(); + ++task_stats.remote_batch_segmented_region_export_count; + task_stats.remote_batch_segmented_region_export_item_count += pack.source_indices.size(); + task_stats.remote_batch_segmented_region_export_bytes += pack.manifest.total_size(); + + LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_pack_realization" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " remote=true" + << " mode=segmented_region_export" + << " staged_slab=false" + << " source_realization_mode=source_layout_host_shared" + << " host_region_class=" << host_region_class << " pack_count=1" + << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size() + << " source_segments=" << source_segments_or->size() + << " remote_keys=" << communicator_export_or->export_registration.remote_memory_keys.size() + << " registration_ownership=broker_owned"; + + LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_transport_emit" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " source_realization_mode=segmented_region_export" + << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size(); + + for (std::size_t pack_index = 0; pack_index < pack.source_indices.size(); ++pack_index) { + auto slice = pack.slices[pack_index]; + slice.set_transport_id(transport_id); + batch_slice_by_outcome_index.emplace( + source_layout_pending[pack.source_indices[pack_index]]->outcome_index, std::move(slice)); + } + batch_transports.push_back(std::move(transport)); + } + } + task_stats.remote_batch_segmented_region_export_elapsed += absl::Now() - export_started_at; + } + if (peer_transport_support.supports_v1()) { std::vector batch_entries; std::vector batch_entry_outcome_indices; @@ -5617,21 +5965,24 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( prepared_remote_batch.outcome_slots.reserve(task.batch.items.size()); prepared_remote_batch.retire_handles.reserve(task.batch.items.size()); for (const auto& pending : task.batch.items) { - if (!pending.body_handle.has_value()) { + const auto batch_slice_it = batch_slice_by_outcome_index.find(pending.outcome_index); + const bool has_batch_slice = batch_slice_it != batch_slice_by_outcome_index.end(); + if (!pending.body_handle.has_value() && !has_batch_slice) { continue; } - prepared_remote_batch.retire_handles.push_back(*pending.body_handle); + if (pending.body_handle.has_value()) { + prepared_remote_batch.retire_handles.push_back(*pending.body_handle); + } auto* dst = prepared_remote_batch.home_req.add_items(); dst->set_artifact_id(pending.artifact_id); dst->mutable_invariant()->CopyFrom(pending.invariant); - const auto batch_slice_it = batch_slice_by_outcome_index.find(pending.outcome_index); - if (batch_slice_it != batch_slice_by_outcome_index.end()) { + if (has_batch_slice) { dst->mutable_batch_payload_slice()->CopyFrom(batch_slice_it->second); } const auto payload_ref_it = payload_ref_by_artifact.find(pending.artifact_id); if (payload_ref_it != payload_ref_by_artifact.end()) { dst->set_payload_ref(payload_ref_it->second); - } else if (batch_slice_it == batch_slice_by_outcome_index.end()) { + } else if (!has_batch_slice) { prepared_remote_batch.home_req.mutable_items()->RemoveLast(); continue; } @@ -5761,9 +6112,13 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( stats.remote_batch_pack_count += delta.remote_batch_pack_count; stats.remote_batch_pack_item_count += delta.remote_batch_pack_item_count; stats.remote_batch_pack_bytes += delta.remote_batch_pack_bytes; + stats.remote_batch_segmented_region_export_count += delta.remote_batch_segmented_region_export_count; + stats.remote_batch_segmented_region_export_item_count += delta.remote_batch_segmented_region_export_item_count; + stats.remote_batch_segmented_region_export_bytes += delta.remote_batch_segmented_region_export_bytes; stats.local_stage_elapsed += delta.local_stage_elapsed; stats.remote_stage_elapsed += delta.remote_stage_elapsed; stats.remote_batch_pack_elapsed += delta.remote_batch_pack_elapsed; + stats.remote_batch_segmented_region_export_elapsed += delta.remote_batch_segmented_region_export_elapsed; stats.local_home_apply_elapsed += delta.local_home_apply_elapsed; stats.remote_home_rpc_elapsed += delta.remote_home_rpc_elapsed; }; @@ -5829,9 +6184,14 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( << " remote_batch_pack_count=" << stats.remote_batch_pack_count << " remote_batch_pack_items=" << stats.remote_batch_pack_item_count << " remote_batch_pack_bytes=" << stats.remote_batch_pack_bytes + << " remote_batch_segmented_region_export_count=" << stats.remote_batch_segmented_region_export_count + << " remote_batch_segmented_region_export_items=" << stats.remote_batch_segmented_region_export_item_count + << " remote_batch_segmented_region_export_bytes=" << stats.remote_batch_segmented_region_export_bytes << " local_stage_ms=" << absl::ToDoubleMilliseconds(stats.local_stage_elapsed) << " remote_stage_ms=" << absl::ToDoubleMilliseconds(stats.remote_stage_elapsed) << " remote_batch_pack_ms=" << absl::ToDoubleMilliseconds(stats.remote_batch_pack_elapsed) + << " remote_batch_segmented_region_export_ms=" + << absl::ToDoubleMilliseconds(stats.remote_batch_segmented_region_export_elapsed) << " local_home_apply_ms=" << absl::ToDoubleMilliseconds(stats.local_home_apply_elapsed) << " remote_home_rpc_ms=" << absl::ToDoubleMilliseconds(stats.remote_home_rpc_elapsed); rctx.mark_success(); diff --git a/daemon/service/grpc_service_impl_batch_redirect_e2e_test.cc b/daemon/service/grpc_service_impl_batch_redirect_e2e_test.cc index ed0ecebc..993bda2f 100644 --- a/daemon/service/grpc_service_impl_batch_redirect_e2e_test.cc +++ b/daemon/service/grpc_service_impl_batch_redirect_e2e_test.cc @@ -399,6 +399,15 @@ void set_invariant( invariant->set_payload_digest_hex(sha256_hex(payload)); } +void set_layout_only_invariant( + tensorcast::daemon::v2::PutIfAbsentInvariant* invariant, + std::string_view layout_id, + std::string_view payload) { + invariant->set_layout_id(std::string(layout_id)); + invariant->set_byte_length(payload.size()); + invariant->set_verification_mode(tensorcast::daemon::v2::BYTE_ARTIFACT_VERIFICATION_MODE_LAYOUT_AND_SIZE_ONLY); +} + struct LeaseState { uint64_t shard_id{0}; std::string holder_daemon_id; @@ -1575,6 +1584,132 @@ TEST_CASE( release_test_host_shared_region(*front.harness, target_region); } +TEST_CASE( + "BatchPutIfAbsentFromRegion exports remote HOST_SHARED source layouts without staged slab", + "[daemon][batch][redirect][transport][communicator][host_shared][put]") { + auto gs = std::make_shared(); + gs->connected = true; + + const auto root_front = make_tmp_dir("front_put_segmented_region"); + const auto root_home = make_tmp_dir("home_put_segmented_region"); + + const auto lease_ttl = std::chrono::seconds(5); + const auto route_budget = std::chrono::seconds(10); + const auto worker_budget = std::chrono::seconds(10); + const auto keepalive_interval = std::chrono::hours(1); + auto engine_front = std::make_shared(make_engine_opts(root_front)); + auto engine_home = std::make_shared(make_engine_opts(root_home)); + + GrpcDaemon home( + kHomeDaemonId, + root_home, + engine_home, + gs, + lease_ttl, + route_budget, + worker_budget, + keepalive_interval, + /*shard_home_eligible=*/true, + /*routing_epoch=*/1, + /*start_server=*/true, + /*inline_payload_threshold_bytes=*/8); + const uint16_t home_p2p_port = engine_home->get_shared_comm_manager()->listen_port(); + gs->upsert_worker( + kHomeDaemonId, "127.0.0.1", grpc_port_from_address(home.address), kShardHomeEligibleFlag, home_p2p_port); + + GrpcDaemon front( + kFrontDaemonId, + root_front, + engine_front, + gs, + lease_ttl, + route_budget, + worker_budget, + keepalive_interval, + /*shard_home_eligible=*/true, + /*routing_epoch=*/1, + /*start_server=*/true, + /*inline_payload_threshold_bytes=*/8); + const uint16_t front_p2p_port = engine_front->get_shared_comm_manager()->listen_port(); + gs->upsert_worker( + kFrontDaemonId, "127.0.0.1", grpc_port_from_address(front.address), kShardHomeEligibleFlag, front_p2p_port); + + const std::vector daemon_ids{std::string(kFrontDaemonId), std::string(kHomeDaemonId)}; + const std::string artifact_id_a = find_artifact_id_for_expected_home(kHomeDaemonId, daemon_ids); + const std::string artifact_id_b = artifact_on_same_shard(artifact_id_a, "put-segmented-region"); + const std::string payload_a(64, 'x'); + const std::string payload_b(96, 'y'); + const std::string slab = payload_a + payload_b; + const std::uint64_t shard_id = shard_for_artifact(artifact_id_a, /*shard_count=*/4096ULL); + gs->seed_lease(shard_id, kHomeDaemonId, /*lease_generation=*/1); + + auto source_region = register_test_host_shared_region( + *front.harness, slab.size(), tensorcast::daemon::IpcRegionRegistry::HostRegionClass::kAllocator); + std::memcpy(source_region.base_ptr, slab.data(), slab.size()); + + BatchPutIfAbsentFromRegionRequest put_req; + auto* put_item_a = put_req.add_items(); + put_item_a->mutable_selection()->set_artifact_id(artifact_id_a); + set_layout_only_invariant(put_item_a->mutable_invariant(), "layout_v1", payload_a); + auto* put_item_b = put_req.add_items(); + put_item_b->mutable_selection()->set_artifact_id(artifact_id_b); + set_layout_only_invariant(put_item_b->mutable_invariant(), "layout_v1", payload_b); + populate_two_item_host_shared_region_layout( + put_req.mutable_source_layout(), source_region, artifact_id_a, payload_a.size(), artifact_id_b, payload_b.size()); + put_req.mutable_source_layout()->mutable_offsets(0)->set_slot_index(7); + put_req.mutable_source_layout()->mutable_offsets(0)->set_slot_generation(17); + put_req.mutable_source_layout()->mutable_offsets(1)->set_slot_index(8); + put_req.mutable_source_layout()->mutable_offsets(1)->set_slot_generation(18); + put_req.set_pid(source_region.owner_pid); + put_req.set_operation_id("op-put-segmented-region"); + + CollectingLogSink sink; + absl::AddLogSink(&sink); + BatchPutIfAbsentFromRegionResponse put_resp; + grpc::ServerContext put_ctx; + const auto put_st = front.harness->service().BatchPutIfAbsentFromRegion(&put_ctx, &put_req, &put_resp); + absl::RemoveLogSink(&sink); + + REQUIRE(put_st.ok()); + REQUIRE(put_resp.outcomes_size() == 2); + CAPTURE(put_resp.outcomes(0).message()); + CAPTURE(put_resp.outcomes(1).message()); + REQUIRE(put_resp.outcomes(0).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + REQUIRE(put_resp.outcomes(1).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + CHECK(put_resp.outcomes(0).slot_index() == 7); + CHECK(put_resp.outcomes(0).slot_generation() == 17); + CHECK(put_resp.outcomes(1).slot_index() == 8); + CHECK(put_resp.outcomes(1).slot_generation() == 18); + CHECK(sink.Contains("mode=segmented_region_export")); + CHECK(sink.Contains("source_realization_mode=source_layout_host_shared")); + CHECK_FALSE(sink.Contains("mode=staged_slab")); + + auto target_region = register_test_host_shared_region(*front.harness, slab.size()); + std::memset(target_region.base_ptr, 0, static_cast(target_region.size_bytes)); + + BatchGetIntoRegionRequest get_req; + get_req.add_selections()->set_artifact_id(artifact_id_a); + get_req.add_selections()->set_artifact_id(artifact_id_b); + populate_two_item_host_shared_region_layout( + get_req.mutable_target_layout(), target_region, artifact_id_a, payload_a.size(), artifact_id_b, payload_b.size()); + get_req.set_pid(target_region.owner_pid); + get_req.set_operation_id("op-put-segmented-region"); + + BatchGetIntoRegionResponse get_resp; + grpc::ServerContext get_ctx; + const auto get_st = front.harness->service().BatchGetIntoRegion(&get_ctx, &get_req, &get_resp); + REQUIRE(get_st.ok()); + REQUIRE(get_resp.outcomes_size() == 2); + REQUIRE(get_resp.outcomes(0).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + REQUIRE(get_resp.outcomes(1).status() == BatchItemStatus::BATCH_ITEM_STATUS_OK); + REQUIRE(std::string(static_cast(target_region.base_ptr), payload_a.size()) == payload_a); + REQUIRE( + std::string(static_cast(target_region.base_ptr) + payload_a.size(), payload_b.size()) == payload_b); + + release_test_host_shared_region(*front.harness, source_region); + release_test_host_shared_region(*front.harness, target_region); +} + TEST_CASE( "payload_ref remote issuer validation uses canonical routed owner rpc", "[daemon][batch][redirect][issuer_route]") { diff --git a/daemon/service/payload_transport_broker.cc b/daemon/service/payload_transport_broker.cc index 5352d109..5d20314b 100644 --- a/daemon/service/payload_transport_broker.cc +++ b/daemon/service/payload_transport_broker.cc @@ -1912,6 +1912,133 @@ absl::StatusOr PayloadTransport }; } +absl::StatusOr PayloadTransportBroker:: + issue_batch_payload_communicator_export( + const v2::BatchPayloadManifest& manifest, + absl::Span source_segments, + tensorcast::common::v1::PayloadRefDirection direction, + std::string_view operation_id, + absl::Time capability_expires_at, + std::string_view consumer_daemon_id) { + if (!batch_transport_segmented_communicator_export_enabled()) { + return absl::FailedPreconditionError("segmented batch communicator transport is disabled"); + } + if (options_.comm_manager == nullptr || !options_.comm_manager->is_enabled()) { + return absl::FailedPreconditionError("communication manager is unavailable for region source export"); + } + auto manifest_status = validate_batch_payload_manifest(manifest); + if (!manifest_status.ok()) { + return manifest_status; + } + if (static_cast(source_segments.size()) != manifest.entries_size()) { + return absl::InvalidArgumentError("segmented region source count mismatch"); + } + const absl::Time now = absl::Now(); + auto expires_at_or = resolve_payload_ref_expiry(now, options_.ttl, capability_expires_at); + if (!expires_at_or.ok()) { + return expires_at_or.status(); + } + const absl::Time expires_at = *expires_at_or; + if (expires_at - now < options_.minimum_batch_transport_ttl) { + return absl::FailedPreconditionError("batch communicator transport ttl below minimum"); + } + + std::vector buffer_addresses; + std::vector buffer_sizes; + std::vector> keepalives; + buffer_addresses.reserve(source_segments.size()); + buffer_sizes.reserve(source_segments.size()); + keepalives.reserve(source_segments.size()); + + std::uint64_t total_payload_bytes = 0; + for (int entry_index = 0; entry_index < manifest.entries_size(); ++entry_index) { + const auto& segment = source_segments[entry_index]; + if (segment.data == nullptr) { + return absl::FailedPreconditionError("segmented region source requires data pointer"); + } + if (segment.size_bytes == 0) { + return absl::FailedPreconditionError("segmented region source requires non-empty segment"); + } + if (segment.keepalive == nullptr) { + return absl::FailedPreconditionError("segmented region source requires keepalive"); + } + if (segment.size_bytes != manifest.entries(entry_index).length()) { + return absl::FailedPreconditionError("segmented region source length mismatch"); + } + if (segment.size_bytes > std::numeric_limits::max()) { + return absl::OutOfRangeError("segmented region source exceeds host memory limits"); + } + if (total_payload_bytes > std::numeric_limits::max() - segment.size_bytes) { + return absl::OutOfRangeError("segmented region source payload exceeds uint64 range"); + } + total_payload_bytes += segment.size_bytes; + buffer_addresses.push_back(const_cast(segment.data)); + buffer_sizes.push_back(static_cast(segment.size_bytes)); + keepalives.push_back(segment.keepalive); + } + if (total_payload_bytes != manifest.total_size()) { + return absl::FailedPreconditionError("segmented region source payload size mismatch"); + } + + const absl::Time export_started_at = absl::Now(); + const absl::Time register_started_at = absl::Now(); + auto registration_or = options_.comm_manager->register_memory(buffer_addresses, buffer_sizes, /*device_id=*/-1); + const absl::Duration register_elapsed = absl::Now() - register_started_at; + if (!registration_or.ok()) { + return registration_or.status(); + } + registration_or->location = common::memory::MemoryLocation::CPU; + registration_or->device_id = -1; + registration_or->artifact_size = manifest.total_size(); + + const auto unregister_registered = [&]() { + for (const auto& tensor_key : registration_or->remote_memory_keys) { + (void)options_.comm_manager->get_engine().unregister_tensor(tensor_key); + } + }; + + const absl::Time issue_ref_started_at = absl::Now(); + auto batch_payload_ref_or = issue_batch_payload_ref_record( + manifest, + /*payload=*/nullptr, + direction, + operation_id, + capability_expires_at, + consumer_daemon_id); + const absl::Duration issue_ref_elapsed = absl::Now() - issue_ref_started_at; + if (!batch_payload_ref_or.ok()) { + unregister_registered(); + return batch_payload_ref_or.status(); + } + { + absl::MutexLock lock(&mu_); + const auto it = batch_records_.find(batch_payload_ref_or->metadata.transport_id); + if (it == batch_records_.end()) { + unregister_registered(); + lifecycle_manager_->release_lease(batch_payload_ref_or->lease_id); + return absl::NotFoundError("segmented region source transport record is missing"); + } + it->second.communicator_export = *registration_or; + it->second.communicator_export_keepalives = std::move(keepalives); + it->second.communicator_export_requires_unregister = true; + } + VLOG(2) << "batch_payload_ref.communicator_export_summary" + << " direction=" << payload_direction_label(direction) << " operation_id=" << operation_id + << " realization=segmented_region_source" + << " transport_id=" << batch_payload_ref_or->metadata.transport_id + << " payload_bytes=" << batch_payload_ref_or->metadata.payload_size + << " remote_keys=" << registration_or->remote_memory_keys.size() + << " source_segments=" << source_segments.size() + << " register_ms=" << absl::ToDoubleMilliseconds(register_elapsed) + << " issue_ref_ms=" << absl::ToDoubleMilliseconds(issue_ref_elapsed) + << " total_ms=" << absl::ToDoubleMilliseconds(absl::Now() - export_started_at); + return BatchCommunicatorExport{ + .metadata = batch_payload_ref_or->metadata, + .batch_payload_ref = batch_payload_ref_or->batch_payload_ref, + .export_registration = *registration_or, + }; +} + absl::StatusOr PayloadTransportBroker::inspect_payload_ref( std::string_view payload_ref, absl::Time now, diff --git a/daemon/service/payload_transport_broker.h b/daemon/service/payload_transport_broker.h index 51ed6b42..02a87fed 100644 --- a/daemon/service/payload_transport_broker.h +++ b/daemon/service/payload_transport_broker.h @@ -105,6 +105,12 @@ class PayloadTransportBroker { BodyExportView export_view; }; + struct BatchCommunicatorRegionSourceSegment { + const void* data{nullptr}; + std::uint64_t size_bytes{0}; + std::shared_ptr keepalive; + }; + struct BatchPayloadSource { BatchRefMetadata metadata; std::shared_ptr source; @@ -185,6 +191,14 @@ class PayloadTransportBroker { absl::Time capability_expires_at = absl::InfiniteFuture(), std::string_view consumer_daemon_id = ""); + [[nodiscard]] absl::StatusOr issue_batch_payload_communicator_export( + const v2::BatchPayloadManifest& manifest, + absl::Span source_segments, + tensorcast::common::v1::PayloadRefDirection direction, + std::string_view operation_id = "", + absl::Time capability_expires_at = absl::InfiniteFuture(), + std::string_view consumer_daemon_id = ""); + [[nodiscard]] absl::StatusOr inspect_payload_ref( std::string_view payload_ref, absl::Time now, diff --git a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md index a989c8d2..f803ce07 100644 --- a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md +++ b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md @@ -121,21 +121,24 @@ RDMA realization follow-ons still in progress: mirror, while MTCP and staged fallback paths still materialize one full local pack payload per `transport_id` before per-item slicing. - `HomeBatchPutIfAbsent` accepts batch transports on the put path and consumes - eligible remote RDMA `communicator_source` transports through per-item - `SourceSlice` loaders staged by `BodyBackingManager::stage_body(...)`. The - first put-side cut intentionally keeps source-side staged-slab pack - construction and per-item home staging in place; it only removes the home - daemon's mandatory remote full-pack mirror. + eligible remote RDMA `communicator_source` transports without a home-daemon + full-pack mirror. - The put-side composite final-body staging cut is implemented for eligible `LAYOUT_AND_SIZE_ONLY` byte artifacts: the home daemon prepares unpublished retained body backings, lowers one remote pack source into a composite target layout, and lets the shared `0115` dataplane execute one batched direct-write materialization before authority join installation. +- The put-side source daemon now realizes eligible remote-home + `HOST_SHARED`/`LAYOUT_AND_SIZE_ONLY` source-layout batch-set transports as + segmented `communicator_source` exports over the original source spans. The + contiguous staged-slab realization remains the fallback for strict digest, + non-`HOST_SHARED`, capability-miss, and lifetime-gap cases. - `BodyHandle` now exposes the export-view API used by source-side segmented communicator export. -- The remaining RDMA follow-on is producer-side read-plan servicing: - source-side CPU byte-artifact slices are still copied from retained backing - into pinned staged response buffers before the remote RDMA reads them. +- The remaining get-side RDMA follow-on is producer-side read-plan servicing: + source-side CPU byte-artifact slices may still be copied from retained + backing into pinned staged response buffers before the remote RDMA reads + them. - NodeAgent preserves structured `manifest`, `publish`, `hydrate`, and `evict_local` artifact results over `ExecutePlan`. @@ -730,12 +733,19 @@ sequenceDiagram Put-path rules: 1. Source-side batching is grouped by remote home bucket first; a transport pack must not mix different home daemons. -2. The source daemon must not expose caller-accessible local regions directly as a remote fetch source after the local ingress RPC - returns. -3. Before issuing a batch capability on the put path, the source daemon must adopt or stage the bytes into daemon-owned - transport state. -4. The current implementation realizes transport state as one contiguous staged slab per pack. The wire contract does - not depend on that choice, but the live controller and broker paths do. +2. The source daemon must not expose caller-accessible local regions as a + remote fetch source unless it owns a transport-lifetime lease, pin, or + equivalent keepalive that survives the local ingress RPC and prevents the + exported bytes from being recycled before transport release. +3. Before issuing a batch capability on the put path, the source daemon must + adopt the bytes into daemon-owned transport-lifetime state. That state may + be a staged copy, or for an eligible `HOST_SHARED` no-pack path it may be + the original region span plus region lease, slot pin or generation token, + communicator export registration, and cleanup ownership. +4. The current implementation realizes eligible `HOST_SHARED` source-layout + packs as segmented communicator exports and keeps the contiguous staged slab + as the fallback realization. The wire contract does not depend on either + physical choice. 5. The home daemon still verifies each item's invariant and still installs routed join truth per artifact. 6. The home-daemon consume path is transport-dependent. Eligible remote RDMA `communicator_source` sources are consumed as per-item remote `SourceSlice` @@ -970,6 +980,174 @@ Observability rules: layout-and-size items drops toward zero because per-item `SeekableSourceLoader` staging is bypassed. +#### 5.5.6c Accepted put-side source no-pack segmented export + +After composite final-body staging, the remaining batch-set transport copy is +on the requester/source daemon: `BatchPutIfAbsentFromRegion` stages each +remote-home source-layout item into a transient forwarding body and then packs +those bodies into one contiguous host slab before issuing the +`communicator_source`. The accepted Step-3 realization removes that put-side +full-pack slab for eligible `HOST_SHARED` source-layout batches. + +Scope: + +1. This is a source-daemon optimization for remote-home + `BatchPutIfAbsentFromRegion`. It does not change `HomeBatchPutIfAbsent`, + routed authority, shard fencing, per-item outcomes, or the Phase-9 + composite final-body target path. +2. The home daemon still receives the same `BatchPayloadTransport` shape: a + manifest plus a `v2 communicator_source` opened with + `PAYLOAD_REF_DIRECTION_PUT`. The physical source may be a segmented export + over original source-layout spans rather than a staged slab. +3. Same-home puts, inline payloads, per-item `payload_ref` inputs, VRAM source + layouts, and transports that cannot prove source lifetime keep the existing + staging and fallback behavior. + +Logical planning semantics: + +1. Logical pack planning remains manifest-first. Each admitted item contributes + exactly one `BatchPayloadManifest` entry with the same `{artifact_id, + offset, length, digest}` contract used by staged packs. +2. Manifest offsets define a concatenated logical pack byte space. They do not + require the source daemon to allocate one contiguous physical slab. +3. The first no-pack cut should admit `LAYOUT_AND_SIZE_ONLY` source-layout + items. Modes that require the source daemon to compute or prove a payload + digest before home publication remain on the staged path until streaming or + post-write digest support is designed. +4. Pack grouping remains per remote home daemon and per shard task. A no-pack + pack must not mix different home daemons, mixed source kinds, or items whose + lifetime cannot be held by the same transport-lifetime contract. + +Eligibility: + +1. The target home is remote from the source daemon. +2. The peer advertises `v2` batch transport and segmented communicator export; + the local `PayloadTransportBroker` has segmented communicator export + enabled; and the local producer endpoint has a usable P2P port. +3. Every admitted entry comes directly from the validated `source_layout`. + Entries that already require inline bytes, per-item `payload_ref`, or a + transient `BodyHandle` are excluded from the no-pack group. +4. The source layout is pure CPU `HOST_SHARED`. `VRAM`, mixed memory kinds, and + non-exportable host mappings fall back before transport issue. +5. Each source-layout item resolves to a bounded exportable span whose length + exactly matches the manifest entry and the put invariant byte length. +6. Allocator-backed `HOST_SHARED` regions require `slot_index` and + `slot_generation` on every admitted offset, plus a successful slot pin or + equivalent generation-validated keepalive that prevents slot reuse until the + batch transport is safe to release. Missing, stale, or unpinnable tokens are + pre-issue fallback reasons. +7. Scratch-slab `HOST_SHARED` regions require a daemon-managed region lease + and an ownership or immutability proof that the source bytes cannot be + rewritten before transport release. If the current caller contract cannot + prove that lifetime, the controller must use the staged path. +8. Segment count, payload bytes, control-plane size, and remote-key count must + stay within the existing batch payload limits and any communicator-specific + descriptor budget. + +Source export seam: + +1. `ByteArtifactRegionLayout` should expose a region-source export view, or an + equivalent `SegmentExportView` abstraction shared with `BodyExportView`, + rather than forcing `BatchPutIfAbsentFromRegion` to create transient + forwarding bodies. +2. A region-source export view must carry the CPU address range, length, + `HOST_SHARED` class, region lease or attach keepalive, slot token and pin + when applicable, memory location, communicator export registration or a + registration request, and explicit unregister ownership. +3. `PayloadTransportBroker` should accept segmented source entries from either + retained-body export views or region-source export views and concatenate + their remote keys and buffer sizes into one logical + `BatchPayloadCommunicatorSource`. +4. If the broker registers raw `HOST_SHARED` region spans for this transport, + the batch record owns those registrations and must unregister them on issue + failure or transport expiry. If a future region export view reuses a + pre-existing export lease, that lease owns unregister and the batch record + must hold only the keepalive. A first implementation should avoid mixing + owned and externally owned registrations inside one batch record. + +Home consume compatibility: + +1. `RemoteKeySource` already interprets `remote_memory_keys[]` and + `buffer_sizes[]` as a segmented logical byte space. A manifest offset maps + through that concatenation in the same way whether the source is a staged + slab or original `HOST_SHARED` segments. +2. Phase-9 home composite staging therefore does not need a new mapping model: + it still maps manifest pack offsets into unpublished final body backing + offsets and delegates materialization to the shared `0115` dataplane. +3. Same-daemon local slicing remains an implementation detail and is not the + validation target for this RDMA no-pack path. + +Fallback and failure rules: + +1. Fallback to the current staged path must happen before issuing the segmented + transport. The fallback path remains + `stage_pending_body(kTransientForward)` plus `pack_batch_payload_entries`. +2. Pre-issue fallback covers peer capability misses, local config misses, + missing P2P endpoint, non-`HOST_SHARED` source layout, mixed source kinds, + strict digest, missing or unpinnable allocator slot tokens, scratch lifetime + uncertainty, segment-budget overflow, and export-registration failure. +3. After a segmented no-pack transport has been issued, the controller must not + silently retry the same operation through a staged slab using the same + possibly mutable source bytes. The current operation should fail affected + items using existing transport or RPC failure semantics; an upper-layer + retry may submit a fresh request. +4. Transport keepalives must outlive remote reads and be released only through + the broker's batch-record expiry or an equivalent explicit release path. + +Observability rules: + +1. Eligible no-pack source realization should log + `byte_artifact.batch_put_if_absent_from_region_pack_realization` with a + distinct mode such as `mode=segmented_region_export`, `staged_slab=false`, + `source_realization_mode=source_layout_host_shared`, `host_region_class`, + `pack_count`, `item_count`, `payload_bytes`, `source_segments`, + `remote_keys`, and registration ownership. +2. Pre-issue fallback should log a bounded reason such as + `peer_lacks_segmented_export`, `local_segmented_export_disabled`, + `producer_endpoint_unavailable`, `not_source_layout`, `not_host_shared`, + `vram_source`, `slot_token_missing`, `slot_pin_unavailable`, + `slot_generation_mismatch`, `scratch_lifetime_unproven`, + `segment_budget_exceeded`, `export_registration_failed`, or + `mixed_source_kind`. +3. The put summary should separately report staged-slab pack counts and bytes + versus segmented no-pack export counts and bytes. On the intended SGLang KV + path, staged batch-set pack bytes should collapse to zero while Phase-9 home + composite logs remain present. + +Phase-10 implementation status, first cut accepted on 2026-04-27: + +1. `ByteArtifactRegionLayout::open_host_shared_source_span(...)` is the + source-layout export seam. It admits only daemon-managed `HOST_SHARED` + items with a live region keepalive, rejects zero-length or non-host spans, + and requires allocator-backed offsets to carry both `slot_index` and + `slot_generation`. There is still no separate allocator slot-pin API; the + accepted first cut relies on the validated source-layout token plus the + daemon-managed region keepalive while the synchronous put RPC is in flight. +2. `PayloadTransportBroker::issue_batch_payload_communicator_export(...)` now + accepts raw `HOST_SHARED` region-source segments, registers those CPU spans + with the communicator, stores the region keepalives in the batch transport + record, and marks the remote keys as broker-owned so expiry and failure + cleanup unregister them. +3. Remote-home `BatchPutIfAbsentFromRegion` admits `LAYOUT_AND_SIZE_ONLY` + source-layout items into this no-pack path when the peer advertises v2 + segmented communicator export and the local P2P endpoint is usable. Strict + digest modes, non-`HOST_SHARED` sources, missing tokens, capability misses, + and export failures still fall back before issue to + `stage_pending_body(kTransientForward)` plus `pack_batch_payload_entries`. +4. The home daemon path is unchanged from Phase 9: the segmented source-layout + export still appears as one logical `RemoteKeySource` pack, and + `HomeBatchPutIfAbsent` applies it through composite final-body staging and + batched direct-write. +5. Validation run + `20260427-132749_tensorcast_tp2_workers2_prompts10` showed source worker + `worker_00` emitted `364` `mode=segmented_region_export` pack realization + records for `5145` items / `21579694080` bytes, with + `remote_batch_pack_count=0` and `remote_batch_pack_bytes=0`; home worker + `worker_01` still emitted `364` + `home_batch_put_if_absent_transport_apply_summary` records with + `read_mode=batched_direct_write`, + `materialize_mode=single_source_composite`, and no full-pack mirror. + #### 5.5.7 Implemented v2 communicator-backed realization `v2 communicator_source` is the current communicator-backed realization. It moves routed byte-artifact remote transport @@ -1002,6 +1180,9 @@ Rules: 5. Current producer `v2` realizations are transport-dependent: - eligible RDMA get paths may export one logical pack as segmented retained body views, + - eligible RDMA put paths may export one logical pack as segmented + `HOST_SHARED` source-layout views once the no-pack source realization in + 5.5.6c is implemented, - MTCP-compatible and fallback paths may still realize one staged host pack. 6. Current get-side remote `v2` consume paths open one communicator source per `transport_id`: @@ -1096,14 +1277,17 @@ Design intent: mandatory sink full-pack mirror on eligible paths, - put-side `HomeBatchPutIfAbsent` now removes the same mandatory home-daemon full-pack mirror for eligible RDMA `communicator_source` transports while - keeping source-side staged-slab pack construction in place, -- put-side `HomeBatchPutIfAbsent` may next consume the same shared `0115` - composite execution contract to batch-stage one remote pack into unpublished - final body backings for `LAYOUT_AND_SIZE_ONLY` items, while preserving - first-writer authority semantics, -- the remaining RDMA bottleneck is producer-side servicing: CPU source slices - are still copied from retained backing into pinned staged response buffers - before remote reads, + preserving source-side staged-slab pack construction, +- put-side `HomeBatchPutIfAbsent` now consumes the same shared `0115` + composite execution contract to batch-stage eligible remote packs into + unpublished final body backings for `LAYOUT_AND_SIZE_ONLY` items, while + preserving first-writer authority semantics, +- after put-side composite staging, the remaining put-side copy is the + requester/source daemon's staged full-pack slab for remote-home + source-layout batch-set transports, +- the remaining get-side RDMA bottleneck is producer-side servicing: CPU + source slices may still be copied from retained backing into pinned staged + response buffers before remote reads, - RDMA should therefore converge toward direct-readable source servicing for eligible retained CPU backings while keeping the existing shared `0115` sink and transport execution path, @@ -1146,10 +1330,12 @@ Normative rules: dataplane before authority join installation. 7. Source-side RDMA communicator export should continue to prefer no-pack-copy segmented export over daemon-owned pack slab realization. The producer may - expose one logical pack byte space by concatenating per-entry or per-backing - exported segments through `remote_memory_keys[]`, `buffer_sizes[]`, and - `total_payload_bytes`; it does not need to copy those bytes into one - daemon-owned slab first. + expose one logical pack byte space by concatenating per-entry or + per-backing exported segments through `remote_memory_keys[]`, + `buffer_sizes[]`, and `total_payload_bytes`; it does not need to copy those + bytes into one daemon-owned slab first. On the put path, this rule applies + only when the source daemon can hold the original `HOST_SHARED` + source-layout bytes with transport-lifetime leases or pins. 8. The next RDMA get-side follow-on is source servicing, not a new sink API. Eligible retained CPU backings should be served as direct-readable source segments in the read-plan response instead of first being copied into pinned staged @@ -1163,44 +1349,48 @@ Normative rules: staging-driven. They must not consume `FlowCreditLedger`, `StageLease`, or staged ACK-release semantics, and they must not be split merely because staged `buffers_per_flow` credit is exhausted. -10. Direct-source window sizing should instead be bounded by descriptor/control +11. Direct-source window sizing should instead be bounded by descriptor/control limits such as segment count, control payload bytes, and request budgeting. The accepted benchmark target is that one routed transport's `32` source segments fit in one direct-source response window and one sink `read_multi()` call, even if transport realization still posts many RDMA WRs internally. -11. RDMA producer hot paths may optionally retain a publish-time export-view +12. RDMA producer hot paths may optionally retain a publish-time export-view keepalive keyed by backing identity as an optimization hint for later direct-source servicing. This cache is RDMA-only, best-effort, and advisory: missing, expired, or invalidated retained exports must never change authority truth or manifest semantics, and they must not disable request- time export acquisition or staged fallback. -12. Publish-time retained exports must live outside `BackingRecord` snapshots. +13. Publish-time retained exports must live outside `BackingRecord` snapshots. Snapshot copies of backing metadata must not silently extend export lifetime; source-side preregistration is a separate bounded cache over previously acquired `BodyHandle` export views. -13. Publish-time retained-export cache lifetime must be explicitly bounded by +14. Publish-time retained-export cache lifetime must be explicitly bounded by TTL and live-entry/live-byte budgets, and it must be invalidated on backing lifecycle changes such as invalidation, rebind, prune, or replacement. -14. RDMA zero-copy in this design means "no mandatory pack copy, no mandatory +15. RDMA zero-copy in this design means "no mandatory pack copy, no mandatory source-side staging copy, and no mandatory full-pack mirror" when direct- write source and target paths exist. It does not remove item-scoped digests, item-scoped lowering, or per-item success and failure outcomes. -15. If a candidate pack or item cannot produce the required export view, cannot satisfy lifetime requirements, or +16. If a candidate pack or item cannot produce the required export view, cannot satisfy lifetime requirements, or resolves to a non-direct-write transport, the daemon may fall back per pack to the staged contiguous realization and the existing `grpc_chunk_ref` or staged `communicator_source` paths. MTCP-validated behavior must remain available. -16. The intended implementation order is: +17. The intended implementation order is: - sink-side no-mirror consume path first, because the lower dataplane already supports direct-write remote sources, - `BodyHandle` export-view API second, - source-side no-pack segmented communicator export third, - source-side direct-readable RDMA servicing fourth, - put-side home no-mirror consume as the first batch-set parity step, + - put-side composite final-body staging as the second batch-set parity + step, + - put-side source-layout no-pack segmented export as the third batch-set + parity step, - and source-side publish-time retained-export warming as an optional follow-on optimization on top of the same `BodyHandle` seam. -17. Session-scoped staging reuse remains complementary follow-on work after this transport-specific split. It must not +18. Session-scoped staging reuse remains complementary follow-on work after this transport-specific split. It must not be used as a reason to keep RDMA on the forced pack-plus-mirror path. -18. Shared composite direct-write and routed vectored pull semantics are +19. Shared composite direct-write and routed vectored pull semantics are defined by `0115`. `0087` owns byte-artifact authority plus the consumer- side and producer-side realization rules that decide whether a routed pack is staged or direct-readable; it does not own a transport-private RDMA sink @@ -1373,8 +1563,13 @@ Normative rules: 10. Slot tokens are caller-owned lifetime labels. TensorCast validates their presence and echoes them in per-item outcomes, but slot allocation, recycling, and stale-completion filtering remain caller responsibilities. -11. Host pinning is an optional performance policy on the caller mapping. It is - not part of `HOST_SHARED` correctness. + A remote no-pack put export is stricter: the source daemon must also hold a + transport-lifetime slot pin or equivalent generation-validated keepalive, + otherwise it must stage the bytes before export. +11. Host pinning is an optional performance policy for ordinary local + `HOST_SHARED` access. RDMA no-pack export still needs an explicit + communicator registration or reusable export lease for the exported span, + and that registration must have clear cleanup ownership. 12. `MaterializeIntoMappedTarget` remains a separate mapped-target path and currently rejects `HOST_SHARED` target layouts. `HOST_SHARED` direct-write is currently defined for byte-artifact batch ingress, not the generic From 36475ba8d7171b137964dfb642f6501523541f56 Mon Sep 17 00:00:00 2001 From: zhouyuhan Date: Mon, 27 Apr 2026 18:24:06 +0800 Subject: [PATCH 05/49] feat(byte-artifact): reuse stable HOST_SHARED MRs for put source views - Thread activated HOST_SHARED stable-backing refs and keepalives through source spans and segmented put-side communicator exports. - Add stable-backed lightweight source-view keys that preserve the existing BatchPayloadCommunicatorSource schema while splitting view-key cleanup from stable chunk MR ownership. - Teach READ_PLAN_REQUEST admission to resolve stable-backed source views on the requested rail through StableLocalBackingState::ensure_chunk(), bypassing broker-owned raw tensor registration. - Validate with targeted daemon/communicator tests, tensorcast daemon build, and the share-remote SGLang replay showing stable_backing_view exports with cached-token hits intact --- core/communicator/engine/engine.cc | 177 +++++++++++++++++- core/communicator/engine/engine.h | 38 ++++ core/communicator/engine/rdma_engine_test.cc | 129 +++++++++++++ .../communicator/engine/rdma_stage_fn_test.cc | 2 + .../store/components/communication_manager.cc | 63 +++++++ core/store/components/communication_manager.h | 10 + daemon/service/byte_artifact_region_layout.cc | 6 + daemon/service/byte_artifact_region_layout.h | 2 + ...artifact_region_layout_host_shared_test.cc | 6 + .../controllers/byte_artifact_controller.cc | 10 +- daemon/service/payload_transport_broker.cc | 62 +++++- daemon/service/payload_transport_broker.h | 5 + ...e-and-routed-byte-artifact-architecture.md | 157 +++++++++++++++- 13 files changed, 654 insertions(+), 13 deletions(-) diff --git a/core/communicator/engine/engine.cc b/core/communicator/engine/engine.cc index 7b156603..4e1cb827 100644 --- a/core/communicator/engine/engine.cc +++ b/core/communicator/engine/engine.cc @@ -1013,15 +1013,23 @@ StagingWindow::StageFn MakeStageFunction( v1::RdmaConfig::StagedRdmaBackend staged_backend, bool use_direct, ibv_mr* direct_mr, + std::shared_ptr direct_keepalive, std::shared_ptr source_stage_profile) { if (use_direct) { const uint64_t base_addr = tensor->get_uint64_addr(); const uint64_t tensor_bytes = tensor->get_bytes(); - return [tensor, ledger, request_key = std::move(request_key), base_addr, tensor_bytes, direct_mr]( + return [tensor, + ledger, + request_key = std::move(request_key), + base_addr, + tensor_bytes, + direct_mr, + direct_keepalive = std::move(direct_keepalive)]( uint64_t offset, uint32_t bytes, uint32_t /*segment_idx*/) -> absl::StatusOr { if (offset + bytes > tensor_bytes) { return absl::OutOfRangeError("direct RDMA stage exceeded tensor bounds"); } + (void)direct_keepalive; void* device_ptr = reinterpret_cast(base_addr + offset); StageLease::Metadata metadata; metadata.transport = StageTransport::kRdma; @@ -1956,6 +1964,72 @@ absl::Status Communicator::wait_for_tensor_reads_to_drain(const std::string& ten return absl::OkStatus(); } +std::shared_ptr Communicator::lookup_stable_local_backing_source_view( + const std::string& tensor_key) const { + absl::MutexLock lock(&stable_source_views_mu_); + auto it = stable_source_views_.find(tensor_key); + if (it == stable_source_views_.end()) { + return nullptr; + } + return it->second; +} + +absl::StatusOr Communicator::ensure_stable_local_backing_source_view_mr( + const std::string& tensor_key, + const std::shared_ptr& view, + const net_dev_t& dev) { + if (view == nullptr) { + return absl::InvalidArgumentError("stable-backed source view is required"); + } + if (dev == nullptr) { + return absl::InvalidArgumentError("stable-backed source view requires an RDMA device"); + } + std::shared_ptr state; + { + absl::MutexLock lock(&stable_local_backings_mu_); + auto it = stable_local_backings_.find(view->backing.backing_id); + if (it != stable_local_backings_.end()) { + state = it->second; + } + } + if (state == nullptr) { + return absl::FailedPreconditionError("stable-backed source view backing is not active"); + } + auto stable_use = state->acquire_use(); + if (stable_use == nullptr) { + return absl::FailedPreconditionError("stable-backed source view backing is retiring"); + } + + const uint32_t chunk_slots = std::max(1, stable_local_mr_reuse_chunk_slots_); + const bool prewarm_requested = state->prewarm_requested_enabled(); + const bool prewarm_complete = state->prewarm_complete_for_test(); + auto chunk_or = + state->ensure_chunk(dev, dev->get_rail_id(), view->backing.slot_bytes, chunk_slots, view->addr, view->bytes); + if (!chunk_or.ok()) { + return chunk_or.status(); + } + if (chunk_or->mr == nullptr) { + return absl::InternalError("stable-backed source view chunk MR is null"); + } + LOG(INFO) << "communicator.read_plan_stable_source_prepare" + << " tensor_key=" << tensor_key << " backing_id=" << view->backing.backing_id + << " rail_id=" << chunk_or->rail_id << " nic=" << chunk_or->nic_name + << " chunk_index=" << chunk_or->chunk_index << " cache_hit=" << chunk_or->cache_hit + << " waited_on_inflight=" << chunk_or->waited_on_inflight << " registered_now=" << chunk_or->registered_now + << " prewarm_requested=" << prewarm_requested << " prewarm_complete=" << prewarm_complete; + return StableSourceViewMrEnsureResult{ + .mr = chunk_or->mr, + .backing_use = stable_use, + .backing_id = view->backing.backing_id, + .chunk_index = chunk_or->chunk_index, + .cache_hit = chunk_or->cache_hit, + .waited_on_inflight = chunk_or->waited_on_inflight, + .registered_now = chunk_or->registered_now, + .prewarm_requested = prewarm_requested, + .prewarm_complete = prewarm_complete, + }; +} + std::shared_ptr Communicator::create_transfer_progress_state( std::string transfer_id, std::string request_key, @@ -2729,6 +2803,7 @@ absl::StatusOr> Communicator::admit_read_plan_r v1::RdmaConfig::StagedRdmaBackend staged_backend = v1::RdmaConfig::STAGED_RDMA_BACKEND_HOST_PINNED; bool direct_eligible = false; ibv_mr* direct_mr = nullptr; + std::shared_ptr direct_keepalive; }; std::vector resolved_slices; @@ -2752,6 +2827,7 @@ absl::StatusOr> Communicator::admit_read_plan_r } read_guards->push_back(*read_guard_or); tensor->wait_read_ready(); + auto stable_source_view = lookup_stable_local_backing_source_view(tensor_key); auto dev = task.request.rail_id >= 0 ? tensor->get_dev_by_rail(task.request.rail_id) : nullptr; if (dev == nullptr) { @@ -2762,6 +2838,15 @@ absl::StatusOr> Communicator::admit_read_plan_r } } const bool tensor_on_cpu = tensor->get_mem_type() == COMMUNICATE_ENGINE_DEV_CPU; + if (dev == nullptr && tensor_on_cpu && stable_source_view != nullptr) { + const int requested_rail = task.request.rail_id >= 0 + ? task.request.rail_id + : (session->dev != nullptr ? session->dev->get_rail_id() : -1); + dev = get_net_dev(COMMUNICATE_ENGINE_DEV_CPU, 0, tensor_key, requested_rail); + if (dev != nullptr) { + tensor->add_dev(dev); + } + } if (tensor_on_cpu && session->dev != nullptr && (dev == nullptr || session->dev->get_name() != dev->get_name() || session->dev->get_rail_id() != dev->get_rail_id())) { @@ -2770,6 +2855,9 @@ absl::StatusOr> Communicator::admit_read_plan_r << " selected_nic=" << (dev != nullptr ? dev->get_name() : "") << " session_rail=" << session->dev->get_rail_id() << " session_nic=" << session->dev->get_name(); dev = session->dev; + if (stable_source_view != nullptr && tensor->get_dev_by_rail(dev->get_rail_id()) == nullptr) { + tensor->add_dev(dev); + } } if (dev == nullptr) { return absl::FailedPreconditionError(absl::StrCat("read_plan tensor missing RDMA device: ", tensor_key)); @@ -2805,6 +2893,7 @@ absl::StatusOr> Communicator::admit_read_plan_r const bool direct_requested = tensor->direct_rdma_enabled(); bool direct_eligible = false; ibv_mr* direct_mr = nullptr; + std::shared_ptr direct_keepalive; DirectFallbackReason fallback_reason = DirectFallbackReason::kNone; if (direct_requested) { const bool direct_mem_supported = tensor_on_cpu || tensor->get_mem_type() == COMMUNICATE_ENGINE_DEV_GPU; @@ -2813,7 +2902,18 @@ absl::StatusOr> Communicator::admit_read_plan_r } else if (tensor->needs_staging()) { fallback_reason = DirectFallbackReason::kNeedsStaging; } else { - if (tensor_on_cpu) { + if (tensor_on_cpu && stable_source_view != nullptr) { + auto ensure_result_or = ensure_stable_local_backing_source_view_mr(tensor_key, stable_source_view, dev); + if (!ensure_result_or.ok()) { + return ensure_result_or.status(); + } + direct_mr = ensure_result_or->mr; + direct_keepalive = std::move(ensure_result_or->backing_use); + direct_eligible = direct_mr != nullptr; + if (!direct_eligible) { + fallback_reason = DirectFallbackReason::kMrUnavailable; + } + } else if (tensor_on_cpu) { auto ensure_result_or = ensure_tensor_registered_on_dev(tensor, dev); if (!ensure_result_or.ok()) { fallback_reason = DirectFallbackReason::kMrUnavailable; @@ -2863,6 +2963,7 @@ absl::StatusOr> Communicator::admit_read_plan_r .staged_backend = staged_backend, .direct_eligible = direct_eligible, .direct_mr = direct_mr, + .direct_keepalive = direct_keepalive, }); if (source.bytes > std::numeric_limits::max() - total_bytes) { return absl::InvalidArgumentError("read_plan source byte count overflow"); @@ -2927,6 +3028,7 @@ absl::StatusOr> Communicator::admit_read_plan_r source.staged_backend, use_direct, source.direct_mr, + source.direct_keepalive, session->source_stage_profile), .chunk_size = chunk_size, .zero_copy = use_direct, @@ -3274,6 +3376,72 @@ absl::Status Communicator::deactivate_stable_local_backing(std::string_view back return absl::OkStatus(); } +absl::Status Communicator::register_stable_local_backing_source_view(const StableLocalBackingSourceView& view) { + if (!enable_rdma_ || rdma_context_ == nullptr) { + return absl::FailedPreconditionError("stable-backed source views require RDMA"); + } + if (view.tensor_key.empty()) { + return absl::InvalidArgumentError("stable-backed source view requires tensor_key"); + } + if (view.addr == 0 || view.bytes == 0) { + return absl::InvalidArgumentError("stable-backed source view requires non-empty address range"); + } + if (view.backing.kind != StableLocalBackingKind::kHostSharedRegion || view.backing.backing_id.empty() || + view.backing.backing_base_addr == 0 || view.backing.backing_bytes == 0 || + view.backing.dev_type != COMMUNICATE_ENGINE_DEV_CPU || view.backing.slot_bytes == 0) { + return absl::InvalidArgumentError("stable-backed source view requires CPU HOST_SHARED backing with slot geometry"); + } + if (view.keepalive == nullptr) { + return absl::InvalidArgumentError("stable-backed source view requires keepalive"); + } + + std::shared_ptr state; + { + absl::MutexLock lock(&stable_local_backings_mu_); + auto it = stable_local_backings_.find(view.backing.backing_id); + if (it != stable_local_backings_.end()) { + state = it->second; + } + } + if (state == nullptr) { + return absl::FailedPreconditionError("stable-backed source view backing is not active"); + } + auto merge_status = state->merge_activation_backing(view.backing, nullptr); + if (!merge_status.ok()) { + return merge_status; + } + const uint32_t chunk_slots = std::max(1, stable_local_mr_reuse_chunk_slots_); + auto chunk_or = state->resolve_chunk_for_region(view.backing.slot_bytes, chunk_slots, view.addr, view.bytes); + if (!chunk_or.ok()) { + return chunk_or.status(); + } + + auto tensor = std::make_shared( + view.tensor_key, + view.addr, + view.bytes, + COMMUNICATE_ENGINE_DEV_CPU, + /*dev=*/nullptr); + tensor->set_read_ready(); + tensor->set_direct_rdma_enabled(true); + store_.register_tensor(tensor); + { + absl::MutexLock lock(&stable_source_views_mu_); + stable_source_views_[view.tensor_key] = + std::make_shared(StableLocalBackingSourceViewState{ + .addr = view.addr, + .bytes = view.bytes, + .backing = view.backing, + .keepalive = view.keepalive, + }); + } + VLOG(2) << "stable_local_backing.source_view_register" + << " key=" << view.tensor_key << " backing_id=" << view.backing.backing_id << " addr=0x" << std::hex + << view.addr << std::dec << " bytes=" << view.bytes << " slot_bytes=" << view.backing.slot_bytes + << " chunk_index=" << chunk_or->chunk_index << " chunk_bytes=" << chunk_or->chunk_bytes; + return absl::OkStatus(); +} + bool Communicator::stable_local_backing_supported_for_test() const { return enable_rdma_ && rdma_context_ != nullptr; } @@ -4743,6 +4911,7 @@ absl::Status Communicator::handle_rdma_read_request( staged_backend, use_direct, direct_mr, + nullptr, session->source_stage_profile); session->window = std::make_unique(*ledger_ptr, stage_fn, total_bytes, chunk_size, start_offset, window_segments); @@ -5058,6 +5227,10 @@ absl::Status Communicator::unregister_tensor(const std::string& tensor_key) { } else { store_.unregister_tensor(tensor_key); } + { + absl::MutexLock lock(&stable_source_views_mu_); + stable_source_views_.erase(tensor_key); + } { absl::MutexLock lock(&tensor_read_mu_); diff --git a/core/communicator/engine/engine.h b/core/communicator/engine/engine.h index 19350d72..2e8328fb 100644 --- a/core/communicator/engine/engine.h +++ b/core/communicator/engine/engine.h @@ -127,6 +127,16 @@ class Communicator { absl::Status deactivate_stable_local_backing(std::string_view backing_id); + struct StableLocalBackingSourceView { + std::string tensor_key; + uint64_t addr = 0; + uint64_t bytes = 0; + tensorcast::store::StableLocalBackingRef backing; + std::shared_ptr keepalive; + }; + + absl::Status register_stable_local_backing_source_view(const StableLocalBackingSourceView& view); + // Test-only stable-backing introspection. bool stable_local_backing_supported_for_test() const; bool stable_local_backing_active_for_test(std::string_view backing_id) const; @@ -294,6 +304,25 @@ class Communicator { double gate_wait_ms = 0.0; }; + struct StableSourceViewMrEnsureResult { + struct ibv_mr* mr = nullptr; + std::shared_ptr backing_use; + std::string backing_id; + uint64_t chunk_index = 0; + bool cache_hit = false; + bool waited_on_inflight = false; + bool registered_now = false; + bool prewarm_requested = false; + bool prewarm_complete = false; + }; + + struct StableLocalBackingSourceViewState { + uint64_t addr = 0; + uint64_t bytes = 0; + tensorcast::store::StableLocalBackingRef backing; + std::shared_ptr keepalive; + }; + struct TensorReadState { int inflight = 0; bool retiring = false; @@ -303,6 +332,12 @@ class Communicator { absl::StatusOr> acquire_tensor_read_lease(const std::string& tensor_key); void release_tensor_read_lease(const std::string& tensor_key); absl::Status wait_for_tensor_reads_to_drain(const std::string& tensor_key, absl::Duration timeout); + std::shared_ptr lookup_stable_local_backing_source_view( + const std::string& tensor_key) const; + absl::StatusOr ensure_stable_local_backing_source_view_mr( + const std::string& tensor_key, + const std::shared_ptr& view, + const transport::net_dev_t& dev); struct MtcpReadTask { channel_t channel; @@ -457,6 +492,9 @@ class Communicator { mutable absl::Mutex stable_local_backings_mu_; absl::flat_hash_map> stable_local_backings_ ABSL_GUARDED_BY(stable_local_backings_mu_); + mutable absl::Mutex stable_source_views_mu_; + absl::flat_hash_map> stable_source_views_ + ABSL_GUARDED_BY(stable_source_views_mu_); // Serialize channel creation to avoid duplicate control connections to same peer mutable absl::Mutex create_channel_mu_; diff --git a/core/communicator/engine/rdma_engine_test.cc b/core/communicator/engine/rdma_engine_test.cc index 4fb11ed9..3f1c9f7a 100644 --- a/core/communicator/engine/rdma_engine_test.cc +++ b/core/communicator/engine/rdma_engine_test.cc @@ -1684,6 +1684,135 @@ TEST_CASE( CommunicatorTestPeer::stop_workers(server); } +TEST_CASE( + "READ_PLAN_REQUEST stable-backed CPU source uses requested-rail chunk MR without raw source registration", + "[rdma][communicator][read_plan][stable_backing]") { + using tensorcast::communicator::base::CHANNEL_RDMA; + using tensorcast::communicator::base::COMMUNICATE_ENGINE_DEV_CPU; + + auto cfg = tensorcast::testing::make_tcp_communicator_config(/*enable_rdma=*/true); + cfg.mutable_rdma()->set_enable_stable_local_mr_reuse(true); + cfg.mutable_rdma()->set_stable_local_mr_reuse_chunk_slots(2); + auto pools = tensorcast::testing::make_test_pinned_staging_pools( + cfg.stager().buffers_per_flow(), + cfg.transport().tcp_conn_count(), + /*gpu_slice_bytes=*/(16ULL << 20), + /*cpu_slice_bytes=*/(4ULL << 20), + /*enable_rdma=*/true); + Communicator server(cfg, std::move(pools), /*channel_expire_sec=*/0); + if (!CommunicatorTestPeer::has_rdma_device(server)) { + CommunicatorTestPeer::stop_workers(server); + SUCCEED("Skipping stable-backed source READ_PLAN_REQUEST test: no RDMA net devices available"); + return; + } + + auto net_dev = CommunicatorTestPeer::rdma_context(server)->get_best_dev( + COMMUNICATE_ENGINE_DEV_CPU, -1, -1, "stable-source-view"); + REQUIRE(net_dev != nullptr); + + std::array buffer{}; + const uint64_t slot_addr = reinterpret_cast(buffer.data()) + 64; + tensorcast::store::StableLocalBackingRef backing{ + .kind = tensorcast::store::StableLocalBackingKind::kHostSharedRegion, + .backing_id = "region:test-stable-source-view", + .backing_base_addr = reinterpret_cast(buffer.data()), + .backing_bytes = static_cast(buffer.size()), + .slot_bytes = 64, + .dev_type = COMMUNICATE_ENGINE_DEV_CPU, + .dev_id = 0, + }; + REQUIRE(server.activate_stable_local_backing(backing, std::make_shared(1)).ok()); + REQUIRE( + CommunicatorTestPeer::stable_local_backing_chunk_count(server, backing.backing_id, net_dev->get_rail_id()) == 0); + + const std::string view_key = "stable-source-view-key"; + Communicator::StableLocalBackingSourceView source_view{ + .tensor_key = view_key, + .addr = slot_addr, + .bytes = 64, + .backing = backing, + .keepalive = std::make_shared(2), + }; + REQUIRE(server.register_stable_local_backing_source_view(source_view).ok()); + + int raw_lazy_events = 0; + ScopedLazySourceMrHook hook([&](communicator::engine::CommunicatorTestPeer::LazySourceMrTestEvent /*event*/, + std::string_view tensor_key, + int16_t /*rail_id*/) { + if (tensor_key == view_key) { + ++raw_lazy_events; + } + }); + + int sv[2]; + REQUIRE(::socketpair(AF_UNIX, SOCK_STREAM, 0, sv) == 0); + timeval timeout{}; + timeout.tv_sec = 2; + require_recv_timeout_or_env_restriction(sv[1], timeout); + + auto control_ctx = std::make_shared(); + struct sockaddr_in remote_addr{}; + remote_addr.sin_family = AF_INET; + remote_addr.sin_port = htons(65016); + remote_addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + auto control_transport = + std::make_shared(control_ctx.get(), sv[0], remote_addr); + auto channel = std::make_shared( + control_transport, + CHANNEL_RDMA, + /*buffers_per_flow=*/1, + /*max_window_segments=*/1); + auto flow_state = channel->flow_state(); + REQUIRE(flow_state != nullptr); + + const uint32_t payload_size = sizeof(ProtoReadPlanRequestHeader) + sizeof(ProtoReadPlanSourceSlice); + auto request = std::make_shared(ENGINE_OP_READ_PLAN_REQUEST, payload_size); + auto* hdr = request->get_payload(); + hdr->transport_type = ENGINE_TRANSPORT_RDMA; + hdr->rail_id = net_dev->get_rail_id(); + hdr->request_id = 97; + hdr->num_source_slices = 1; + auto* slice = + reinterpret_cast(reinterpret_cast(hdr) + sizeof(ProtoReadPlanRequestHeader)); + communicator::misc::STRNCPY(slice->tensor_key, view_key, kMaxTensorNameLen); + slice->source_slice_index = 0; + slice->remote_offset = 0; + slice->bytes = 64; + + auto status = CommunicatorTestPeer::on_receive_request(server, channel, control_transport, request); + REQUIRE(status == tensorcast::communicator::misc::SUCCESS); + + ProtoHeader response_header{}; + REQUIRE( + ::recv(sv[1], &response_header, sizeof(response_header), MSG_WAITALL) == + static_cast(sizeof(response_header))); + REQUIRE(response_header.op == ENGINE_OP_READ_PLAN_RESPONSE_EX); + REQUIRE(response_header.size >= sizeof(ProtoReadPlanResponseExHeader)); + + std::vector payload(response_header.size); + REQUIRE(::recv(sv[1], payload.data(), payload.size(), MSG_WAITALL) == static_cast(payload.size())); + auto* response = reinterpret_cast(payload.data()); + REQUIRE(response->request_id == 97); + REQUIRE(response->staged == 0); + REQUIRE(response->zero_copy == 1); + REQUIRE(response->rail_id == net_dev->get_rail_id()); + REQUIRE(response->num_segments == 1); + const auto* seg = reinterpret_cast( + reinterpret_cast(response) + sizeof(ProtoReadPlanResponseExHeader)); + REQUIRE(seg->source_slice_index == 0); + REQUIRE(seg->addr == slot_addr); + REQUIRE(seg->bytes == 64); + REQUIRE(seg->rkey != 0); + CHECK(raw_lazy_events == 0); + CHECK( + CommunicatorTestPeer::stable_local_backing_chunk_count(server, backing.backing_id, net_dev->get_rail_id()) == 1); + + REQUIRE(server.unregister_tensor(view_key).ok()); + REQUIRE(server.deactivate_stable_local_backing(backing.backing_id).ok()); + ::close(sv[1]); + CommunicatorTestPeer::stop_workers(server); +} + TEST_CASE( "CPU direct-source lazy MR gate coalesces same-tensor concurrent admissions", "[rdma][communicator][read_plan]") { diff --git a/core/communicator/engine/rdma_stage_fn_test.cc b/core/communicator/engine/rdma_stage_fn_test.cc index a707f84f..b5053526 100644 --- a/core/communicator/engine/rdma_stage_fn_test.cc +++ b/core/communicator/engine/rdma_stage_fn_test.cc @@ -29,6 +29,7 @@ StagingWindow::StageFn MakeStageFunction( v1::RdmaConfig::StagedRdmaBackend staged_backend, bool use_direct, ::ibv_mr* direct_mr, + std::shared_ptr direct_keepalive, std::shared_ptr source_stage_profile); TEST_CASE("MakeStageFunction guards null stager", "[communicator][rdma]") { @@ -45,6 +46,7 @@ TEST_CASE("MakeStageFunction guards null stager", "[communicator][rdma]") { /*staged_backend=*/v1::RdmaConfig::STAGED_RDMA_BACKEND_HOST_PINNED, /*use_direct=*/false, /*direct_mr=*/nullptr, + /*direct_keepalive=*/nullptr, /*source_stage_profile=*/nullptr); auto lease_or = stage_fn(/*offset=*/0, /*bytes=*/1, /*segment_idx=*/0); diff --git a/core/store/components/communication_manager.cc b/core/store/components/communication_manager.cc index d13227a7..258a5122 100644 --- a/core/store/components/communication_manager.cc +++ b/core/store/components/communication_manager.cc @@ -187,6 +187,69 @@ absl::StatusOr CommunicationManager::register_memory( return info; } +absl::StatusOr CommunicationManager::register_stable_local_backing_source_views( + const std::vector& views) { + if (!is_enabled()) { + return absl::FailedPreconditionError("Communication engine not initialized"); + } + if (views.empty()) { + return absl::InvalidArgumentError("stable-backed source view registration requires at least one view"); + } + + const uint64_t registration_id = next_registration_id_.fetch_add(1, std::memory_order_relaxed); + ExportRegistration info; + info.location = common::memory::MemoryLocation::CPU; + info.device_id = -1; + info.comm_dev_type = communicator::base::COMMUNICATE_ENGINE_DEV_CPU; + info.buffer_addresses.reserve(views.size()); + info.buffer_sizes.reserve(views.size()); + info.remote_memory_keys.reserve(views.size()); + + std::vector registered_keys; + registered_keys.reserve(views.size()); + const auto cleanup_registered_keys = [&]() { + for (const auto& key : registered_keys) { + auto status = comm_engine_->unregister_tensor(key); + if (!status.ok()) { + LOG(WARNING) << "stable-backed source view cleanup failed" + << " key=" << key << " status=" << status; + } + } + }; + for (size_t index = 0; index < views.size(); ++index) { + const auto& view = views[index]; + if (view.address == 0 || view.size_bytes == 0) { + cleanup_registered_keys(); + return absl::InvalidArgumentError("stable-backed source view requires non-empty address range"); + } + if (view.keepalive == nullptr) { + cleanup_registered_keys(); + return absl::InvalidArgumentError("stable-backed source view requires keepalive"); + } + + const std::string key = absl::StrCat("stable_backing_view_", registration_id, "_", index, "_", view.address); + communicator::engine::Communicator::StableLocalBackingSourceView engine_view{ + .tensor_key = key, + .addr = view.address, + .bytes = static_cast(view.size_bytes), + .backing = view.backing, + .keepalive = view.keepalive, + }; + auto status = comm_engine_->register_stable_local_backing_source_view(engine_view); + if (!status.ok()) { + cleanup_registered_keys(); + return status; + } + registered_keys.push_back(key); + info.buffer_addresses.push_back(view.address); + info.buffer_sizes.push_back(view.size_bytes); + info.remote_memory_keys.push_back(key); + info.artifact_size += view.size_bytes; + } + + return info; +} + absl::Status CommunicationManager::activate_stable_local_backing( const store::StableLocalBackingRef& backing, std::shared_ptr keepalive) { diff --git a/core/store/components/communication_manager.h b/core/store/components/communication_manager.h index 34687bce..a26c9bcd 100644 --- a/core/store/components/communication_manager.h +++ b/core/store/components/communication_manager.h @@ -99,6 +99,16 @@ class CommunicationManager { const std::vector& buffer_sizes, int device_id); + struct StableLocalBackingSourceView { + uint64_t address = 0; + size_t size_bytes = 0; + store::StableLocalBackingRef backing; + std::shared_ptr keepalive; + }; + + absl::StatusOr register_stable_local_backing_source_views( + const std::vector& views); + absl::Status activate_stable_local_backing( const store::StableLocalBackingRef& backing, std::shared_ptr keepalive = nullptr); diff --git a/daemon/service/byte_artifact_region_layout.cc b/daemon/service/byte_artifact_region_layout.cc index 24e96cdf..ca098ef4 100644 --- a/daemon/service/byte_artifact_region_layout.cc +++ b/daemon/service/byte_artifact_region_layout.cc @@ -404,6 +404,10 @@ absl::StatusOr ByteArtifactRegio return absl::FailedPreconditionError("HOST_SHARED source span requires region keepalive"); } const void* item_base_ptr = static_cast(storage.base_ptr) + range.storage_local_offset; + std::optional stable_backing = storage.stable_backing; + if (stable_backing.has_value() && range.slot_token.has_value() && range.logical_length > 0) { + stable_backing->slot_bytes = range.logical_length; + } return HostSharedSourceSpan{ .data = item_base_ptr, .length = range.logical_length, @@ -411,6 +415,8 @@ absl::StatusOr ByteArtifactRegio .host_region_class = storage.host_region_class, .daemon_managed = storage.daemon_managed, .slot_token = range.slot_token, + .stable_backing = stable_backing, + .stable_backing_keepalive = storage.stable_backing_keepalive, .keepalive = storage.keepalive, }; } diff --git a/daemon/service/byte_artifact_region_layout.h b/daemon/service/byte_artifact_region_layout.h index 937147b8..782b479d 100644 --- a/daemon/service/byte_artifact_region_layout.h +++ b/daemon/service/byte_artifact_region_layout.h @@ -35,6 +35,8 @@ class ByteArtifactRegionLayout { IpcRegionRegistry::HostRegionClass host_region_class{IpcRegionRegistry::HostRegionClass::kNone}; bool daemon_managed{false}; std::optional slot_token; + std::optional stable_backing; + std::shared_ptr stable_backing_keepalive; std::shared_ptr keepalive; }; diff --git a/daemon/service/byte_artifact_region_layout_host_shared_test.cc b/daemon/service/byte_artifact_region_layout_host_shared_test.cc index c3a34056..545f0420 100644 --- a/daemon/service/byte_artifact_region_layout_host_shared_test.cc +++ b/daemon/service/byte_artifact_region_layout_host_shared_test.cc @@ -125,6 +125,12 @@ TEST_CASE( REQUIRE(source_span_or->slot_token->slot_generation.has_value()); REQUIRE(*source_span_or->slot_token->slot_index == 7); REQUIRE(*source_span_or->slot_token->slot_generation == 11); + REQUIRE(source_span_or->stable_backing.has_value()); + CHECK(source_span_or->stable_backing->backing_id == desc_or->region_id); + CHECK(source_span_or->stable_backing->backing_base_addr == reinterpret_cast(mapping_or->base_ptr)); + CHECK(source_span_or->stable_backing->backing_bytes == kStorageBytes); + CHECK(source_span_or->stable_backing->slot_bytes == kItemBytes); + REQUIRE(source_span_or->stable_backing_keepalive != nullptr); REQUIRE(source_span_or->keepalive != nullptr); REQUIRE(std::memcmp(source_span_or->data, buffer, sizeof(buffer)) == 0); } diff --git a/daemon/service/controllers/byte_artifact_controller.cc b/daemon/service/controllers/byte_artifact_controller.cc index 30956d9c..34b3efe1 100644 --- a/daemon/service/controllers/byte_artifact_controller.cc +++ b/daemon/service/controllers/byte_artifact_controller.cc @@ -920,6 +920,8 @@ acquire_segmented_region_source_segments( PayloadTransportBroker::BatchCommunicatorRegionSourceSegment{ .data = entry.source_span.data, .size_bytes = entry.source_span.length, + .stable_backing = entry.source_span.stable_backing, + .stable_backing_keepalive = entry.source_span.stable_backing_keepalive, .keepalive = entry.source_span.keepalive, }); } @@ -5695,12 +5697,16 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( << " holder_daemon_id=" << task.route.holder_daemon_id << " remote=true" << " mode=segmented_region_export" << " staged_slab=false" - << " source_realization_mode=source_layout_host_shared" + << " source_realization_mode=" + << (communicator_export_or->broker_owned_register ? "source_layout_host_shared" + : "source_layout_host_shared_stable_backing") << " host_region_class=" << host_region_class << " pack_count=1" << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size() << " source_segments=" << source_segments_or->size() << " remote_keys=" << communicator_export_or->export_registration.remote_memory_keys.size() - << " registration_ownership=broker_owned"; + << " registration_ownership=" << communicator_export_or->registration_ownership + << " mr_ownership=" << communicator_export_or->mr_ownership + << " broker_owned_register=" << communicator_export_or->broker_owned_register; LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_transport_emit" << " operation_id=" << operation_id << " shard_id=" << task.shard_id diff --git a/daemon/service/payload_transport_broker.cc b/daemon/service/payload_transport_broker.cc index 5d20314b..6e450f58 100644 --- a/daemon/service/payload_transport_broker.cc +++ b/daemon/service/payload_transport_broker.cc @@ -1946,9 +1946,14 @@ absl::StatusOr PayloadTransport std::vector buffer_addresses; std::vector buffer_sizes; std::vector> keepalives; + std::vector stable_source_views; + std::vector> stable_keepalives; + bool stable_backing_candidate = true; buffer_addresses.reserve(source_segments.size()); buffer_sizes.reserve(source_segments.size()); keepalives.reserve(source_segments.size()); + stable_source_views.reserve(source_segments.size()); + stable_keepalives.reserve(source_segments.size()); std::uint64_t total_payload_bytes = 0; for (int entry_index = 0; entry_index < manifest.entries_size(); ++entry_index) { @@ -1975,6 +1980,20 @@ absl::StatusOr PayloadTransport buffer_addresses.push_back(const_cast(segment.data)); buffer_sizes.push_back(static_cast(segment.size_bytes)); keepalives.push_back(segment.keepalive); + if (segment.stable_backing.has_value() && segment.stable_backing_keepalive != nullptr) { + auto combined_keepalive = std::make_shared>>( + std::vector>{segment.keepalive, segment.stable_backing_keepalive}); + stable_keepalives.push_back(combined_keepalive); + stable_source_views.push_back( + store::components::CommunicationManager::StableLocalBackingSourceView{ + .address = reinterpret_cast(segment.data), + .size_bytes = static_cast(segment.size_bytes), + .backing = *segment.stable_backing, + .keepalive = combined_keepalive, + }); + } else { + stable_backing_candidate = false; + } } if (total_payload_bytes != manifest.total_size()) { return absl::FailedPreconditionError("segmented region source payload size mismatch"); @@ -1982,7 +2001,24 @@ absl::StatusOr PayloadTransport const absl::Time export_started_at = absl::Now(); const absl::Time register_started_at = absl::Now(); - auto registration_or = options_.comm_manager->register_memory(buffer_addresses, buffer_sizes, /*device_id=*/-1); + absl::StatusOr registration_or = + absl::FailedPreconditionError("stable-backed source view export was not attempted"); + bool stable_backing_export = false; + absl::Status stable_backing_status = absl::FailedPreconditionError("stable-backed source view export not eligible"); + if (stable_backing_candidate && stable_source_views.size() == source_segments.size()) { + registration_or = options_.comm_manager->register_stable_local_backing_source_views(stable_source_views); + stable_backing_status = registration_or.ok() ? absl::OkStatus() : registration_or.status(); + stable_backing_export = registration_or.ok(); + if (!registration_or.ok()) { + LOG(INFO) << "batch_payload_ref.stable_backing_source_view_fallback" + << " direction=" << payload_direction_label(direction) << " operation_id=" << operation_id + << " source_segments=" << source_segments.size() << " payload_bytes=" << manifest.total_size() + << " status=" << registration_or.status(); + } + } + if (!registration_or.ok()) { + registration_or = options_.comm_manager->register_memory(buffer_addresses, buffer_sizes, /*device_id=*/-1); + } const absl::Duration register_elapsed = absl::Now() - register_started_at; if (!registration_or.ok()) { return registration_or.status(); @@ -1993,7 +2029,11 @@ absl::StatusOr PayloadTransport const auto unregister_registered = [&]() { for (const auto& tensor_key : registration_or->remote_memory_keys) { - (void)options_.comm_manager->get_engine().unregister_tensor(tensor_key); + auto status = options_.comm_manager->get_engine().unregister_tensor(tensor_key); + if (!status.ok()) { + LOG(WARNING) << "batch_payload_ref.segmented_region_source_unregister_failed" + << " tensor_key=" << tensor_key << " status=" << status; + } } }; @@ -2019,16 +2059,23 @@ absl::StatusOr PayloadTransport return absl::NotFoundError("segmented region source transport record is missing"); } it->second.communicator_export = *registration_or; - it->second.communicator_export_keepalives = std::move(keepalives); + it->second.communicator_export_keepalives = + stable_backing_export ? std::move(stable_keepalives) : std::move(keepalives); it->second.communicator_export_requires_unregister = true; } + const std::string_view registration_ownership = + stable_backing_export ? std::string_view("stable_backing_view") : std::string_view("broker_owned"); + const std::string_view mr_ownership = + stable_backing_export ? std::string_view("stable_backing") : std::string_view("broker_owned"); VLOG(2) << "batch_payload_ref.communicator_export_summary" - << " direction=" << payload_direction_label(direction) << " operation_id=" << operation_id - << " realization=segmented_region_source" + << " direction=" << payload_direction_label(direction) << " operation_id=" << operation_id << " realization=" + << (stable_backing_export ? "segmented_region_source_stable_backing" : "segmented_region_source") << " transport_id=" << batch_payload_ref_or->metadata.transport_id << " payload_bytes=" << batch_payload_ref_or->metadata.payload_size << " remote_keys=" << registration_or->remote_memory_keys.size() - << " source_segments=" << source_segments.size() + << " source_segments=" << source_segments.size() << " registration_ownership=" << registration_ownership + << " mr_ownership=" << mr_ownership << " stable_backing_candidate=" << stable_backing_candidate + << " stable_backing_status=" << stable_backing_status << " register_ms=" << absl::ToDoubleMilliseconds(register_elapsed) << " issue_ref_ms=" << absl::ToDoubleMilliseconds(issue_ref_elapsed) << " total_ms=" << absl::ToDoubleMilliseconds(absl::Now() - export_started_at); @@ -2036,6 +2083,9 @@ absl::StatusOr PayloadTransport .metadata = batch_payload_ref_or->metadata, .batch_payload_ref = batch_payload_ref_or->batch_payload_ref, .export_registration = *registration_or, + .registration_ownership = std::string(registration_ownership), + .mr_ownership = std::string(mr_ownership), + .broker_owned_register = !stable_backing_export, }; } diff --git a/daemon/service/payload_transport_broker.h b/daemon/service/payload_transport_broker.h index 02a87fed..2bd15201 100644 --- a/daemon/service/payload_transport_broker.h +++ b/daemon/service/payload_transport_broker.h @@ -99,6 +99,9 @@ class PayloadTransportBroker { BatchRefMetadata metadata; std::string batch_payload_ref; store::ExportRegistration export_registration; + std::string registration_ownership; + std::string mr_ownership; + bool broker_owned_register{false}; }; struct BatchCommunicatorSourceSegment { @@ -108,6 +111,8 @@ class PayloadTransportBroker { struct BatchCommunicatorRegionSourceSegment { const void* data{nullptr}; std::uint64_t size_bytes{0}; + std::optional stable_backing; + std::shared_ptr stable_backing_keepalive; std::shared_ptr keepalive; }; diff --git a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md index f803ce07..f021608c 100644 --- a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md +++ b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md @@ -4,7 +4,7 @@ title: Unified Artifact Runtime and Routed Byte Artifact Architecture status: implemented areas: ["daemon", "sdk", "global_store", "proto", "core", "integrations", "docs"] created: 2026-03-08 -last_updated: 2026-04-26 +last_updated: 2026-04-27 related_code: - docs/designs/0017-client-generated-artifact-id.md - docs/designs/0056-programmable-framework-adv.md @@ -37,9 +37,12 @@ related_code: - daemon/state/ipc_region_registry.h - daemon/state/ipc_region_registry.cc - daemon/service/payload_transport_broker.h + - daemon/service/payload_transport_broker.cc - daemon/state/daemon_kernel.cc - core/store/communication_types.h - core/store/components/communication_manager.h + - core/store/components/communication_manager.cc + - core/store/materialization/contracts/stable_local_backing.h - core/store/materialization/dataplane/sinks/target_layout_host_sink.h - core/store/materialization/dataplane/sinks/target_layout_host_sink.cc - core/store/materialization/dataplane/loaders/p2p_loader.cc @@ -133,6 +136,10 @@ RDMA realization follow-ons still in progress: segmented `communicator_source` exports over the original source spans. The contiguous staged-slab realization remains the fallback for strict digest, non-`HOST_SHARED`, capability-miss, and lifetime-gap cases. +- The next accepted put-side source optimization keeps that no-pack shape but + changes source MR ownership from per-batch broker-owned raw registrations to + stable-backed `HOST_SHARED` source-view keys whose chunk MRs are owned by + the activated stable backing and resolved on the requested RDMA rail. - `BodyHandle` now exposes the export-view API used by source-side segmented communicator export. - The remaining get-side RDMA follow-on is producer-side read-plan servicing: @@ -1148,6 +1155,146 @@ Phase-10 implementation status, first cut accepted on 2026-04-27: `read_mode=batched_direct_write`, `materialize_mode=single_source_composite`, and no full-pack mirror. +#### 5.5.6d Accepted put-side source stable-backed MR reuse + +Phase 10 removes the put-side full-pack slab, but the first-cut source-region +export still creates broker-owned raw registrations for each batch segment. +The accepted Phase 11 optimization keeps the same logical manifest and +segmented `communicator_source` wire shape while changing MR ownership to the +source daemon's activated `HOST_SHARED` stable backing. This is an MR reuse +optimization, not a new batch-set authority or digest semantic. + +Scope: + +1. This applies only to eligible remote-home + `BatchPutIfAbsentFromRegion` no-pack source-layout items whose bytes live + in daemon-managed CPU `HOST_SHARED` regions with a stable backing. +2. It does not change `HomeBatchPutIfAbsent`, shard routing, manifest + offsets, per-item outcomes, Phase-9 composite final-body staging, or + Phase-10 staged-slab fallback semantics. +3. The first accepted fallback order is: + stable-backed source-view export, then Phase-10 broker-owned raw segmented + region export when the miss is known before transport issue, then the + staged full-pack realization for non-exportable source layouts. +4. Once a stable-backed transport has been issued, the source daemon must not + silently fall back to raw per-view registration in read-plan admission. A + requested-rail stable-backed ensure failure is a transport failure for the + current operation. + +Stable backing exposure on source spans: + +1. `ByteArtifactRegionLayout::HostSharedSourceSpan` must carry the source + backing metadata needed to prove stable-backed export eligibility: + `StableLocalBackingRef`, stable-backing keepalive, region id, host-region + class, source address, length, and backing-local offset or enough + information to derive it from `backing_base_addr`. +2. Source-layout validation already activates local stable backings through + `ExternalTargetAccessService::validate_local_source_layout(...)`; Phase 11 + consumes that activated backing instead of adding a second preregistration + surface. +3. The accepted eligible backing shape is + `StableLocalBackingKind::kHostSharedRegion` over CPU memory, with the span + fully contained in the backing. For allocator-backed slots, `slot_bytes` + must match the item length used for chunk geometry, and the source span + still requires `slot_index` and `slot_generation`. +4. Missing stable backing, missing stable-backing keepalive, cross-backing + groups, invalid geometry, strict digest requirements, or source lifetime + gaps remain pre-issue fallback reasons. + +Stable-backed export API: + +1. The communicator boundary should expose a stable-backed source-view API + such as `export_stable_local_backing_views(...)` or an equivalent + `prepare_stable_local_backing_source_views(...)` seam. The API accepts the + manifest-aligned source spans and returns an `ExportRegistration`-compatible + view with `remote_memory_keys[]`, `buffer_sizes[]`, memory location, and + transport keepalives. +2. The API must not call the generic broker-owned + `register_memory(span_addr, span_bytes)` path for eligible spans. It uses + `StableLocalBackingState` chunk records keyed by + `stable_backing_id + rail_id + chunk_index`. +3. Stable-backed lazy chunk registration is allowed only as a stable backing + cache fill: the resulting MR is owned by `StableLocalBackingState`, is + reusable by later batch get/set operations that touch the same backing + chunk and rail, and is deregistered only when the backing is deactivated. +4. The batch transport record may own source-view key cleanup, but it must not + own or deregister the stable chunk MR. The implementation should make this + split explicit rather than overloading one "registration owned" boolean. + +Remote key representation: + +1. The first implementation should keep the existing + `BatchPayloadCommunicatorSource` schema. Each source item is advertised as + a lightweight remote key whose `buffer_sizes[]` entry is exactly the item + length; `RemoteKeySource` continues to see one concatenated logical pack. +2. A lightweight source-view key is a transport-scoped communicator tensor or + equivalent lookup record over the item span. Its base address and length + are item-local, but its RDMA MR is borrowed from the stable backing chunk. +3. Unregistering the view key on transport expiry removes only the lookup + record and releases keepalives. It must not deregister the borrowed chunk + MR; the borrowed MR must be installed with non-owning semantics. +4. A future protocol could advertise backing id plus chunk offset directly, + but Phase 11 intentionally avoids proto changes and preserves the existing + `remote_memory_keys[]` / `buffer_sizes[]` lowering. + +Requested-rail RDMA behavior: + +1. The stable backing cache is already per rail: prewarm and lazy ensure store + chunk MRs under `stable_backing_id + rail_id + chunk_index`. +2. Phase 11 should use the requested-rail design. Export time creates + stable-backed view metadata; it does not have to eagerly `ensure_chunk(...)` + for every visible rail. +3. Source-side `READ_PLAN_REQUEST` admission already receives the requested + or selected `rail_id`. When the advertised remote key resolves to a + stable-backed source view, admission resolves only that rail by calling the + stable backing chunk ensure path and uses the returned chunk MR for the + direct-source response. +4. Admission must not route stable-backed views through the ordinary + `ensure_tensor_registered_on_dev(...)` raw tensor path, because that would + recreate per-batch/per-segment MR registrations. +5. If prewarm has completed, requested-rail resolution should be a chunk-cache + hit. If not, the allowed fallback is a stable-backed lazy chunk registration + owned by the backing cache, not a broker-owned view registration. + +Controller integration and observability: + +1. Remote-home `BatchPutIfAbsentFromRegion` should prefer stable-backed + source-view export for admitted Phase-10 source-layout packs. The raw + broker-owned segmented region export remains a pre-issue compatibility + fallback for non-stable-backed sources, and staged slabs remain the final + compatibility fallback. +2. The intended SGLang KV path should be fully stable-backed: source logs + should keep `mode=segmented_region_export` and `staged_slab=false`, but + change `source_realization_mode` to a stable-backed value and report + `mr_ownership=stable_backing` or equivalent. +3. New counters should distinguish source-view key count from stable chunk MR + ownership, requested-rail chunk-cache hits, stable-backed lazy chunk + registrations, raw-register fallback count, and staged-slab fallback count. +4. Path validation should prove that the source side no longer emits + broker-owned raw region registrations for intended SGLang `HOST_SHARED` + batch-set transports while home-side `read_mode=batched_direct_write` and + `materialize_mode=single_source_composite` remain unchanged. + +Phase-11 implementation status, first cut accepted on 2026-04-27: + +1. `HOST_SHARED` source spans now carry the activated + `StableLocalBackingRef` plus stable-backing keepalive into the put-side + segmented region export path. +2. The communicator exposes stable-backed lightweight source-view keys. Those + keys are unregistered with the batch transport record, while borrowed chunk + MRs remain owned by `StableLocalBackingState`. +3. Source-side `READ_PLAN_REQUEST` admission detects stable-backed view keys, + resolves only the requested rail through `ensure_chunk(...)`, and bypasses + the ordinary raw tensor lazy-registration path. +4. The SGLang share-remote Phase-11 replay showed all intended remote-home + batch-set packs using + `source_realization_mode=source_layout_host_shared_stable_backing`, + `registration_ownership=stable_backing_view`, and + `mr_ownership=stable_backing`, with no broker-owned raw registrations for + those packs and with home consume still on + `read_mode=batched_direct_write` / + `materialize_mode=single_source_composite`. + #### 5.5.7 Implemented v2 communicator-backed realization `v2 communicator_source` is the current communicator-backed realization. It moves routed byte-artifact remote transport @@ -1181,8 +1328,10 @@ Rules: - eligible RDMA get paths may export one logical pack as segmented retained body views, - eligible RDMA put paths may export one logical pack as segmented - `HOST_SHARED` source-layout views once the no-pack source realization in - 5.5.6c is implemented, + `HOST_SHARED` source-layout views; Phase 10 first realizes those views + through broker-owned raw registrations, while Phase 11 upgrades eligible + stable-backed spans to source-view keys that borrow chunk MRs from the + `HOST_SHARED` stable backing, - MTCP-compatible and fallback paths may still realize one staged host pack. 6. Current get-side remote `v2` consume paths open one communicator source per `transport_id`: @@ -1386,6 +1535,8 @@ Normative rules: step, - put-side source-layout no-pack segmented export as the third batch-set parity step, + - put-side source-layout stable-backed MR reuse as the fourth batch-set + parity step, - and source-side publish-time retained-export warming as an optional follow-on optimization on top of the same `BodyHandle` seam. 18. Session-scoped staging reuse remains complementary follow-on work after this transport-specific split. It must not From 72bbc10eab63145db139b2d00aba6e117c146e27 Mon Sep 17 00:00:00 2001 From: zhouyuhan Date: Mon, 27 Apr 2026 18:45:53 +0800 Subject: [PATCH 06/49] chore(byte-artifact): downgrade high-frequency observability logs to VLOG(2) --- core/communicator/engine/engine.cc | 12 +- daemon/BUILD | 1 + .../controllers/byte_artifact_controller.cc | 251 +++++++++--------- .../grpc_service_impl_batch_runtime_test.cc | 18 ++ daemon/service/payload_transport_broker.cc | 8 +- ...e-and-routed-byte-artifact-architecture.md | 3 +- 6 files changed, 155 insertions(+), 138 deletions(-) diff --git a/core/communicator/engine/engine.cc b/core/communicator/engine/engine.cc index 4e1cb827..41281160 100644 --- a/core/communicator/engine/engine.cc +++ b/core/communicator/engine/engine.cc @@ -2011,12 +2011,12 @@ absl::StatusOr Communicator::ensur if (chunk_or->mr == nullptr) { return absl::InternalError("stable-backed source view chunk MR is null"); } - LOG(INFO) << "communicator.read_plan_stable_source_prepare" - << " tensor_key=" << tensor_key << " backing_id=" << view->backing.backing_id - << " rail_id=" << chunk_or->rail_id << " nic=" << chunk_or->nic_name - << " chunk_index=" << chunk_or->chunk_index << " cache_hit=" << chunk_or->cache_hit - << " waited_on_inflight=" << chunk_or->waited_on_inflight << " registered_now=" << chunk_or->registered_now - << " prewarm_requested=" << prewarm_requested << " prewarm_complete=" << prewarm_complete; + VLOG(2) << "communicator.read_plan_stable_source_prepare" + << " tensor_key=" << tensor_key << " backing_id=" << view->backing.backing_id + << " rail_id=" << chunk_or->rail_id << " nic=" << chunk_or->nic_name + << " chunk_index=" << chunk_or->chunk_index << " cache_hit=" << chunk_or->cache_hit + << " waited_on_inflight=" << chunk_or->waited_on_inflight << " registered_now=" << chunk_or->registered_now + << " prewarm_requested=" << prewarm_requested << " prewarm_complete=" << prewarm_complete; return StableSourceViewMrEnsureResult{ .mr = chunk_or->mr, .backing_use = stable_use, diff --git a/daemon/BUILD b/daemon/BUILD index e5cfe650..0f2553f0 100644 --- a/daemon/BUILD +++ b/daemon/BUILD @@ -2146,6 +2146,7 @@ cc_test( "//core/store:testing_recording_global_store_client", "//core/testing:test_helpers", "//proto/tensorcast/daemon/v2:daemon_grpc_cc", + "@abseil-cpp//absl/log:globals", "@abseil-cpp//absl/log:log_entry", "@abseil-cpp//absl/log:log_sink", "@abseil-cpp//absl/log:log_sink_registry", diff --git a/daemon/service/controllers/byte_artifact_controller.cc b/daemon/service/controllers/byte_artifact_controller.cc index 34b3efe1..93362972 100644 --- a/daemon/service/controllers/byte_artifact_controller.cc +++ b/daemon/service/controllers/byte_artifact_controller.cc @@ -2114,15 +2114,14 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( ++timing_stats.remote_communicator_source_batched_direct_write_count; } } - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_open" - << " operation_id=" << operation_id << " transport_id=" << transport_id - << " kind=communicator_source" - << " remote=" << eligibility.remote - << " item_count=" << transport_it->second->manifest().entries_size() - << " payload_bytes=" << transport_it->second->manifest().total_size() - << " source_direct_write_at=" << eligibility.source_supports_direct_write - << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write - << " resolve_ms=" << absl::ToDoubleMilliseconds(absl::Now() - resolve_started_at); + VLOG(2) << "byte_artifact.home_batch_put_if_absent_transport_open" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=" << eligibility.remote + << " item_count=" << transport_it->second->manifest().entries_size() + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write + << " resolve_ms=" << absl::ToDoubleMilliseconds(absl::Now() - resolve_started_at); } source_kind = resolved_it->second.remote ? store::loading::MaterializationSource::kP2P : store::loading::MaterializationSource::kLocalReplica; @@ -2142,18 +2141,18 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( ++timing_stats.remote_direct_slice_transport_count; timing_stats.remote_direct_slice_bytes += transport_it->second->manifest().total_size(); timing_stats.remote_direct_slice_items += transport_it->second->manifest().entries_size(); - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_read_mode" - << " operation_id=" << operation_id << " transport_id=" << transport_id - << " kind=communicator_source" - << " remote=true" - << " read_mode=direct_remote_slice" - << " realization=source_slice_loader" - << " payload_bytes=" << transport_it->second->manifest().total_size() - << " item_count=" << transport_it->second->manifest().entries_size() - << " source_direct_write_at=" << eligibility.source_supports_direct_write - << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write - << " mirror_ms=0" - << " subsequent_item_slices_local=false"; + VLOG(2) << "byte_artifact.home_batch_put_if_absent_transport_read_mode" + << " operation_id=" << operation_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " remote=true" + << " read_mode=direct_remote_slice" + << " realization=source_slice_loader" + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " item_count=" << transport_it->second->manifest().entries_size() + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write + << " mirror_ms=0" + << " subsequent_item_slices_local=false"; } auto& source_mutex = remote_direct_source_mutexes[transport_id]; if (source_mutex == nullptr) { @@ -2183,18 +2182,18 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( timing_stats.remote_mirror_bytes += transport_it->second->manifest().total_size(); timing_stats.remote_full_pack_mirror_items += transport_it->second->manifest().entries_size(); mirrored_it = mirrored_remote_batch_payloads.emplace(transport_id, std::move(*mirrored_or)).first; - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_mirror" - << " operation_id=" << operation_id << " transport_id=" << transport_id - << " kind=communicator_source" - << " remote=true" - << " read_mode=full_pack" - << " realization=full_pack_mirror" - << " payload_bytes=" << transport_it->second->manifest().total_size() - << " item_count=" << transport_it->second->manifest().entries_size() - << " source_direct_write_at=" << eligibility.source_supports_direct_write - << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write - << " mirror_ms=" << absl::ToDoubleMilliseconds(mirror_elapsed) - << " subsequent_item_slices_local=true"; + VLOG(2) << "byte_artifact.home_batch_put_if_absent_transport_mirror" + << " operation_id=" << operation_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " remote=true" + << " read_mode=full_pack" + << " realization=full_pack_mirror" + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " item_count=" << transport_it->second->manifest().entries_size() + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write + << " mirror_ms=" << absl::ToDoubleMilliseconds(mirror_elapsed) + << " subsequent_item_slices_local=true"; } if (item.batch_payload_slice().offset() + item.batch_payload_slice().length() > mirrored_it->second->size()) { @@ -2367,34 +2366,33 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( return; } const auto eligibility = classify_put_remote_communicator_source(resolved_it->second); - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" - << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" - << " remote=true" - << " read_mode=direct_remote_slice" - << " materialize_mode=per_item" - << " stage_mode=source_slice_loader" - << " batched_direct_write=false" - << " source_count=1" - << " mapping_segments=0" - << " item_count=" << indices.size() << " item_bytes=0" - << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" - << " fallback_reason=" << reason; + VLOG(2) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=true" + << " read_mode=direct_remote_slice" + << " materialize_mode=per_item" + << " stage_mode=source_slice_loader" + << " batched_direct_write=false" + << " source_count=1" + << " mapping_segments=0" + << " item_count=" << indices.size() << " item_bytes=0" + << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" + << " fallback_reason=" << reason; if (direct_remote_batch_payloads.emplace(transport_id).second) { ++timing_stats.remote_direct_slice_transport_count; timing_stats.remote_direct_slice_bytes += transport_it->second->manifest().total_size(); timing_stats.remote_direct_slice_items += transport_it->second->manifest().entries_size(); - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_read_mode" - << " operation_id=" << operation_id << " transport_id=" << transport_id - << " kind=communicator_source" - << " remote=true" - << " read_mode=direct_remote_slice" - << " realization=source_slice_loader" - << " payload_bytes=" << transport_it->second->manifest().total_size() - << " item_count=" << transport_it->second->manifest().entries_size() - << " source_direct_write_at=" << eligibility.source_supports_direct_write - << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write - << " mirror_ms=0" - << " subsequent_item_slices_local=false"; + VLOG(2) << "byte_artifact.home_batch_put_if_absent_transport_read_mode" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=true" + << " read_mode=direct_remote_slice" + << " realization=source_slice_loader" + << " payload_bytes=" << transport_it->second->manifest().total_size() + << " item_count=" << transport_it->second->manifest().entries_size() + << " source_direct_write_at=" << eligibility.source_supports_direct_write + << " source_batched_direct_write_at=" << eligibility.source_supports_batched_direct_write + << " mirror_ms=0" + << " subsequent_item_slices_local=false"; } auto& source_mutex = remote_direct_source_mutexes[transport_id]; if (source_mutex == nullptr) { @@ -2498,19 +2496,19 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( const absl::Duration composite_elapsed = absl::Now() - composite_started_at; timing_stats.remote_composite_stage_elapsed += composite_elapsed; if (!composite_or.ok()) { - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" - << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" - << " remote=true" - << " read_mode=batched_direct_write" - << " materialize_mode=single_source_composite" - << " stage_mode=composite_final_body" - << " batched_direct_write=true" - << " source_count=1" - << " mapping_segments=" << indices.size() << " item_count=" << indices.size() - << " item_bytes=" << item_bytes - << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" - << " materialize_ms=" << absl::ToDoubleMilliseconds(composite_elapsed) << " outcome=failed" - << " status=" << composite_or.status(); + VLOG(2) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=true" + << " read_mode=batched_direct_write" + << " materialize_mode=single_source_composite" + << " stage_mode=composite_final_body" + << " batched_direct_write=true" + << " source_count=1" + << " mapping_segments=" << indices.size() << " item_count=" << indices.size() + << " item_bytes=" << item_bytes + << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" + << " materialize_ms=" << absl::ToDoubleMilliseconds(composite_elapsed) << " outcome=failed" + << " status=" << composite_or.status(); for (const int index : indices) { deferred_outcomes[index] = make_outcome( prepared_items[static_cast(index)].artifact_id, @@ -2538,20 +2536,20 @@ grpc::Status ByteArtifactController::home_batch_put_if_absent( if (composite_or->materialize_result.direct_write_supported) { ++timing_stats.remote_composite_batched_direct_write_count; } - LOG(INFO) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" - << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" - << " remote=true" - << " read_mode=batched_direct_write" - << " materialize_mode=single_source_composite" - << " stage_mode=composite_final_body" - << " batched_direct_write=true" - << " source_count=1" - << " mapping_segments=" << indices.size() << " item_count=" << indices.size() - << " item_bytes=" << item_bytes - << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" - << " materialize_ms=" << absl::ToDoubleMilliseconds(composite_elapsed) - << " direct_write_supported=" << composite_or->materialize_result.direct_write_supported - << " fallback_reason=none"; + VLOG(2) << "byte_artifact.home_batch_put_if_absent_transport_apply_summary" + << " operation_id=" << operation_id << " transport_id=" << transport_id << " kind=communicator_source" + << " remote=true" + << " read_mode=batched_direct_write" + << " materialize_mode=single_source_composite" + << " stage_mode=composite_final_body" + << " batched_direct_write=true" + << " source_count=1" + << " mapping_segments=" << indices.size() << " item_count=" << indices.size() + << " item_bytes=" << item_bytes + << " transport_payload_bytes=" << transport_it->second->manifest().total_size() << " mirror_ms=0" + << " materialize_ms=" << absl::ToDoubleMilliseconds(composite_elapsed) + << " direct_write_supported=" << composite_or->materialize_result.direct_write_supported + << " fallback_reason=none"; for (std::size_t local_index = 0; local_index < indices.size(); ++local_index) { auto& prepared_item = prepared_items[static_cast(indices[local_index])]; prepared_item.staged_body = std::move(composite_or->staged_bodies[local_index]); @@ -5692,28 +5690,28 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( task_stats.remote_batch_segmented_region_export_item_count += pack.source_indices.size(); task_stats.remote_batch_segmented_region_export_bytes += pack.manifest.total_size(); - LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_pack_realization" - << " operation_id=" << operation_id << " shard_id=" << task.shard_id - << " holder_daemon_id=" << task.route.holder_daemon_id << " remote=true" - << " mode=segmented_region_export" - << " staged_slab=false" - << " source_realization_mode=" - << (communicator_export_or->broker_owned_register ? "source_layout_host_shared" - : "source_layout_host_shared_stable_backing") - << " host_region_class=" << host_region_class << " pack_count=1" - << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size() - << " source_segments=" << source_segments_or->size() - << " remote_keys=" << communicator_export_or->export_registration.remote_memory_keys.size() - << " registration_ownership=" << communicator_export_or->registration_ownership - << " mr_ownership=" << communicator_export_or->mr_ownership - << " broker_owned_register=" << communicator_export_or->broker_owned_register; - - LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_transport_emit" - << " operation_id=" << operation_id << " shard_id=" << task.shard_id - << " holder_daemon_id=" << task.route.holder_daemon_id << " transport_id=" << transport_id - << " kind=communicator_source" - << " source_realization_mode=segmented_region_export" - << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size(); + VLOG(2) << "byte_artifact.batch_put_if_absent_from_region_pack_realization" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " remote=true" + << " mode=segmented_region_export" + << " staged_slab=false" + << " source_realization_mode=" + << (communicator_export_or->broker_owned_register ? "source_layout_host_shared" + : "source_layout_host_shared_stable_backing") + << " host_region_class=" << host_region_class << " pack_count=1" + << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size() + << " source_segments=" << source_segments_or->size() + << " remote_keys=" << communicator_export_or->export_registration.remote_memory_keys.size() + << " registration_ownership=" << communicator_export_or->registration_ownership + << " mr_ownership=" << communicator_export_or->mr_ownership + << " broker_owned_register=" << communicator_export_or->broker_owned_register; + + VLOG(2) << "byte_artifact.batch_put_if_absent_from_region_transport_emit" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " transport_id=" << transport_id + << " kind=communicator_source" + << " source_realization_mode=segmented_region_export" + << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size(); for (std::size_t pack_index = 0; pack_index < pack.source_indices.size(); ++pack_index) { auto slice = pack.slices[pack_index]; @@ -5771,12 +5769,12 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( task_stats.remote_batch_pack_count += packs_or->size(); task_stats.remote_batch_pack_item_count += packed_items; task_stats.remote_batch_pack_bytes += packed_bytes; - LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_pack_realization" - << " operation_id=" << operation_id << " shard_id=" << task.shard_id - << " holder_daemon_id=" << task.route.holder_daemon_id << " remote=true" - << " mode=staged_slab" - << " pack_count=" << packs_or->size() << " item_count=" << packed_items - << " payload_bytes=" << packed_bytes << " pack_ms=" << absl::ToDoubleMilliseconds(pack_elapsed); + VLOG(2) << "byte_artifact.batch_put_if_absent_from_region_pack_realization" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " remote=true" + << " mode=staged_slab" + << " pack_count=" << packs_or->size() << " item_count=" << packed_items + << " payload_bytes=" << packed_bytes << " pack_ms=" << absl::ToDoubleMilliseconds(pack_elapsed); for (auto& pack : *packs_or) { const bool use_communicator_transport = peer_transport_support.supports_v2() && d_.payload_transport_broker.batch_transport_communicator_enabled(); @@ -5873,13 +5871,12 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( task_stats.remote_batch_transport_item_count += pack.source_indices.size(); task_stats.remote_batch_transport_bytes += pack.manifest.total_size(); - LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_transport_emit" - << " operation_id=" << operation_id << " shard_id=" << task.shard_id - << " holder_daemon_id=" << task.route.holder_daemon_id << " transport_id=" << transport_id - << " kind=" << (emitted_communicator_transport ? "communicator_source" : "grpc_chunk_ref") - << " source_realization_mode=staged_slab" - << " item_count=" << pack.source_indices.size() - << " payload_bytes=" << pack.manifest.total_size(); + VLOG(2) << "byte_artifact.batch_put_if_absent_from_region_transport_emit" + << " operation_id=" << operation_id << " shard_id=" << task.shard_id + << " holder_daemon_id=" << task.route.holder_daemon_id << " transport_id=" << transport_id + << " kind=" << (emitted_communicator_transport ? "communicator_source" : "grpc_chunk_ref") + << " source_realization_mode=staged_slab" + << " item_count=" << pack.source_indices.size() << " payload_bytes=" << pack.manifest.total_size(); for (std::size_t pack_index = 0; pack_index < pack.source_indices.size(); ++pack_index) { auto slice = pack.slices[pack_index]; @@ -6025,14 +6022,14 @@ grpc::Status ByteArtifactController::batch_put_if_absent_from_region( for (;;) { auto rpc_result = dispatch_remote_home_batch(std::move(prepared_remote_batch)); task_stats.remote_home_rpc_elapsed += rpc_result.rpc_elapsed; - LOG(INFO) << "byte_artifact.batch_put_if_absent_from_region_home_rpc_result" - << " operation_id=" << operation_id << " shard_id=" << rpc_result.request.shard_id - << " holder_daemon_id=" << rpc_result.request.holder_daemon_id - << " requested_items=" << rpc_result.request.home_req.items_size() - << " attempt=" << rpc_result.request.attempt - << " rpc_ms=" << absl::ToDoubleMilliseconds(rpc_result.rpc_elapsed) - << " status_ok=" << rpc_result.status.ok() << " status_code=" << rpc_result.status.error_code() - << " status_message=" << rpc_result.status.error_message(); + VLOG(2) << "byte_artifact.batch_put_if_absent_from_region_home_rpc_result" + << " operation_id=" << operation_id << " shard_id=" << rpc_result.request.shard_id + << " holder_daemon_id=" << rpc_result.request.holder_daemon_id + << " requested_items=" << rpc_result.request.home_req.items_size() + << " attempt=" << rpc_result.request.attempt + << " rpc_ms=" << absl::ToDoubleMilliseconds(rpc_result.rpc_elapsed) + << " status_ok=" << rpc_result.status.ok() << " status_code=" << rpc_result.status.error_code() + << " status_message=" << rpc_result.status.error_message(); if (!rpc_result.status.ok()) { for (const auto& slot : rpc_result.request.outcome_slots) { diff --git a/daemon/service/grpc_service_impl_batch_runtime_test.cc b/daemon/service/grpc_service_impl_batch_runtime_test.cc index b65e9b60..d835e2f5 100644 --- a/daemon/service/grpc_service_impl_batch_runtime_test.cc +++ b/daemon/service/grpc_service_impl_batch_runtime_test.cc @@ -12,6 +12,7 @@ #include #include +#include "absl/log/globals.h" #include "absl/log/log_entry.h" #include "absl/log/log_sink.h" #include "absl/log/log_sink_registry.h" @@ -210,6 +211,21 @@ class ScopedCollectingLogSink { CollectingLogSink& sink_; }; +class ScopedVLogLevel { + public: + explicit ScopedVLogLevel(int level) : previous_level_(absl::SetGlobalVLogLevel(level)) {} + + ~ScopedVLogLevel() { + absl::SetGlobalVLogLevel(previous_level_); + } + + ScopedVLogLevel(const ScopedVLogLevel&) = delete; + ScopedVLogLevel& operator=(const ScopedVLogLevel&) = delete; + + private: + int previous_level_; +}; + static Topology make_pcie_batch_payload_topology( std::string_view local_endpoint_id, std::string_view remote_endpoint_id) { @@ -1922,6 +1938,7 @@ TEST_CASE( HomeBatchPutIfAbsentResponse put_resp; grpc::ServerContext put_ctx; { + ScopedVLogLevel scoped_vlog(/*level=*/2); ScopedCollectingLogSink scoped_sink(sink); REQUIRE(home->service().HomeBatchPutIfAbsent(&put_ctx, &put_req, &put_resp).ok()); } @@ -2039,6 +2056,7 @@ TEST_CASE( HomeBatchPutIfAbsentResponse put_resp; grpc::ServerContext put_ctx; { + ScopedVLogLevel scoped_vlog(/*level=*/2); ScopedCollectingLogSink scoped_sink(sink); REQUIRE(home->service().HomeBatchPutIfAbsent(&put_ctx, &put_req, &put_resp).ok()); } diff --git a/daemon/service/payload_transport_broker.cc b/daemon/service/payload_transport_broker.cc index 6e450f58..0163bb26 100644 --- a/daemon/service/payload_transport_broker.cc +++ b/daemon/service/payload_transport_broker.cc @@ -2010,10 +2010,10 @@ absl::StatusOr PayloadTransport stable_backing_status = registration_or.ok() ? absl::OkStatus() : registration_or.status(); stable_backing_export = registration_or.ok(); if (!registration_or.ok()) { - LOG(INFO) << "batch_payload_ref.stable_backing_source_view_fallback" - << " direction=" << payload_direction_label(direction) << " operation_id=" << operation_id - << " source_segments=" << source_segments.size() << " payload_bytes=" << manifest.total_size() - << " status=" << registration_or.status(); + VLOG(2) << "batch_payload_ref.stable_backing_source_view_fallback" + << " direction=" << payload_direction_label(direction) << " operation_id=" << operation_id + << " source_segments=" << source_segments.size() << " payload_bytes=" << manifest.total_size() + << " status=" << registration_or.status(); } } if (!registration_or.ok()) { diff --git a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md index f021608c..bb62b2a3 100644 --- a/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md +++ b/docs/designs/0087-unified-artifact-runtime-and-routed-byte-artifact-architecture.md @@ -1614,7 +1614,8 @@ Rules: exportable producer memory instead of RPC payload slicing, but logical pack shape still remains bounded by `max_batch_payload_bytes` and `max_batch_items`. 3. Current observability is log-first rather than metrics-first. The live implementation emits structured - `INFO` or `VLOG(1)` summaries such as: + `INFO` or `VLOG(1/2)` summaries; high-frequency put-side RDMA transport probes are `VLOG(2)` so production + runs do not print them unless verbose logging is explicitly enabled. Examples include: - `byte_artifact.home_batch_get_timing_summary` - `byte_artifact.home_batch_get_response_shape` - `byte_artifact.batch_get_into_region_home_rpc_result` From 431b5597a8bbe380fb727322a7a2cbab9da51866 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 11:56:23 +0800 Subject: [PATCH 07/49] docs: design weight broadcast scheduling --- ...trol-plane-coordinated-weight-broadcast.md | 364 ++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 docs/designs/0116-control-plane-coordinated-weight-broadcast.md diff --git a/docs/designs/0116-control-plane-coordinated-weight-broadcast.md b/docs/designs/0116-control-plane-coordinated-weight-broadcast.md new file mode 100644 index 00000000..0137da30 --- /dev/null +++ b/docs/designs/0116-control-plane-coordinated-weight-broadcast.md @@ -0,0 +1,364 @@ +--- +slug: control-plane-coordinated-weight-broadcast +title: Control-Plane Coordinated Weight Broadcast +status: proposed +areas: ["sdk", "daemon", "core", "global_store", "docs"] +created: 2026-04-29 +last_updated: 2026-04-29 +related_code: + - tensorcast/api/context.py + - tensorcast/api/store/artifact.py + - tensorcast/api/_materialize.py + - tensorcast/daemon_ctl.py + - proto/tensorcast/daemon/v2/store_daemon.proto + - proto/tensorcast/global_store/v1/global_store.proto + - daemon/service/controllers/replica_materialization_service.cc + - daemon/service/controllers/materialization_policy_utils.cc + - core/store/materialization/contracts/loading_spec.h + - core/store/components/global_store_client.cc + - core/store/runtime/ingestion/materialization_facade.cc + - tensorcast/global_store/services/transport_service.py + - tensorcast/global_store/repositories/replica_repository.py + - tensorcast/global_store/repositories/transport_repository.py + - docs/plans/0083-group-aware-transport-scheduling.md +links: + dependencies: + - ../plans/0083-group-aware-transport-scheduling.md + - ./0108-tensor-aware-materialization-strategy-plane.md + - ./0114-collective-first-binding-realization-for-tp-serving-startup.md +--- + +# Summary + +TensorCast already has the control-plane pieces required for a soft model +weight broadcast: + +- `RequestReplicaTransport` requires request idempotency, +- Global Store queues transport requests through group dispatch, +- group dispatch tracks group progress and source spread, +- replica selection filters liveness, availability, and capacity, +- and successful materialization registers a new daemon-local replica that can + later serve P2P requests. + +The missing Phase 1 link is the daemon-owned prefetch path. `Artifact.prefetch` +currently allocates a daemon replica through `MaterializeReplicaRequest`, but +that request does not carry transport request ids or scheduling group hints. +The result is that simultaneous model-weight prefetches cannot reliably enter +Global Store as one transport group, even though the scheduler already knows +how to dispatch such a group. + +This design adds a first-class transport scheduling hint to the SDK and daemon +materialization API, then reuses the existing Global Store scheduler and P2P +data plane. It deliberately does not introduce an independent broadcast +control plane and does not move model-weight distribution to NCCL. + +```mermaid +flowchart LR + A["Model version prefetch
same broadcast group"] --> B["SDK CallContext
transport group + request id"] + B --> C["Store Daemon
MaterializeReplicaRequest"] + C --> D["StoreEngine
MaterializeHints"] + D --> E["Global Store
group dispatch"] + E --> F["Selected source replica
source spread"] + F --> G["P2P materialization
RDMA / MTCP / routing"] + G --> H["Target daemon
registers replica"] + H --> I["New source candidate
for later requests"] +``` + +# Goals / Non-Goals + +## Goals + +- Let one model version prefetch fanout produce stable transport request ids + and a shared `weight_broadcast` scheduling group. +- Preserve `replica_uuid` as a pure daemon replica/session identifier. +- Reuse existing Global Store group dispatch, pending request queues, + idempotency checks, source spread, and completion outcome accounting. +- Reuse existing P2P materialization and replica registration/export paths. +- Keep ordinary `tensorcast.artifact(...).tensor_dict()` and ungrouped + `Artifact.prefetch()` behavior unchanged. +- Provide an API shape that can evolve into explicit broadcast sessions and + tree plans without rewriting the data plane. + +## Non-Goals + +- Do not implement strict parent-child tree scheduling in Phase 1. +- Do not add chunk-level pipeline broadcast in Phase 1. +- Do not introduce NCCL as the cross-cluster model-weight broadcast control + plane. +- Do not add a Global Store schema migration for Phase 1. +- Do not change `MaterializeIntoTarget`, binding, or mapped-binding transport + semantics in the first cut. + +# Prior Constraints Reviewed + +## Group-aware transport scheduling + +`docs/plans/0083-group-aware-transport-scheduling.md` and the current Global +Store code already define queue-based dispatch, group fairness, completion +bias, starvation aging, and source spread. This design keeps that scheduler as +the Phase 1 control-plane primitive instead of adding a parallel scheduler. + +## Strategy-plane ownership + +`0108` keeps semantic materialization strategy in the daemon/core +materialization path. This design follows that boundary: the SDK only sends +request-level coordination hints, while Store Daemon and StoreEngine decide how +to acquire sources and execute the transfer. + +## Same-host collectives + +`0114` focuses on same-host collective-first binding realization. That work is +complementary but not a cluster-level broadcast replacement. This design keeps +cluster-wide model-weight dissemination under Global Store coordination and +leaves NCCL or local collectives as later locality-specific optimizations. + +## Replica identity purity + +Current API docs state that `replica_uuid` remains a pure operation/session id. +This design keeps that rule. Transport group metadata must travel in explicit +transport hint fields, not in `replica_uuid`. + +# Architecture & Interfaces + +## SDK context + +Add a typed transport scheduling context: + +```python +@dataclass(frozen=True, slots=True) +class TransportSchedulingGroup: + group_id: str + group_kind: str + total_parts: int + part_id: str + priority: int = 0 + epoch: int = 0 + request_id: str | None = None +``` + +`CallContext` gains: + +```python +transport_group: TransportSchedulingGroup | None = None +``` + +`tensorcast.context(...)` accepts the same optional argument. Existing +`ctx.tags["tc.transport.group.*"]` remains a compatibility path because +binding and mapped target paths already use those tags to build +`operation_id#tcg:...` strings. The typed field is the preferred API for new +prefetch callers. + +For model weights, callers should use: + +```python +ctx = tensorcast.context( + idempotency_key="load:model-a:v42", + transport_group=tensorcast.TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=128, + part_id="daemon-17", + ), +) +store.artifact(artifact_id=artifact_id).prefetch(device="cuda:0", ctx=ctx) +``` + +When a group exists and no explicit transport request id is provided, SDK code +generates a deterministic id from artifact id, selection hash, daemon id, +device id, group kind, group id, epoch, and part id. This makes retries +idempotent while avoiding collisions between different target daemons or +devices. + +## Daemon proto + +Extend `MaterializeReplicaRequest` with explicit transport scheduling fields: + +```proto +message TransportSchedulingGroupHint { + string group_id = 1; + string group_kind = 2; + uint32 total_parts = 3; + string part_id = 4; + uint32 priority = 5; + uint64 epoch = 6; +} + +message MaterializeReplicaRequest { + // existing fields... + string transport_request_id = 21; + TransportSchedulingGroupHint transport_scheduling_group = 22; +} +``` + +The daemon proto owns this request shape so `store_daemon.proto` does not need +to import the Global Store service proto just for one scheduling hint. The C++ +daemon maps it into `store::loading::TransportSchedulingGroupHint`, which is +already the type consumed by `MaterializeHints`. + +## DaemonCtl and SDK materialization + +`DaemonCtl.materialize_by_artifact_id_v2(...)` receives +`transport_request_id` and `transport_scheduling_group` keyword arguments and +copies them into `MaterializeReplicaRequest`. + +`materialize_artifact_v2(...)` resolves the transport hints from `CallContext` +and passes them through `DaemonCtl`. `Artifact.prefetch()` continues to call +`pipeline.materialize_subset(...)`, but grouped prefetches now reach the +daemon as explicit transport hints. + +Ungrouped materialization sends no hints. Global Store request ids for +ordinary paths continue to be generated by existing C++ fallback logic. + +## Store Daemon to StoreEngine + +`ReplicaMaterializationService::materialize_replica(...)` reads the new request +fields and applies them to `MaterializeHints` before calling +`StoreEngine::materialize_replica(...)`. + +The existing request path already passes `MaterializeHints` into: + +- `MaterializeOrchestrator`, +- `MaterializationFacade`, +- `GlobalStoreClient::request_replica_transport(...)`, +- `GlobalStoreClient::request_view_transport(...)`. + +Those paths already convert `TransportSchedulingGroupHint` into +`RequestReplicaTransportRequest.scheduling_group` and pass +`transport_request_id` as `request_id`. + +## Global Store behavior + +Phase 1 does not change Global Store schema or scheduler semantics. Grouped +prefetch requests enter the existing queue and are dispatched by: + +- group fairness floor, +- completion bias, +- starvation aging, +- source-balance scoring, +- group source spread, +- heartbeat and accepting-new-request filters, +- replica and worker concurrency limits. + +Only `TRANSPORT_COMPLETION_OUTCOME_SUCCESS` contributes to group progress, as +the current transport service already requires. + +## Replica export + +Prefetch remains daemon-owned and uses `LEASE_MODE_NO_LEASE`. To make successful +prefetch replicas useful as later sources, callers may request +`GetArtifactOptions(export_policy="auto")` or `force` where appropriate. The +daemon and StoreEngine continue to own export eligibility and remote memory key +registration. + +# Future Tree Broadcast + +Phase 2 can add explicit Global Store concepts without replacing the Phase 1 +data path: + +```text +BroadcastSession(session_id, artifact_id, view_id, epoch, fanout, state) +BroadcastEdge(session_id, parent_worker_id, child_worker_id, level, state, attempt) +``` + +The scheduler would select parent-child edges, then instruct each child to run +the same P2P materialization path with a preferred parent. A child only enters +the parent pool after materialization succeeds and exportable transport +metadata is registered. + +Phase 3 can make parent selection topology-aware. Phase 4 can consider +chunk-level pipeline forwarding, which requires partial residency and +per-chunk verification state and is intentionally outside Phase 1. + +# Naming Compliance + +| Interface | Language | Compliance | +| --- | --- | --- | +| `TransportSchedulingGroup` | Python class | PascalCase, matching existing dataclass style such as `CollectiveLoadGroup`. | +| `transport_group` | Python field | snake_case field name. | +| `transport_request_id` | Python/proto/C++ field | snake_case field name. | +| `TransportSchedulingGroupHint` | Proto/C++ message mapping | PascalCase type name. | +| `resolve_transport_scheduling_hints` | Python helper | snake_case function name. | +| `apply_materialize_replica_transport_hints` | C++ helper | snake_case function name. | + +# Schema Changes + +Phase 1 has no Global Store database schema changes. It only extends the daemon +RPC request shape and threads already-existing scheduler metadata into existing +Global Store transport tables. + +Phase 2 `BroadcastSession` and `BroadcastEdge` would require a separate design +or plan section with migrations and recovery semantics before implementation. + +# Error Model + +- Invalid typed groups fail client-side where possible: + - `group_kind`, `group_id`, and `part_id` must be non-empty, + - `total_parts` must be positive, + - `priority` and `epoch` must be non-negative. +- If both typed `transport_group` and legacy `ctx.tags` group keys are present, + the typed group wins for `MaterializeReplica` prefetch. This avoids + ambiguous request ids. +- A repeated `transport_request_id` with a different payload is rejected by the + existing Global Store idempotency checks. +- Failed, expired, or cancelled transports release capacity but do not count as + group success. +- If no exportable source is available, materialization follows existing + fallback behavior, including MTCP/disk fallback when allowed by source policy. + +# Compatibility & Acceptance Criteria + +## Compatibility + +- Existing callers that do not pass `transport_group` see no behavior change. +- `tensor_dict`, `tensor_dict_into`, `Binding.swap`, and mapped binding paths + keep their current transport hint behavior. +- Existing `ctx.tags` group keys remain recognized for compatibility. +- The new proto fields are additive and optional. + +## Phase 1 Acceptance Criteria + +- Multiple daemon prefetches for the same model version can carry the same + `group_kind`, `group_id`, `epoch`, and `total_parts`. +- Each target daemon can carry a distinct `part_id` and stable + `transport_request_id`. +- Global Store `pending_transport_requests` and `artifact_transports` show the + requests in one group. +- Source selection uses existing group source spread instead of concentrating + all requests on one root when alternatives are available. +- Successful target materialization still registers local replicas. +- Exportable replicas still publish remote memory keys when export policy and + local state allow it. +- Failed, expired, or cancelled transports do not advance group success. +- Single-node, no-RDMA, and ordinary ungrouped materialization paths keep + working through existing fallback paths. + +# Testing + +- Python SDK unit tests: + - typed `TransportSchedulingGroup` validates fields, + - `Artifact.prefetch()` forwards transport hints to the materialization + pipeline, + - deterministic transport request ids are stable for the same group part, + - ungrouped prefetch sends no hint. +- Python daemon client tests: + - `DaemonCtl.materialize_by_artifact_id_v2()` fills the new proto fields. +- C++ daemon tests: + - `MaterializeReplicaRequest.transport_scheduling_group` becomes + `MaterializeHints.transport_scheduling_group`. + - `transport_request_id` becomes `MaterializeHints.transport_request_id`. +- Global Store regression tests: + - existing group dispatch, request idempotency, group progress, and source + spread tests continue to pass. + +# References + +- `tensorcast/global_store/services/transport_service.py` +- `tensorcast/global_store/repositories/replica_repository.py` +- `tensorcast/global_store/repositories/transport_repository.py` +- `core/store/materialization/contracts/loading_spec.h` +- `core/store/components/global_store_client.cc` +- `core/store/runtime/ingestion/materialization_facade.cc` +- `daemon/service/controllers/replica_materialization_service.cc` +- `tensorcast/api/store/artifact.py` +- `proto/tensorcast/daemon/v2/store_daemon.proto` From 83a1584e225dc17a425206cc0915410c34e21272 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:10:05 +0800 Subject: [PATCH 08/49] docs: plan weight broadcast scheduling --- ...trol-plane-coordinated-weight-broadcast.md | 8 + ...trol-plane-coordinated-weight-broadcast.md | 1119 +++++++++++++++++ 2 files changed, 1127 insertions(+) create mode 100644 docs/plans/0116-control-plane-coordinated-weight-broadcast.md diff --git a/docs/designs/0116-control-plane-coordinated-weight-broadcast.md b/docs/designs/0116-control-plane-coordinated-weight-broadcast.md index 0137da30..44312d3e 100644 --- a/docs/designs/0116-control-plane-coordinated-weight-broadcast.md +++ b/docs/designs/0116-control-plane-coordinated-weight-broadcast.md @@ -9,9 +9,12 @@ related_code: - tensorcast/api/context.py - tensorcast/api/store/artifact.py - tensorcast/api/_materialize.py + - tensorcast/api/plan/plan.py - tensorcast/daemon_ctl.py - proto/tensorcast/daemon/v2/store_daemon.proto - proto/tensorcast/global_store/v1/global_store.proto + - proto/tensorcast/plan/v1/plan.proto + - tensorcast/node_agent/executor.py - daemon/service/controllers/replica_materialization_service.cc - daemon/service/controllers/materialization_policy_utils.cc - core/store/materialization/contracts/loading_spec.h @@ -22,6 +25,7 @@ related_code: - tensorcast/global_store/repositories/transport_repository.py - docs/plans/0083-group-aware-transport-scheduling.md links: + plan: ../plans/0116-control-plane-coordinated-weight-broadcast.md dependencies: - ../plans/0083-group-aware-transport-scheduling.md - ./0108-tensor-aware-materialization-strategy-plane.md @@ -148,6 +152,10 @@ binding and mapped target paths already use those tags to build `operation_id#tcg:...` strings. The typed field is the preferred API for new prefetch callers. +Programmable plans serialize the same typed group through `plan.v1.CallContext` +so node-agent prefetch steps can preserve group membership when a cluster plan +fans out work to multiple daemons. + For model weights, callers should use: ```python diff --git a/docs/plans/0116-control-plane-coordinated-weight-broadcast.md b/docs/plans/0116-control-plane-coordinated-weight-broadcast.md new file mode 100644 index 00000000..3ad30b77 --- /dev/null +++ b/docs/plans/0116-control-plane-coordinated-weight-broadcast.md @@ -0,0 +1,1119 @@ +--- +slug: control-plane-coordinated-weight-broadcast +title: Control-Plane Coordinated Weight Broadcast Implementation Plan +links: + design: ../designs/0116-control-plane-coordinated-weight-broadcast.md +--- + +# Control-Plane Coordinated Weight Broadcast Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Thread model-weight prefetch transport request ids and group hints from the Python SDK through Store Daemon into existing Global Store group dispatch. + +**Architecture:** Add typed SDK transport scheduling context, additive daemon proto fields, DaemonCtl forwarding, and C++ daemon mapping into existing `MaterializeHints`. The Global Store scheduler and P2P data plane remain unchanged; Phase 1 only makes daemon-owned `Artifact.prefetch()` visible to the existing group dispatcher. + +**Tech Stack:** Python SDK, pydantic options, protobuf/buf generation, C++ Store Daemon controllers, StoreEngine `MaterializeHints`, pytest, Bazel/Catch2. + +--- + +# Current State & Grounding + +- Branch: `runze/broadcast-weight`. +- Design: `docs/designs/0116-control-plane-coordinated-weight-broadcast.md`. +- Existing Global Store fields: `proto/tensorcast/global_store/v1/global_store.proto` already defines `RequestReplicaTransportRequest.request_id` and `TransportSchedulingGroup`. +- Existing core hints: `core/store/materialization/contracts/loading_spec.h` already defines `TransportSchedulingGroupHint` and `MaterializeHints.transport_request_id`. +- Existing C++ transport client: `core/store/components/global_store_client.cc` already copies hints into `RequestReplicaTransportRequest`. +- Current gap: `proto/tensorcast/daemon/v2/store_daemon.proto::MaterializeReplicaRequest` has no transport hint fields, so `Artifact.prefetch()` cannot send group hints to Store Daemon. +- Current prefetch entrypoint: `tensorcast/api/store/artifact.py::Artifact.prefetch()`. +- Current SDK materialization path: `tensorcast/api/_materialize.py::materialize_artifact_v2()` to `tensorcast/daemon_ctl.py::DaemonCtl.materialize_by_artifact_id_v2()`. +- Current daemon materialization path: `daemon/service/controllers/replica_materialization_service.cc::materialize_replica()`. +- Existing dirty worktree before this plan includes generated proto files and `pyproject.toml`; implementation must not revert or stage unrelated pre-existing changes. + +# Files + +- Modify: `proto/tensorcast/daemon/v2/store_daemon.proto` +- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py` +- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi` +- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py` if Buf rewrites it +- Modify: `tensorcast/api/context.py` +- Modify: `tensorcast/api/__init__.py` +- Modify: `tensorcast/__init__.py` +- Modify: `tensorcast/api/store/artifact.py` +- Modify: `tensorcast/api/_materialize.py` +- Modify: `tensorcast/daemon_ctl.py` +- Modify: `daemon/service/controllers/materialization_policy_utils.h` +- Modify: `daemon/service/controllers/materialization_policy_utils.cc` +- Modify: `daemon/service/controllers/replica_materialization_service.cc` +- Modify: `daemon/service/materialization_policy_utils_test.cc` +- Test: `tests/python/api/test_prefetch_operation.py` +- Test: add `tests/python/api/test_daemon_ctl_transport_hints.py` +- Modify: `docs/designs/0116-control-plane-coordinated-weight-broadcast.md` + +# Phases & Milestones + +- [ ] Phase 1: Add public SDK transport group context and deterministic prefetch hint resolution. +- [ ] Phase 2: Add daemon proto fields and regenerate Python stubs. +- [ ] Phase 3: Forward transport hints through DaemonCtl and Store Daemon. +- [ ] Phase 4: Verify SDK, daemon, and Global Store regressions. + +### Task 1: SDK Transport Group Context + +**Files:** +- Modify: `tensorcast/api/context.py` +- Modify: `tensorcast/api/__init__.py` +- Modify: `tensorcast/__init__.py` +- Test: `tests/python/api/test_prefetch_operation.py` + +- [ ] **Step 1: Write failing validation and export tests** + +Add these imports near the top of `tests/python/api/test_prefetch_operation.py`: + +```python +from tensorcast.api.context import TransportSchedulingGroup +``` + +Append these tests: + +```python +def test_transport_scheduling_group_rejects_invalid_values() -> None: + invalid_cases = [ + {"group_kind": "", "group_id": "model:v1", "total_parts": 2, "part_id": "d0"}, + {"group_kind": "weight_broadcast", "group_id": "", "total_parts": 2, "part_id": "d0"}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 0, "part_id": "d0"}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 2, "part_id": ""}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 2, "part_id": "d0", "priority": -1}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 2, "part_id": "d0", "epoch": -1}, + ] + + for kwargs in invalid_cases: + try: + TransportSchedulingGroup(**kwargs) + except ValueError: + continue + raise AssertionError(f"expected invalid transport group: {kwargs}") + + +def test_context_accepts_typed_transport_group() -> None: + group = tc.TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=8, + part_id="daemon-3", + ) + + ctx = tc.context(request_id="req-1", transport_group=group) + + assert ctx.transport_group == group + assert tc.TransportSchedulingGroup is TransportSchedulingGroup +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: + +```bash +pytest tests/python/api/test_prefetch_operation.py::test_transport_scheduling_group_rejects_invalid_values tests/python/api/test_prefetch_operation.py::test_context_accepts_typed_transport_group -v +``` + +Expected: FAIL because `TransportSchedulingGroup` is not defined/exported. + +- [ ] **Step 3: Implement `TransportSchedulingGroup`** + +In `tensorcast/api/context.py`, add this dataclass after `CollectiveLoadGroup`: + +```python +@dataclass(frozen=True, slots=True) +class TransportSchedulingGroup: + """Control-plane transport scheduling group for coordinated P2P source selection.""" + + group_id: str + group_kind: str + total_parts: int + part_id: str + priority: int = 0 + epoch: int = 0 + request_id: str | None = None + + def __post_init__(self) -> None: + group_kind = str(self.group_kind).strip() + group_id = str(self.group_id).strip() + part_id = str(self.part_id).strip() + total_parts = int(self.total_parts) + priority = int(self.priority) + epoch = int(self.epoch) + request_id = None if self.request_id is None else str(self.request_id).strip() + if not group_kind: + raise ValueError("TransportSchedulingGroup.group_kind must be non-empty") + if not group_id: + raise ValueError("TransportSchedulingGroup.group_id must be non-empty") + if total_parts <= 0: + raise ValueError("TransportSchedulingGroup.total_parts must be positive") + if not part_id: + raise ValueError("TransportSchedulingGroup.part_id must be non-empty") + if priority < 0: + raise ValueError("TransportSchedulingGroup.priority must be non-negative") + if epoch < 0: + raise ValueError("TransportSchedulingGroup.epoch must be non-negative") + object.__setattr__(self, "group_kind", group_kind) + object.__setattr__(self, "group_id", group_id) + object.__setattr__(self, "total_parts", total_parts) + object.__setattr__(self, "part_id", part_id) + object.__setattr__(self, "priority", priority) + object.__setattr__(self, "epoch", epoch) + object.__setattr__(self, "request_id", request_id or None) +``` + +Add `transport_group` to `CallContext`: + +```python +transport_group: TransportSchedulingGroup | None = None +``` + +Add `transport_group` to `context(...)` and pass it into `CallContext(...)`: + +```python +transport_group: TransportSchedulingGroup | None = None, +``` + +```python +transport_group=transport_group, +``` + +Add `"TransportSchedulingGroup"` to `__all__`. + +- [ ] **Step 4: Export the new type** + +In `tensorcast/api/__init__.py`, import and export `TransportSchedulingGroup` from `tensorcast.api.context`. + +In `tensorcast/__init__.py`, add: + +```python +"TransportSchedulingGroup": ("tensorcast.api", "TransportSchedulingGroup"), +``` + +Add `TransportSchedulingGroup` to the `TYPE_CHECKING` import list and `__all__`. + +- [ ] **Step 5: Run tests to verify they pass** + +Run: + +```bash +pytest tests/python/api/test_prefetch_operation.py::test_transport_scheduling_group_rejects_invalid_values tests/python/api/test_prefetch_operation.py::test_context_accepts_typed_transport_group -v +``` + +Expected: PASS. + +- [ ] **Step 6: Commit** + +Run: + +```bash +git add tensorcast/api/context.py tensorcast/api/__init__.py tensorcast/__init__.py tests/python/api/test_prefetch_operation.py +git commit -m "feat(sdk): add transport scheduling group context" +``` + +### Task 2: SDK Prefetch Hint Resolution + +**Files:** +- Modify: `tensorcast/api/store/artifact.py` +- Test: `tests/python/api/test_prefetch_operation.py` + +- [ ] **Step 1: Write failing prefetch forwarding tests** + +Append these tests to `tests/python/api/test_prefetch_operation.py`: + +```python +def test_prefetch_forwards_typed_transport_group_hint() -> None: + store = _Store() + artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") + group = tc.TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=16, + part_id="daemon-1", + priority=7, + request_id="explicit-transport-req", + ) + + artifact.prefetch(device="cuda:0", ctx=tc.context(transport_group=group)) + + call = store._materialization.calls[0] + assert call["transport_request_id"] == "explicit-transport-req" + forwarded = call["transport_scheduling_group"] + assert forwarded.group_kind == "weight_broadcast" + assert forwarded.group_id == "model-a:v42" + assert forwarded.epoch == 42 + assert forwarded.total_parts == 16 + assert forwarded.part_id == "daemon-1" + assert forwarded.priority == 7 + + +def test_prefetch_derives_stable_transport_request_id_for_group() -> None: + group = tc.TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=16, + part_id="daemon-1", + ) + ctx = tc.context(transport_group=group) + first_store = _Store() + second_store = _Store() + first = Artifact(store_ref=weakref.ref(first_store), artifact_id="aid") + second = Artifact(store_ref=weakref.ref(second_store), artifact_id="aid") + + first.prefetch(device="cuda:0", ctx=ctx) + second.prefetch(device="cuda:0", ctx=ctx) + + first_request_id = first_store._materialization.calls[0]["transport_request_id"] + second_request_id = second_store._materialization.calls[0]["transport_request_id"] + assert first_request_id + assert first_request_id == second_request_id + assert first_request_id.startswith("prefetch:") + + +def test_prefetch_without_group_sends_no_transport_hint() -> None: + store = _Store() + artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") + + artifact.prefetch(device="cuda:0") + + call = store._materialization.calls[0] + assert call["transport_request_id"] is None + assert call["transport_scheduling_group"] is None +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: + +```bash +pytest tests/python/api/test_prefetch_operation.py::test_prefetch_forwards_typed_transport_group_hint tests/python/api/test_prefetch_operation.py::test_prefetch_derives_stable_transport_request_id_for_group tests/python/api/test_prefetch_operation.py::test_prefetch_without_group_sends_no_transport_hint -v +``` + +Expected: FAIL because `Artifact.prefetch()` does not pass `transport_request_id` or `transport_scheduling_group`. + +- [ ] **Step 3: Add transport hint helper** + +In `tensorcast/api/store/artifact.py`, add `TransportSchedulingGroup` to the context import: + +```python +from tensorcast.api.context import CallContext, TransportSchedulingGroup +``` + +Add this helper near `_build_transport_operation_id(...)`: + +```python +def _transport_group_from_ctx_tags(ctx: CallContext | None) -> TransportSchedulingGroup | None: + if ctx is None or not ctx.tags: + return None + group_kind = _read_context_tag_str(ctx.tags, _TRANSPORT_GROUP_KIND_TAG) + group_id = _read_context_tag_str(ctx.tags, _TRANSPORT_GROUP_ID_TAG) + part_id = _read_context_tag_str(ctx.tags, _TRANSPORT_GROUP_PART_ID_TAG) + total_parts = _read_context_tag_int(ctx.tags, _TRANSPORT_GROUP_TOTAL_PARTS_TAG, default=0) + if not (group_kind and group_id and part_id and total_parts > 0): + return None + return TransportSchedulingGroup( + group_kind=group_kind, + group_id=group_id, + total_parts=total_parts, + part_id=part_id, + priority=_read_context_tag_int(ctx.tags, _TRANSPORT_GROUP_PRIORITY_TAG, default=0), + epoch=_read_context_tag_int(ctx.tags, _TRANSPORT_GROUP_EPOCH_TAG, default=0), + request_id=_read_context_tag_str(ctx.tags, _TRANSPORT_REQUEST_ID_TAG) or None, + ) + + +def _resolve_prefetch_transport_hints( + *, + ctx: CallContext | None, + daemon_id: str, + artifact_id: str, + selection_hash: str, + logical_layout_hash: str, + device_id: int, + device_uuid: str, +) -> tuple[str | None, TransportSchedulingGroup | None]: + group = (ctx.transport_group if ctx is not None else None) or _transport_group_from_ctx_tags(ctx) + if group is None: + return None, None + if group.request_id: + return group.request_id, group + digest = hashlib.sha256() + digest.update(b"tensorcast.prefetch.transport.v1") + for value in ( + daemon_id, + artifact_id, + logical_layout_hash, + selection_hash, + str(device_id), + device_uuid, + group.group_kind, + group.group_id, + str(group.epoch), + str(group.total_parts), + group.part_id, + ): + digest.update(b"|") + digest.update(str(value).encode("utf-8")) + return f"prefetch:{digest.hexdigest()}", group +``` + +- [ ] **Step 4: Thread hints from `Artifact.prefetch()`** + +In `Artifact.prefetch()`, after `device_uuid = device_uuid_for(device_id)` is available for deterministic operation id generation, keep a local `device_uuid_value` for both deterministic replica id and transport id: + +```python +device_uuid_value = device_uuid_for(device_id) +``` + +Use `device_uuid_value` in the existing action fingerprint instead of recomputing `device_uuid`. + +Before calling `pipeline.materialize_subset(...)`, add: + +```python +transport_request_id, transport_scheduling_group = _resolve_prefetch_transport_hints( + ctx=ctx, + daemon_id=daemon_id, + artifact_id=artifact_id, + selection_hash=selection_hash, + logical_layout_hash=bytes(selection.logical_layout_hash).hex(), + device_id=device_id, + device_uuid=device_uuid_value, +) +``` + +Pass these kwargs into `pipeline.materialize_subset(...)`: + +```python +transport_request_id=transport_request_id, +transport_scheduling_group=transport_scheduling_group, +``` + +- [ ] **Step 5: Run tests to verify they pass** + +Run: + +```bash +pytest tests/python/api/test_prefetch_operation.py -v +``` + +Expected: PASS. + +- [ ] **Step 6: Commit** + +Run: + +```bash +git add tensorcast/api/store/artifact.py tests/python/api/test_prefetch_operation.py +git commit -m "feat(sdk): derive prefetch transport group hints" +``` + +### Task 3: Daemon Proto and Python Generation + +**Files:** +- Modify: `proto/tensorcast/daemon/v2/store_daemon.proto` +- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py` +- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi` +- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py` if Buf rewrites it + +- [ ] **Step 1: Add daemon transport group proto fields** + +In `proto/tensorcast/daemon/v2/store_daemon.proto`, add this message before `MaterializeReplicaRequest`: + +```proto +message TransportSchedulingGroupHint { + string group_id = 1; + string group_kind = 2; + uint32 total_parts = 3; + string part_id = 4; + uint32 priority = 5; + uint64 epoch = 6; +} +``` + +Add these fields to `MaterializeReplicaRequest` after `serving_artifact_policy = 20;`: + +```proto + string transport_request_id = 21; + TransportSchedulingGroupHint transport_scheduling_group = 22; +``` + +- [ ] **Step 2: Format and regenerate protos** + +Run: + +```bash +bazel run @rules_buf_toolchains//:buf -- format ./proto -w +bash tools/build_proto_python.sh +``` + +Expected: `store_daemon_pb2.py` and `store_daemon_pb2.pyi` include `TransportSchedulingGroupHint`, `transport_request_id`, and `transport_scheduling_group`. + +- [ ] **Step 3: Inspect generated changes carefully** + +Run: + +```bash +git status --short proto/tensorcast/daemon/v2/store_daemon.proto proto/gen/python/tensorcast/daemon/v2 +git diff -- proto/tensorcast/daemon/v2/store_daemon.proto proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py +``` + +Expected: diffs are limited to daemon v2 proto generation for the added fields. Do not stage unrelated pre-existing generated proto directories. + +- [ ] **Step 4: Commit** + +Run: + +```bash +git add proto/tensorcast/daemon/v2/store_daemon.proto proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py +git commit -m "feat(proto): add materialize replica transport hints" +``` + +If `store_daemon_pb2_grpc.py` has no diff, omit it from `git add`. + +### Task 4: DaemonCtl and Materialization Pipeline Forwarding + +**Files:** +- Modify: `tensorcast/api/_materialize.py` +- Modify: `tensorcast/daemon_ctl.py` +- Test: add `tests/python/api/test_daemon_ctl_transport_hints.py` +- Test: `tests/python/api/test_prefetch_operation.py` + +- [ ] **Step 1: Write failing DaemonCtl request construction test** + +Create `tests/python/api/test_daemon_ctl_transport_hints.py`: + +```python +# Copyright (c) 2026, TensorCast Team. + +from __future__ import annotations + +from tensorcast.daemon_ctl import DaemonCtl +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.daemon.v2 import store_daemon_pb2 + + +class _FakeUnary: + _method = b"/tensorcast.daemon.v2.StoreDaemonService/MaterializeReplica" + + def __init__(self) -> None: + self.requests: list[store_daemon_pb2.MaterializeReplicaRequest] = [] + + def __call__(self, request, timeout=None): # noqa: ANN001, ANN204 + del timeout + self.requests.append(request) + response = store_daemon_pb2.MaterializeReplicaResponse() + response.status = ( + store_daemon_pb2.MaterializeReplicaStatus.MATERIALIZE_REPLICA_STATUS_ALLOCATED + ) + response.ticket.replica_uuid = request.replica_uuid + return response + + +class _FakeStub: + def __init__(self) -> None: + self.MaterializeReplica = _FakeUnary() + + +def test_daemon_ctl_forwards_materialize_transport_hints(monkeypatch) -> None: # noqa: ANN001 + monkeypatch.setattr(DaemonCtl, "_create_channel", lambda self, addr: None) + ctl = DaemonCtl("fake-daemon") + fake_stub = _FakeStub() + ctl.stub_v2 = fake_stub + ctl.stub = fake_stub + monkeypatch.setattr(ctl, "_get_effective_pid", lambda: 123) + monkeypatch.setattr(ctl, "_unary_call", lambda method, request, **kwargs: method(request, timeout=kwargs.get("timeout"))) + selection = common_pb2.ArtifactSelection(artifact_id="aid") + group = store_daemon_pb2.TransportSchedulingGroupHint( + group_kind="weight_broadcast", + group_id="model-a:v42", + total_parts=16, + part_id="daemon-1", + priority=7, + epoch=42, + ) + + ctl.materialize_by_artifact_id_v2( + selection=selection, + replica_uuid="replica-1", + device_uuid="device-uuid", + wait_for_completion=False, + return_response=True, + transport_request_id="transport-req-1", + transport_scheduling_group=group, + ) + + request = fake_stub.MaterializeReplica.requests[0] + assert request.transport_request_id == "transport-req-1" + assert request.transport_scheduling_group.group_kind == "weight_broadcast" + assert request.transport_scheduling_group.group_id == "model-a:v42" + assert request.transport_scheduling_group.total_parts == 16 + assert request.transport_scheduling_group.part_id == "daemon-1" + assert request.transport_scheduling_group.priority == 7 + assert request.transport_scheduling_group.epoch == 42 +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: + +```bash +pytest tests/python/api/test_daemon_ctl_transport_hints.py -v +``` + +Expected: FAIL because `DaemonCtl.materialize_by_artifact_id_v2()` does not accept transport hint kwargs. + +- [ ] **Step 3: Add DaemonCtl kwargs and request copying** + +In all three overloads and the implementation of `DaemonCtl.materialize_by_artifact_id_v2(...)`, add: + +```python +transport_request_id: str | None = None, +transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint | None = None, +``` + +After constructing `MaterializeReplicaRequest`, add: + +```python +if transport_request_id: + request.transport_request_id = str(transport_request_id) +if transport_scheduling_group is not None: + request.transport_scheduling_group.CopyFrom(transport_scheduling_group) +``` + +- [ ] **Step 4: Forward hints from `materialize_artifact_v2()`** + +In `tensorcast/api/_materialize.py`, import `TransportSchedulingGroup`: + +```python +from tensorcast.api.context import CallContext, CollectiveLoadGroup, TransportSchedulingGroup +``` + +Add optional parameters to `materialize_artifact_v2(...)`: + +```python +transport_request_id: str | None = None, +transport_scheduling_group: TransportSchedulingGroup | None = None, +``` + +Before the `client.materialize_by_artifact_id_v2(...)` call, convert the SDK group: + +```python +transport_group_proto = None +if transport_scheduling_group is not None: + transport_group_proto = store_daemon_pb2.TransportSchedulingGroupHint( + group_id=transport_scheduling_group.group_id, + group_kind=transport_scheduling_group.group_kind, + total_parts=int(transport_scheduling_group.total_parts), + part_id=transport_scheduling_group.part_id, + priority=int(transport_scheduling_group.priority), + epoch=int(transport_scheduling_group.epoch), + ) +``` + +Pass: + +```python +transport_request_id=transport_request_id, +transport_scheduling_group=transport_group_proto, +``` + +- [ ] **Step 5: Run SDK forwarding tests** + +Run: + +```bash +pytest tests/python/api/test_daemon_ctl_transport_hints.py tests/python/api/test_prefetch_operation.py -v +``` + +Expected: PASS. + +- [ ] **Step 6: Commit** + +Run: + +```bash +git add tensorcast/api/_materialize.py tensorcast/daemon_ctl.py tests/python/api/test_daemon_ctl_transport_hints.py tests/python/api/test_prefetch_operation.py +git commit -m "feat(sdk): forward prefetch transport hints to daemon" +``` + +### Task 5: C++ Daemon Hint Mapping + +**Files:** +- Modify: `daemon/service/controllers/materialization_policy_utils.h` +- Modify: `daemon/service/controllers/materialization_policy_utils.cc` +- Modify: `daemon/service/controllers/replica_materialization_service.cc` +- Modify: `daemon/service/materialization_policy_utils_test.cc` + +- [ ] **Step 1: Write failing C++ mapping tests** + +In `daemon/service/materialization_policy_utils_test.cc`, add this using declaration: + +```cpp +using tensorcast::daemon::materialization_policy::resolve_transport_scheduling_group_hint; +``` + +Append these test cases: + +```cpp +TEST_CASE( + "MaterializeReplica transport scheduling group maps to loading hint", + "[daemon][materialization][policy]") { + v2::TransportSchedulingGroupHint proto; + proto.set_group_kind("weight_broadcast"); + proto.set_group_id("model-a:v42"); + proto.set_total_parts(16); + proto.set_part_id("daemon-1"); + proto.set_priority(7); + proto.set_epoch(42); + + auto hint_or = resolve_transport_scheduling_group_hint(proto); + + REQUIRE(hint_or.has_value()); + CHECK(hint_or->group_kind == "weight_broadcast"); + CHECK(hint_or->group_id == "model-a:v42"); + CHECK(hint_or->total_parts == 16); + CHECK(hint_or->part_id == "daemon-1"); + CHECK(hint_or->priority == 7); + CHECK(hint_or->epoch == 42); +} + +TEST_CASE( + "MaterializeReplica transport scheduling group rejects incomplete values", + "[daemon][materialization][policy]") { + v2::TransportSchedulingGroupHint proto; + proto.set_group_kind("weight_broadcast"); + proto.set_group_id("model-a:v42"); + proto.set_total_parts(0); + proto.set_part_id("daemon-1"); + + auto hint_or = resolve_transport_scheduling_group_hint(proto); + + CHECK(!hint_or.has_value()); +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: + +```bash +bazel test //daemon:materialization_policy_utils_test --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: FAIL because `resolve_transport_scheduling_group_hint` does not exist. + +- [ ] **Step 3: Add C++ mapping helper** + +In `daemon/service/controllers/materialization_policy_utils.h`, declare: + +```cpp +std::optional resolve_transport_scheduling_group_hint( + const v2::TransportSchedulingGroupHint& group); +``` + +In `daemon/service/controllers/materialization_policy_utils.cc`, implement: + +```cpp +std::optional resolve_transport_scheduling_group_hint( + const v2::TransportSchedulingGroupHint& group) { + if (group.group_kind().empty() || group.group_id().empty() || group.part_id().empty() || group.total_parts() == 0) { + return std::nullopt; + } + store::loading::TransportSchedulingGroupHint out; + out.group_kind = group.group_kind(); + out.group_id = group.group_id(); + out.total_parts = group.total_parts(); + out.part_id = group.part_id(); + out.priority = group.priority(); + out.epoch = group.epoch(); + return out; +} +``` + +- [ ] **Step 4: Apply request fields in `materialize_replica()`** + +In `daemon/service/controllers/replica_materialization_service.cc`, add this using declaration near the existing materialization policy aliases: + +```cpp +using materialization_policy::resolve_transport_scheduling_group_hint; +``` + +After `apply_request_context_to_hints(request_context, &hints);`, add: + +```cpp +if (!req.transport_request_id().empty()) { + hints.transport_request_id = req.transport_request_id(); +} +if (req.has_transport_scheduling_group()) { + hints.transport_scheduling_group = + resolve_transport_scheduling_group_hint(req.transport_scheduling_group()); +} +``` + +- [ ] **Step 5: Run C++ mapping test** + +Run: + +```bash +bazel test //daemon:materialization_policy_utils_test --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: PASS. + +- [ ] **Step 6: Commit** + +Run: + +```bash +git add daemon/service/controllers/materialization_policy_utils.h daemon/service/controllers/materialization_policy_utils.cc daemon/service/controllers/replica_materialization_service.cc daemon/service/materialization_policy_utils_test.cc +git commit -m "feat(daemon): map materialize transport hints" +``` + +### Task 6: Plan Context and Node-Agent Prefetch Propagation + +**Files:** +- Modify: `proto/tensorcast/plan/v1/plan.proto` +- Generate/modify: `proto/gen/python/tensorcast/plan/v1/plan_pb2.py` +- Generate/modify: `proto/gen/python/tensorcast/plan/v1/plan_pb2.pyi` +- Modify: `tensorcast/api/plan/plan.py` +- Modify: `tensorcast/node_agent/executor.py` +- Test: `tests/python/api/test_plan_spec.py` +- Test: `tests/python/node_agent/test_plan_execution.py` + +- [ ] **Step 1: Write failing plan serialization test** + +In `tests/python/api/test_plan_spec.py`, add a test: + +```python +def test_plan_context_serializes_transport_group() -> None: + ctx = CallContext( + request_id="req-weight-broadcast", + idempotency_key="idem-weight-broadcast", + transport_group=TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=16, + part_id="daemon-1", + priority=3, + request_id="transport-req-1", + ), + ) + + spec = Plan(ctx).to_spec() + + assert spec.context.transport_group.group_kind == "weight_broadcast" + assert spec.context.transport_group.group_id == "model-a:v42" + assert spec.context.transport_group.epoch == 42 + assert spec.context.transport_group.total_parts == 16 + assert spec.context.transport_group.part_id == "daemon-1" + assert spec.context.transport_group.priority == 3 + assert spec.context.transport_group.request_id == "transport-req-1" +``` + +Add imports if missing: + +```python +from tensorcast.api.context import TransportSchedulingGroup +``` + +- [ ] **Step 2: Write failing node-agent propagation test** + +In `tests/python/node_agent/test_plan_execution.py`, extend `_DaemonStub.__init__`: + +```python +self.transport_request_ids: list[str | None] = [] +self.transport_groups: list[object] = [] +``` + +Extend `_DaemonStub.materialize_by_artifact_id_v2(...)`: + +```python +self.transport_request_ids.append(kwargs.get("transport_request_id")) +self.transport_groups.append(kwargs.get("transport_scheduling_group")) +``` + +Add: + +```python +def test_node_agent_prefetch_forwards_transport_group_from_plan_context() -> None: + daemon = _DaemonStub() + executor = NodeAgentExecutor(client=daemon) + plan = plan_pb2.PlanSpec() + plan.context.request_id = "req-weight-broadcast" + plan.context.idempotency_key = "idem-weight-broadcast" + plan.context.transport_group.group_kind = "weight_broadcast" + plan.context.transport_group.group_id = "model-a:v42" + plan.context.transport_group.epoch = 42 + plan.context.transport_group.total_parts = 16 + plan.context.transport_group.part_id = "daemon-1" + plan.context.transport_group.priority = 3 + plan.context.transport_group.request_id = "transport-req-1" + step = plan.steps.add() + step.step_id = "prefetch-1" + step.target.target_type = plan_pb2.TARGET_TYPE_WORKER + step.target.target_id = "daemon-1" + step.action.prefetch.selection.CopyFrom(_selection()) + step.action.prefetch.device_id = 0 + + result = executor.execute(plan) + + assert result.steps[0].status.state == node_agent_pb2.OPERATION_STATE_SUCCESS + assert daemon.transport_request_ids == ["transport-req-1"] + group = daemon.transport_groups[0] + assert group.group_kind == "weight_broadcast" + assert group.group_id == "model-a:v42" + assert group.total_parts == 16 + assert group.part_id == "daemon-1" + assert group.priority == 3 + assert group.epoch == 42 +``` + +- [ ] **Step 3: Run tests to verify they fail** + +Run: + +```bash +pytest tests/python/api/test_plan_spec.py::test_plan_context_serializes_transport_group tests/python/node_agent/test_plan_execution.py::test_node_agent_prefetch_forwards_transport_group_from_plan_context -v +``` + +Expected: FAIL because `plan.v1.CallContext` does not have `transport_group`. + +- [ ] **Step 4: Add plan proto fields** + +In `proto/tensorcast/plan/v1/plan.proto`, add: + +```proto +message TransportSchedulingGroup { + string group_id = 1; + string group_kind = 2; + uint32 total_parts = 3; + string part_id = 4; + uint32 priority = 5; + uint64 epoch = 6; + string request_id = 7; +} +``` + +Add to `message CallContext`: + +```proto + TransportSchedulingGroup transport_group = 6; +``` + +Run: + +```bash +bazel run @rules_buf_toolchains//:buf -- format ./proto -w +bash tools/build_proto_python.sh +``` + +- [ ] **Step 5: Serialize plan context transport group** + +In `tensorcast/api/plan/plan.py`, import `TransportSchedulingGroup` where `CallContext` is imported if needed. In `_call_context_proto()`, add: + +```python +if self._ctx.transport_group is not None: + group = self._ctx.transport_group + proto.transport_group.group_id = group.group_id + proto.transport_group.group_kind = group.group_kind + proto.transport_group.total_parts = int(group.total_parts) + proto.transport_group.part_id = group.part_id + proto.transport_group.priority = int(group.priority) + proto.transport_group.epoch = int(group.epoch) + if group.request_id: + proto.transport_group.request_id = group.request_id +``` + +In `_action_context(...)`, preserve the plan-level group: + +```python +transport_group=self._ctx.transport_group, +``` + +- [ ] **Step 6: Node-agent converts plan group to daemon group** + +In `tensorcast/node_agent/executor.py`, add helper: + +```python +def _transport_group_from_plan_context( + call_ctx: CallContext, +) -> store_daemon_pb2.TransportSchedulingGroupHint | None: + group = call_ctx.transport_group + if group is None: + return None + return store_daemon_pb2.TransportSchedulingGroupHint( + group_id=group.group_id, + group_kind=group.group_kind, + total_parts=int(group.total_parts), + part_id=group.part_id, + priority=int(group.priority), + epoch=int(group.epoch), + ) +``` + +In `tensorcast/node_agent/executor.py::_call_context_from_proto(...)`, include `transport_group` in the returned `CallContext`: + +```python +transport_group = None +if plan.context.HasField("transport_group") and plan.context.transport_group.group_kind: + proto_group = plan.context.transport_group + transport_group = TransportSchedulingGroup( + group_id=proto_group.group_id, + group_kind=proto_group.group_kind, + total_parts=int(proto_group.total_parts), + part_id=proto_group.part_id, + priority=int(proto_group.priority), + epoch=int(proto_group.epoch), + request_id=proto_group.request_id or None, + ) +``` + +Pass transport hints in `_materialize_selection(...)` by adding parameters: + +```python +transport_request_id: str | None = None, +transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint | None = None, +``` + +Forward to `self._client.materialize_by_artifact_id_v2(...)`: + +```python +transport_request_id=transport_request_id, +transport_scheduling_group=transport_scheduling_group, +``` + +In `_prefetch(...)`, call `_materialize_selection(...)` with: + +```python +transport_request_id=call_ctx.transport_group.request_id if call_ctx.transport_group else None, +transport_scheduling_group=_transport_group_from_plan_context(call_ctx), +``` + +In `_action_context(...)` and `_instance_action_context(...)`, preserve the existing group when creating derived action contexts: + +```python +transport_group=call_ctx.transport_group, +``` + +- [ ] **Step 7: Run plan and node-agent tests** + +Run: + +```bash +pytest tests/python/api/test_plan_spec.py::test_plan_context_serializes_transport_group tests/python/node_agent/test_plan_execution.py::test_node_agent_prefetch_forwards_transport_group_from_plan_context -v +``` + +Expected: PASS. + +- [ ] **Step 8: Commit** + +Run: + +```bash +git add proto/tensorcast/plan/v1/plan.proto proto/gen/python/tensorcast/plan/v1/plan_pb2.py proto/gen/python/tensorcast/plan/v1/plan_pb2.pyi proto/gen/python/tensorcast/plan/v1/plan_pb2_grpc.py tensorcast/api/plan/plan.py tensorcast/node_agent/executor.py tests/python/api/test_plan_spec.py tests/python/node_agent/test_plan_execution.py +git commit -m "feat(plan): propagate prefetch transport groups" +``` + +If `plan_pb2_grpc.py` has no diff, omit it from `git add`. + +### Task 7: Documentation, Verification, and Cleanup + +**Files:** +- Modify: `docs/designs/0116-control-plane-coordinated-weight-broadcast.md` +- Modify: `docs/plans/0116-control-plane-coordinated-weight-broadcast.md` +- Modify: `tensorcast/api/README.md` +- Modify: `tensorcast/api/store/README.md` + +- [ ] **Step 1: Link design to plan** + +In `docs/designs/0116-control-plane-coordinated-weight-broadcast.md`, add under `links:`: + +```yaml + plan: ../plans/0116-control-plane-coordinated-weight-broadcast.md +``` + +- [ ] **Step 2: Update API docs for the public typed context** + +In `tensorcast/api/README.md`, add a short bullet under "Programmable Control-Plane Primitives": + +```markdown +- `CallContext.transport_group=TransportSchedulingGroup(...)` carries explicit transport scheduling metadata for coordinated prefetch fanout. Use `group_kind="weight_broadcast"` and a stable `group_id` such as a model version to let Global Store group dispatch spread source selection. +``` + +In `tensorcast/api/store/README.md`, extend the prefetch section with: + +```markdown +Grouped model-weight prefetch can pass `ctx=CallContext(transport_group=TransportSchedulingGroup(...))`; the daemon forwards the group to Global Store transport scheduling while keeping `replica_uuid` as a pure daemon replica/session id. +``` + +- [ ] **Step 3: Run focused Python tests** + +Run: + +```bash +pytest tests/python/api/test_prefetch_operation.py tests/python/api/test_daemon_ctl_transport_hints.py tests/python/api/test_plan_spec.py::test_plan_context_serializes_transport_group tests/python/node_agent/test_plan_execution.py::test_node_agent_prefetch_forwards_transport_group_from_plan_context -v +``` + +Expected: PASS. + +- [ ] **Step 4: Run focused C++ test** + +Run: + +```bash +bazel test //daemon:materialization_policy_utils_test --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: PASS. + +- [ ] **Step 5: Run Global Store scheduler regression tests** + +Run: + +```bash +pytest tests/python/global_store/test_services.py::test_group_dispatch_refreshes_group_source_counts_per_pending tests/python/global_store/test_services.py::test_group_dispatch_rejects_duplicate_group_part_id tests/python/global_store/test_services.py::test_transport_service_group_progress_counts_success_only -v +``` + +If a test name has drifted, run: + +```bash +pytest tests/python/global_store/test_services.py -k "group_dispatch or group_progress or group_source" -v +``` + +Expected: PASS. + +- [ ] **Step 6: Check generated and unrelated dirty files** + +Run: + +```bash +git status --short +git diff --check +``` + +Expected: only this feature's tracked files are modified or staged. Pre-existing unrelated `pyproject.toml` and generated proto dirt must remain unstaged unless this implementation changed the same file for this feature. + +- [ ] **Step 7: Commit docs and verification updates** + +Run: + +```bash +git add docs/designs/0116-control-plane-coordinated-weight-broadcast.md docs/plans/0116-control-plane-coordinated-weight-broadcast.md tensorcast/api/README.md tensorcast/api/store/README.md +git commit -m "docs: document weight broadcast prefetch hints" +``` + +# Acceptance Checks + +- [ ] `Artifact.prefetch()` without `CallContext.transport_group` sends no transport hint. +- [ ] `Artifact.prefetch()` with typed transport group sends stable `transport_request_id` and complete group metadata. +- [ ] `DaemonCtl.materialize_by_artifact_id_v2()` copies transport hints into `MaterializeReplicaRequest`. +- [ ] `ReplicaMaterializationService::materialize_replica()` copies transport hints into `MaterializeHints`. +- [ ] Existing Global Store group dispatch tests still pass. +- [ ] `replica_uuid` is not overloaded with `#tcg:` metadata for prefetch. +- [ ] No unrelated pre-existing dirty files are reverted or staged. From 9b839ecc59f088d65761c02851b96da78bce0150 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:12:28 +0800 Subject: [PATCH 09/49] feat(sdk): add transport scheduling group context --- tensorcast/__init__.py | 3 ++ tensorcast/api/__init__.py | 2 + tensorcast/api/context.py | 45 +++++++++++++++++++++ tests/python/api/test_prefetch_operation.py | 34 ++++++++++++++++ 4 files changed, 84 insertions(+) diff --git a/tensorcast/__init__.py b/tensorcast/__init__.py index 68d71725..60b01a0b 100644 --- a/tensorcast/__init__.py +++ b/tensorcast/__init__.py @@ -143,6 +143,7 @@ def _install_c_extension_bootstrap() -> None: ), "CallContext": ("tensorcast.api", "CallContext"), "CollectiveLoadGroup": ("tensorcast.api", "CollectiveLoadGroup"), + "TransportSchedulingGroup": ("tensorcast.api", "TransportSchedulingGroup"), "GovernanceContext": ("tensorcast.api", "GovernanceContext"), "DirectorySnapshot": ("tensorcast.api", "DirectorySnapshot"), "CapabilityDirectoryClient": ( @@ -423,6 +424,7 @@ def __dir__() -> list[str]: CapabilityDirectoryClient, CapabilityDirectoryOptions, CollectiveLoadGroup, + TransportSchedulingGroup, DirectorySnapshot, ExecutionTopologyContext, GetArtifactOptions, @@ -520,6 +522,7 @@ def __dir__() -> list[str]: "binding_realization_plan_to_proto", "CallContext", "CollectiveLoadGroup", + "TransportSchedulingGroup", "ExecutionDiagnostics", "BindingUpdateEpoch", "HashBackend", diff --git a/tensorcast/api/__init__.py b/tensorcast/api/__init__.py index c9cd34ea..9cccd544 100644 --- a/tensorcast/api/__init__.py +++ b/tensorcast/api/__init__.py @@ -31,6 +31,7 @@ GovernanceContext, QosClass, SpanAttributeValue, + TransportSchedulingGroup, context, ) from tensorcast.api.directory import ( @@ -198,6 +199,7 @@ "DirectorySnapshot", "FinalizeClass", "GovernanceContext", + "TransportSchedulingGroup", "InstanceExecutionRoute", "Operation", "OperationError", diff --git a/tensorcast/api/context.py b/tensorcast/api/context.py index 42a31189..34828e35 100644 --- a/tensorcast/api/context.py +++ b/tensorcast/api/context.py @@ -18,6 +18,47 @@ class CollectiveLoadGroup: rank: int +@dataclass(frozen=True, slots=True) +class TransportSchedulingGroup: + """Control-plane transport scheduling group for coordinated P2P source selection.""" + + group_id: str + group_kind: str + total_parts: int + part_id: str + priority: int = 0 + epoch: int = 0 + request_id: str | None = None + + def __post_init__(self) -> None: + group_kind = str(self.group_kind).strip() + group_id = str(self.group_id).strip() + part_id = str(self.part_id).strip() + total_parts = int(self.total_parts) + priority = int(self.priority) + epoch = int(self.epoch) + request_id = None if self.request_id is None else str(self.request_id).strip() + if not group_kind: + raise ValueError("TransportSchedulingGroup.group_kind must be non-empty") + if not group_id: + raise ValueError("TransportSchedulingGroup.group_id must be non-empty") + if total_parts <= 0: + raise ValueError("TransportSchedulingGroup.total_parts must be positive") + if not part_id: + raise ValueError("TransportSchedulingGroup.part_id must be non-empty") + if priority < 0: + raise ValueError("TransportSchedulingGroup.priority must be non-negative") + if epoch < 0: + raise ValueError("TransportSchedulingGroup.epoch must be non-negative") + object.__setattr__(self, "group_kind", group_kind) + object.__setattr__(self, "group_id", group_id) + object.__setattr__(self, "total_parts", total_parts) + object.__setattr__(self, "part_id", part_id) + object.__setattr__(self, "priority", priority) + object.__setattr__(self, "epoch", epoch) + object.__setattr__(self, "request_id", request_id or None) + + @dataclass(frozen=True, slots=True) class GovernanceContext: """Typed low-cardinality governance hints propagated with a plan.""" @@ -37,6 +78,7 @@ class CallContext: idempotency_key: str | None = None tags: Mapping[str, SpanAttributeValue] | None = None collective: CollectiveLoadGroup | None = None + transport_group: TransportSchedulingGroup | None = None governance: GovernanceContext | None = None @@ -48,6 +90,7 @@ def context( idempotency_key: str | None = None, tags: Mapping[str, SpanAttributeValue] | None = None, collective: CollectiveLoadGroup | None = None, + transport_group: TransportSchedulingGroup | None = None, governance: GovernanceContext | None = None, ) -> CallContext: return CallContext( @@ -57,6 +100,7 @@ def context( idempotency_key=idempotency_key, tags=tags, collective=collective, + transport_group=transport_group, governance=governance, ) @@ -67,5 +111,6 @@ def context( "GovernanceContext", "QosClass", "SpanAttributeValue", + "TransportSchedulingGroup", "context", ] diff --git a/tests/python/api/test_prefetch_operation.py b/tests/python/api/test_prefetch_operation.py index d6315a9f..35077e00 100644 --- a/tests/python/api/test_prefetch_operation.py +++ b/tests/python/api/test_prefetch_operation.py @@ -10,6 +10,7 @@ import tensorcast as tc from tensorcast.api._materialize import MaterializationPayload from tensorcast.api._device import device_uuid_for +from tensorcast.api.context import TransportSchedulingGroup from tensorcast.api.store.artifact import Artifact from tensorcast.common.selection_identity import ( compute_selection_hash, @@ -91,6 +92,39 @@ def __init__(self) -> None: self.closed = False +def test_transport_scheduling_group_rejects_invalid_values() -> None: + invalid_cases = [ + {"group_kind": "", "group_id": "model:v1", "total_parts": 2, "part_id": "d0"}, + {"group_kind": "weight_broadcast", "group_id": "", "total_parts": 2, "part_id": "d0"}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 0, "part_id": "d0"}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 2, "part_id": ""}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 2, "part_id": "d0", "priority": -1}, + {"group_kind": "weight_broadcast", "group_id": "model:v1", "total_parts": 2, "part_id": "d0", "epoch": -1}, + ] + + for kwargs in invalid_cases: + try: + TransportSchedulingGroup(**kwargs) + except ValueError: + continue + raise AssertionError(f"expected invalid transport group: {kwargs}") + + +def test_context_accepts_typed_transport_group() -> None: + group = tc.TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=8, + part_id="daemon-3", + ) + + ctx = tc.context(request_id="req-1", transport_group=group) + + assert ctx.transport_group == group + assert tc.TransportSchedulingGroup is TransportSchedulingGroup + + def test_prefetch_uses_deterministic_operation_id() -> None: store = _Store() artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") From 0ef397845a7581aa2a0a325435b24cdc529d7145 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:16:12 +0800 Subject: [PATCH 10/49] feat(sdk): derive prefetch transport group hints --- tensorcast/api/store/artifact.py | 104 +++++++++++++++++++- tests/python/api/test_prefetch_operation.py | 71 ++++++++++++- 2 files changed, 168 insertions(+), 7 deletions(-) diff --git a/tensorcast/api/store/artifact.py b/tensorcast/api/store/artifact.py index df5e6331..a51158c7 100644 --- a/tensorcast/api/store/artifact.py +++ b/tensorcast/api/store/artifact.py @@ -15,6 +15,7 @@ from datetime import timezone from typing import ( TYPE_CHECKING, + Callable, Mapping, Sequence, SupportsIndex, @@ -32,7 +33,7 @@ _coerce_slice_spec, build_view_spec, ) -from tensorcast.api.context import CallContext +from tensorcast.api.context import CallContext, TransportSchedulingGroup from tensorcast.api.operation import ( DaemonReplicaOperation, Operation, @@ -523,6 +524,78 @@ def _build_transport_operation_id( return base_operation_id +def _transport_group_from_ctx_tags( + ctx: CallContext | None, +) -> TransportSchedulingGroup | None: + if ctx is None or not ctx.tags: + return None + group_kind = _read_context_tag_str(ctx.tags, _TRANSPORT_GROUP_KIND_TAG) + group_id = _read_context_tag_str(ctx.tags, _TRANSPORT_GROUP_ID_TAG) + part_id = _read_context_tag_str(ctx.tags, _TRANSPORT_GROUP_PART_ID_TAG) + total_parts = _read_context_tag_int( + ctx.tags, + _TRANSPORT_GROUP_TOTAL_PARTS_TAG, + default=0, + ) + if not (group_kind and group_id and part_id and total_parts > 0): + return None + return TransportSchedulingGroup( + group_kind=group_kind, + group_id=group_id, + total_parts=total_parts, + part_id=part_id, + priority=_read_context_tag_int( + ctx.tags, + _TRANSPORT_GROUP_PRIORITY_TAG, + default=0, + ), + epoch=_read_context_tag_int( + ctx.tags, + _TRANSPORT_GROUP_EPOCH_TAG, + default=0, + ), + request_id=_read_context_tag_str(ctx.tags, _TRANSPORT_REQUEST_ID_TAG) or None, + ) + + +def _resolve_prefetch_transport_hints( + *, + ctx: CallContext | None, + daemon_id: str, + artifact_id: str, + selection_hash: str, + logical_layout_hash: str, + device_id: int, + device_uuid_factory: Callable[[], str], +) -> tuple[str | None, TransportSchedulingGroup | None]: + group = ( + ctx.transport_group if ctx is not None else None + ) or _transport_group_from_ctx_tags(ctx) + if group is None: + return None, None + if group.request_id: + return group.request_id, group + device_uuid = device_uuid_factory() + digest = hashlib.sha256() + digest.update(b"tensorcast.prefetch.transport.v1") + for value in ( + daemon_id, + artifact_id, + logical_layout_hash, + selection_hash, + str(device_id), + device_uuid, + group.group_kind, + group.group_id, + str(group.epoch), + str(group.total_parts), + group.part_id, + ): + digest.update(b"|") + digest.update(str(value).encode("utf-8")) + return f"prefetch:{digest.hexdigest()}", group + + def _register_client_binding( *, runtime: "StoreRuntimeContext", @@ -2136,18 +2209,27 @@ def prefetch( selection = self._build_artifact_selection() view_id = view_id_hint or selection.view_id selection_hash = bytes(selection.selection_hash).hex() + logical_layout_hash = bytes(selection.logical_layout_hash).hex() + device_uuid_value: str | None = None + + def _device_uuid_value() -> str: + nonlocal device_uuid_value + if device_uuid_value is None: + device_uuid_value = ( + "" if device_id == CPU_DEVICE_ID else device_uuid_for(device_id) + ) + return device_uuid_value deterministic_replica_uuid: str | None = None if ctx is not None and ctx.idempotency_key: - logical_layout_hash = bytes(selection.logical_layout_hash).hex() - device_uuid = device_uuid_for(device_id) + resolved_device_uuid = _device_uuid_value() ns = uuid.uuid5(uuid.NAMESPACE_DNS, "tensorcast.op.v1") idempotency_key_hex = hashlib.sha256( ctx.idempotency_key.encode("utf-8") ).hexdigest() action_fingerprint = ( f"prefetch|daemon={daemon_id}|artifact={artifact_id}|layout={logical_layout_hash}" - f"|selection={selection_hash}|device={device_id}|device_uuid={device_uuid}|lease=NO_LEASE|v2" + f"|selection={selection_hash}|device={device_id}|device_uuid={resolved_device_uuid}|lease=NO_LEASE|v2" ) deterministic_replica_uuid = str( uuid.uuid5(ns, f"{idempotency_key_hex}|{action_fingerprint}") @@ -2164,6 +2246,18 @@ def prefetch( if not replica_uuid: replica_uuid = uuid.uuid4().hex + transport_request_id, transport_scheduling_group = ( + _resolve_prefetch_transport_hints( + ctx=ctx, + daemon_id=daemon_id, + artifact_id=artifact_id, + selection_hash=selection_hash, + logical_layout_hash=logical_layout_hash, + device_id=device_id, + device_uuid_factory=_device_uuid_value, + ) + ) + payload, _ = pipeline.materialize_subset( artifact_id=artifact_id, key=None, @@ -2178,6 +2272,8 @@ def prefetch( options=opts, ctx=ctx, lease_mode=store_daemon_pb2.LeaseMode.LEASE_MODE_NO_LEASE, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) self._update_metadata_from_payload(payload, runtime) operation_id = payload.ticket_replica_uuid or payload.replica_uuid or "" diff --git a/tests/python/api/test_prefetch_operation.py b/tests/python/api/test_prefetch_operation.py index 35077e00..9a7a5979 100644 --- a/tests/python/api/test_prefetch_operation.py +++ b/tests/python/api/test_prefetch_operation.py @@ -3,13 +3,13 @@ from __future__ import annotations import hashlib +import importlib import uuid import weakref from typing import Any import tensorcast as tc from tensorcast.api._materialize import MaterializationPayload -from tensorcast.api._device import device_uuid_for from tensorcast.api.context import TransportSchedulingGroup from tensorcast.api.store.artifact import Artifact from tensorcast.common.selection_identity import ( @@ -17,6 +17,8 @@ ) from tensorcast.proto.daemon.v2 import store_daemon_pb2 +artifact_module = importlib.import_module("tensorcast.api.store.artifact") + class _Client: def get_artifact_index_by_id(self, artifact_id: str) -> bytes: @@ -125,7 +127,8 @@ def test_context_accepts_typed_transport_group() -> None: assert tc.TransportSchedulingGroup is TransportSchedulingGroup -def test_prefetch_uses_deterministic_operation_id() -> None: +def test_prefetch_uses_deterministic_operation_id(monkeypatch) -> None: # noqa: ANN001 + monkeypatch.setattr(artifact_module, "device_uuid_for", lambda _device_id: "uuid-0") store = _Store() artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") ctx = tc.context(request_id="req-1", idempotency_key="idem-1") @@ -138,7 +141,7 @@ def test_prefetch_uses_deterministic_operation_id() -> None: view_subset_hash=None, ).hex() logical_layout_hash = artifact._build_artifact_selection().logical_layout_hash.hex() - device_uuid = device_uuid_for(0) + device_uuid = "uuid-0" action_fingerprint = ( f"prefetch|daemon={daemon_id}|artifact=aid|layout={logical_layout_hash}" f"|selection={selection_hash}|device=0|device_uuid={device_uuid}|lease=NO_LEASE|v2" @@ -170,3 +173,65 @@ def test_prefetch_without_ctx_generates_operation_id() -> None: replica_uuid = str(store._materialization.calls[0]["replica_uuid"] or "") assert replica_uuid assert op.operation_id == replica_uuid + + +def test_prefetch_forwards_typed_transport_group_hint() -> None: + store = _Store() + artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") + group = tc.TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=16, + part_id="daemon-1", + priority=7, + request_id="explicit-transport-req", + ) + + artifact.prefetch(device="cuda:0", ctx=tc.context(transport_group=group)) + + call = store._materialization.calls[0] + assert call["transport_request_id"] == "explicit-transport-req" + forwarded = call["transport_scheduling_group"] + assert forwarded.group_kind == "weight_broadcast" + assert forwarded.group_id == "model-a:v42" + assert forwarded.epoch == 42 + assert forwarded.total_parts == 16 + assert forwarded.part_id == "daemon-1" + assert forwarded.priority == 7 + + +def test_prefetch_derives_stable_transport_request_id_for_group(monkeypatch) -> None: # noqa: ANN001 + monkeypatch.setattr(artifact_module, "device_uuid_for", lambda _device_id: "uuid-0") + group = tc.TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + epoch=42, + total_parts=16, + part_id="daemon-1", + ) + ctx = tc.context(transport_group=group) + first_store = _Store() + second_store = _Store() + first = Artifact(store_ref=weakref.ref(first_store), artifact_id="aid") + second = Artifact(store_ref=weakref.ref(second_store), artifact_id="aid") + + first.prefetch(device="cuda:0", ctx=ctx) + second.prefetch(device="cuda:0", ctx=ctx) + + first_request_id = first_store._materialization.calls[0]["transport_request_id"] + second_request_id = second_store._materialization.calls[0]["transport_request_id"] + assert first_request_id + assert first_request_id == second_request_id + assert first_request_id.startswith("prefetch:") + + +def test_prefetch_without_group_sends_no_transport_hint() -> None: + store = _Store() + artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") + + artifact.prefetch(device="cuda:0") + + call = store._materialization.calls[0] + assert call["transport_request_id"] is None + assert call["transport_scheduling_group"] is None From d3e1cf741c0b9fe4d449d0a13417ea91cb822eb3 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:21:36 +0800 Subject: [PATCH 11/49] docs: align broadcast plan with ignored proto outputs --- ...trol-plane-coordinated-weight-broadcast.md | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/docs/plans/0116-control-plane-coordinated-weight-broadcast.md b/docs/plans/0116-control-plane-coordinated-weight-broadcast.md index 3ad30b77..935cf759 100644 --- a/docs/plans/0116-control-plane-coordinated-weight-broadcast.md +++ b/docs/plans/0116-control-plane-coordinated-weight-broadcast.md @@ -33,9 +33,9 @@ links: # Files - Modify: `proto/tensorcast/daemon/v2/store_daemon.proto` -- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py` -- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi` -- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py` if Buf rewrites it +- Generate locally: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py` (ignored by git) +- Generate locally: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi` (ignored by git) +- Generate locally: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py` (ignored by git) - Modify: `tensorcast/api/context.py` - Modify: `tensorcast/api/__init__.py` - Modify: `tensorcast/__init__.py` @@ -416,9 +416,9 @@ git commit -m "feat(sdk): derive prefetch transport group hints" **Files:** - Modify: `proto/tensorcast/daemon/v2/store_daemon.proto` -- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py` -- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi` -- Generate/modify: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py` if Buf rewrites it +- Generate locally: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py` (ignored by git) +- Generate locally: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi` (ignored by git) +- Generate locally: `proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py` (ignored by git) - [ ] **Step 1: Add daemon transport group proto fields** @@ -451,7 +451,7 @@ bazel run @rules_buf_toolchains//:buf -- format ./proto -w bash tools/build_proto_python.sh ``` -Expected: `store_daemon_pb2.py` and `store_daemon_pb2.pyi` include `TransportSchedulingGroupHint`, `transport_request_id`, and `transport_scheduling_group`. +Expected: local generated `store_daemon_pb2.py` and `store_daemon_pb2.pyi` include `TransportSchedulingGroupHint`, `transport_request_id`, and `transport_scheduling_group`. These generated files are ignored by `.gitignore`; they are validation artifacts, not files to force-add. - [ ] **Step 3: Inspect generated changes carefully** @@ -462,19 +462,17 @@ git status --short proto/tensorcast/daemon/v2/store_daemon.proto proto/gen/pytho git diff -- proto/tensorcast/daemon/v2/store_daemon.proto proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py ``` -Expected: diffs are limited to daemon v2 proto generation for the added fields. Do not stage unrelated pre-existing generated proto directories. +Expected: tracked diffs are limited to `proto/tensorcast/daemon/v2/store_daemon.proto`. Generated daemon files may appear under ignored or untracked paths for local validation. Do not stage unrelated pre-existing generated proto directories. - [ ] **Step 4: Commit** Run: ```bash -git add proto/tensorcast/daemon/v2/store_daemon.proto proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.py proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2.pyi proto/gen/python/tensorcast/daemon/v2/store_daemon_pb2_grpc.py +git add proto/tensorcast/daemon/v2/store_daemon.proto git commit -m "feat(proto): add materialize replica transport hints" ``` -If `store_daemon_pb2_grpc.py` has no diff, omit it from `git add`. - ### Task 4: DaemonCtl and Materialization Pipeline Forwarding **Files:** @@ -778,8 +776,8 @@ git commit -m "feat(daemon): map materialize transport hints" **Files:** - Modify: `proto/tensorcast/plan/v1/plan.proto` -- Generate/modify: `proto/gen/python/tensorcast/plan/v1/plan_pb2.py` -- Generate/modify: `proto/gen/python/tensorcast/plan/v1/plan_pb2.pyi` +- Generate locally: `proto/gen/python/tensorcast/plan/v1/plan_pb2.py` (ignored by git) +- Generate locally: `proto/gen/python/tensorcast/plan/v1/plan_pb2.pyi` (ignored by git) - Modify: `tensorcast/api/plan/plan.py` - Modify: `tensorcast/node_agent/executor.py` - Test: `tests/python/api/test_plan_spec.py` @@ -1016,12 +1014,10 @@ Expected: PASS. Run: ```bash -git add proto/tensorcast/plan/v1/plan.proto proto/gen/python/tensorcast/plan/v1/plan_pb2.py proto/gen/python/tensorcast/plan/v1/plan_pb2.pyi proto/gen/python/tensorcast/plan/v1/plan_pb2_grpc.py tensorcast/api/plan/plan.py tensorcast/node_agent/executor.py tests/python/api/test_plan_spec.py tests/python/node_agent/test_plan_execution.py +git add proto/tensorcast/plan/v1/plan.proto tensorcast/api/plan/plan.py tensorcast/node_agent/executor.py tests/python/api/test_plan_spec.py tests/python/node_agent/test_plan_execution.py git commit -m "feat(plan): propagate prefetch transport groups" ``` -If `plan_pb2_grpc.py` has no diff, omit it from `git add`. - ### Task 7: Documentation, Verification, and Cleanup **Files:** From c7d619c3446cef10cb38cf984d2228e8bf74cfc3 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:21:51 +0800 Subject: [PATCH 12/49] feat(proto): add materialize replica transport hints --- proto/tensorcast/daemon/v2/store_daemon.proto | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/proto/tensorcast/daemon/v2/store_daemon.proto b/proto/tensorcast/daemon/v2/store_daemon.proto index 2fae6025..9503edb7 100644 --- a/proto/tensorcast/daemon/v2/store_daemon.proto +++ b/proto/tensorcast/daemon/v2/store_daemon.proto @@ -515,6 +515,15 @@ message ReplicaOperationStatus { optional ReplicaOperationError error = 5; } +message TransportSchedulingGroupHint { + string group_id = 1; + string group_kind = 2; + uint32 total_parts = 3; + string part_id = 4; + uint32 priority = 5; + uint64 epoch = 6; +} + message MaterializeReplicaRequest { tensorcast.common.v1.ArtifactSelection selection = 1; reserved 2, 9; @@ -546,6 +555,8 @@ message MaterializeReplicaRequest { // Explicit collective group hint for same-host coordinated disk->GPU loads. optional CollectiveLoadGroup collective_load_group = 19; ServingArtifactRuntimePolicy serving_artifact_policy = 20; + string transport_request_id = 21; + TransportSchedulingGroupHint transport_scheduling_group = 22; } message CollectiveLoadGroup { From f4c34ab101547fb761653476aa22cf2af85f6a84 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 12:27:00 +0800 Subject: [PATCH 13/49] feat(sdk): forward prefetch transport hints to daemon --- tensorcast/api/_materialize.py | 28 +++- tensorcast/api/store/materialization.py | 24 +++- tensorcast/daemon_ctl.py | 16 +++ .../api/test_daemon_ctl_transport_hints.py | 131 ++++++++++++++++++ .../api/test_materialization_pipeline_v2.py | 41 +++++- 5 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 tests/python/api/test_daemon_ctl_transport_hints.py diff --git a/tensorcast/api/_materialize.py b/tensorcast/api/_materialize.py index bf39a148..9fbac768 100644 --- a/tensorcast/api/_materialize.py +++ b/tensorcast/api/_materialize.py @@ -28,7 +28,11 @@ from tensorcast.api._errors import DaemonUnavailable, IndexParseError from tensorcast.api._runtime import apply_client_load_defaults_if_present from tensorcast.api._utils import new_uuid -from tensorcast.api.context import CallContext, CollectiveLoadGroup +from tensorcast.api.context import ( + CallContext, + CollectiveLoadGroup, + TransportSchedulingGroup, +) from tensorcast.common.selection_contract import ( build_artifact_selection, compute_selected_index_bytes, @@ -208,6 +212,21 @@ def _resolve_collective_load_group( return group +def _transport_group_to_daemon_proto( + group: TransportSchedulingGroup | None, +) -> store_daemon_pb2.TransportSchedulingGroupHint | None: + if group is None: + return None + return store_daemon_pb2.TransportSchedulingGroupHint( + group_id=group.group_id, + group_kind=group.group_kind, + total_parts=int(group.total_parts), + part_id=group.part_id, + priority=int(group.priority), + epoch=int(group.epoch), + ) + + def _build_artifact_selection( *, artifact_id: str, @@ -267,6 +286,8 @@ def materialize_artifact_v2( ctx: CallContext | None = None, timeout_s: float | None = None, lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, + transport_request_id: str | None = None, + transport_scheduling_group: TransportSchedulingGroup | None = None, ) -> MaterializationPayload: if artifact_id is not None and key is not None: raise ValueError("Exactly one of artifact_id or key must be provided") @@ -369,6 +390,9 @@ def materialize_artifact_v2( ) replica_uuid_value = replica_uuid or new_uuid() collective_load_group = _resolve_collective_load_group(ctx) + transport_group_proto = _transport_group_to_daemon_proto( + transport_scheduling_group + ) request_device_uuid = ( "" @@ -391,6 +415,8 @@ def materialize_artifact_v2( target_device_type=target_device_type, lease_mode=lease_mode, collective_load_group=collective_load_group, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_group_proto, timeout_s=effective_timeout_s, timing_out=materialize_timing, ) diff --git a/tensorcast/api/store/materialization.py b/tensorcast/api/store/materialization.py index 962ed98b..8bb772b8 100644 --- a/tensorcast/api/store/materialization.py +++ b/tensorcast/api/store/materialization.py @@ -32,7 +32,7 @@ MaterializationPayload, materialize_artifact_v2, ) -from tensorcast.api.context import CallContext +from tensorcast.api.context import CallContext, TransportSchedulingGroup from tensorcast.api.store.async_ops import ArtifactFuture, TrackedExecutor from tensorcast.api.store.cache import ArtifactCacheEntry from tensorcast.api.store.common import ( @@ -380,6 +380,8 @@ def materialize_subset( options: GetArtifactOptions | None = None, ctx: CallContext | None = None, lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, + transport_request_id: str | None = None, + transport_scheduling_group: TransportSchedulingGroup | None = None, ) -> tuple[MaterializationPayload, int]: return self._perform_get_with_retry( method="get", @@ -398,6 +400,8 @@ def materialize_subset( allow_cpu=True, ctx=ctx, lease_mode=lease_mode, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) def get_view( @@ -1816,6 +1820,8 @@ def _materialize( view_data_hash: str | None = None, view_index_hint: bytes | None = None, replica_uuid: str | None = None, + transport_request_id: str | None = None, + transport_scheduling_group: TransportSchedulingGroup | None = None, ) -> MaterializationPayload: return self._materialize_payload( artifact_id=artifact_id, @@ -1835,6 +1841,8 @@ def _materialize( view_data_hash=view_data_hash, view_index_hint=view_index_hint, replica_uuid=replica_uuid, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) def _materialize_payload( @@ -1857,6 +1865,8 @@ def _materialize_payload( view_data_hash: str | None = None, view_index_hint: bytes | None = None, replica_uuid: str | None = None, + transport_request_id: str | None = None, + transport_scheduling_group: TransportSchedulingGroup | None = None, ) -> MaterializationPayload: client = self._runtime.ensure_client() resolved_artifact_id = artifact_id @@ -1908,6 +1918,8 @@ def _materialize_payload( ctx=ctx, timeout_s=timeout_s, lease_mode=lease_mode, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) disallowed_sources: set[store_daemon_pb2.MaterializationSource] = set() if not allow_p2p: @@ -2074,6 +2086,8 @@ def _perform_get_with_retry( allow_cpu: bool = False, ctx: CallContext | None = None, lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, + transport_request_id: str | None = None, + transport_scheduling_group: TransportSchedulingGroup | None = None, ) -> tuple[MaterializationPayload, int]: options_snapshot = self._build_get_options(options_override) retrieval_policy = options_snapshot.source or RetrievalPolicy() @@ -2229,6 +2243,8 @@ def record_outcome(status: str) -> None: lease_mode=lease_mode, ctx=ctx, timeout_s=rpc_timeout_s, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) summary = self._summarize_materialized(materialized, tensor_names) selection_label = summary["selection"] @@ -2329,6 +2345,8 @@ def record_outcome(status: str) -> None: lease_mode=lease_mode, ctx=ctx, timeout_s=wait_timeout_s, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) except Exception as exc: # noqa: BLE001 error = map_materialization_error(exc) @@ -2460,6 +2478,8 @@ def _attempt_get( view_index_hint: bytes | None = None, replica_uuid: str | None = None, allow_cpu: bool = False, + transport_request_id: str | None = None, + transport_scheduling_group: TransportSchedulingGroup | None = None, ) -> tuple[MaterializationPayload, int]: artifact_id, key = self._resolve_identifiers(artifact_id, key) options = self._build_get_options(options_override) @@ -2483,6 +2503,8 @@ def _attempt_get( ctx=ctx, timeout_s=timeout_s, lease_mode=lease_mode, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) except Exception as exc: # noqa: BLE001 if "selection.logical_layout_hash does not match resolved selection" in str( diff --git a/tensorcast/daemon_ctl.py b/tensorcast/daemon_ctl.py index c42bf9a3..4596057e 100644 --- a/tensorcast/daemon_ctl.py +++ b/tensorcast/daemon_ctl.py @@ -1988,6 +1988,9 @@ def materialize_by_artifact_id_v2( target_device_type: store_daemon_pb2.DeviceType = store_daemon_pb2.DeviceType.DEVICE_TYPE_GPU, lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, collective_load_group: store_daemon_pb2.CollectiveLoadGroup | None = None, + transport_request_id: str | None = None, + transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint + | None = None, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> store_daemon_pb2.MaterializeReplicaResponse: ... @@ -2011,6 +2014,9 @@ def materialize_by_artifact_id_v2( target_device_type: store_daemon_pb2.DeviceType = store_daemon_pb2.DeviceType.DEVICE_TYPE_GPU, lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, collective_load_group: store_daemon_pb2.CollectiveLoadGroup | None = None, + transport_request_id: str | None = None, + transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint + | None = None, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> tuple[bytes, store_daemon_pb2.MaterializeReplicaStatus]: ... @@ -2033,6 +2039,9 @@ def materialize_by_artifact_id_v2( target_device_type: store_daemon_pb2.DeviceType = store_daemon_pb2.DeviceType.DEVICE_TYPE_GPU, lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, collective_load_group: store_daemon_pb2.CollectiveLoadGroup | None = None, + transport_request_id: str | None = None, + transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint + | None = None, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> bytes: ... @@ -2054,6 +2063,9 @@ def materialize_by_artifact_id_v2( target_device_type: store_daemon_pb2.DeviceType = store_daemon_pb2.DeviceType.DEVICE_TYPE_GPU, lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, collective_load_group: store_daemon_pb2.CollectiveLoadGroup | None = None, + transport_request_id: str | None = None, + transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint + | None = None, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> ( @@ -2090,6 +2102,10 @@ def materialize_by_artifact_id_v2( ) if collective_load_group is not None: request.collective_load_group.CopyFrom(collective_load_group) + if transport_request_id: + request.transport_request_id = str(transport_request_id) + if transport_scheduling_group is not None: + request.transport_scheduling_group.CopyFrom(transport_scheduling_group) if wait_for_shared_disk_ms: request.wait_for_shared_disk_ms = int(wait_for_shared_disk_ms) request.source_policy.CopyFrom(resolved_source_policy) diff --git a/tests/python/api/test_daemon_ctl_transport_hints.py b/tests/python/api/test_daemon_ctl_transport_hints.py new file mode 100644 index 00000000..7fd43bf6 --- /dev/null +++ b/tests/python/api/test_daemon_ctl_transport_hints.py @@ -0,0 +1,131 @@ +# Copyright (c) 2026, TensorCast Team. + +from __future__ import annotations + +import torch + +from tensorcast.daemon_ctl import DaemonCtl +from tensorcast.api._config import GetArtifactOptions +from tensorcast.api._materialize import materialize_artifact_v2 +from tensorcast.api.context import TransportSchedulingGroup +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.daemon.v2 import store_daemon_pb2 + + +class _FakeUnary: + _method = b"/tensorcast.daemon.v2.StoreDaemonService/MaterializeReplica" + + def __init__(self) -> None: + self.requests: list[store_daemon_pb2.MaterializeReplicaRequest] = [] + + def __call__(self, request, timeout=None): # noqa: ANN001, ANN204 + del timeout + self.requests.append(request) + response = store_daemon_pb2.MaterializeReplicaResponse() + response.status = store_daemon_pb2.MaterializeReplicaStatus.MATERIALIZE_REPLICA_STATUS_ALLOCATED + response.ticket.replica_uuid = request.replica_uuid + return response + + +class _FakeStub: + def __init__(self) -> None: + self.MaterializeReplica = _FakeUnary() + + +def test_daemon_ctl_forwards_materialize_transport_hints(monkeypatch) -> None: # noqa: ANN001 + ctl = DaemonCtl.__new__(DaemonCtl) + ctl.server_address = "fake-daemon" + fake_stub = _FakeStub() + ctl.stub_v2 = fake_stub + ctl.stub = fake_stub + monkeypatch.setattr(ctl, "_get_effective_pid", lambda: 123) + monkeypatch.setattr( + ctl, + "_unary_call", + lambda method, request, **kwargs: method( + request, + timeout=kwargs.get("timeout"), + ), + ) + selection = common_pb2.ArtifactSelection(artifact_id="aid") + group = store_daemon_pb2.TransportSchedulingGroupHint( + group_kind="weight_broadcast", + group_id="model-a:v42", + total_parts=16, + part_id="daemon-1", + priority=7, + epoch=42, + ) + + ctl.materialize_by_artifact_id_v2( + selection=selection, + replica_uuid="replica-1", + device_uuid="device-uuid", + wait_for_completion=False, + return_response=True, + transport_request_id="transport-req-1", + transport_scheduling_group=group, + ) + + request = fake_stub.MaterializeReplica.requests[0] + assert request.transport_request_id == "transport-req-1" + assert request.transport_scheduling_group.group_kind == "weight_broadcast" + assert request.transport_scheduling_group.group_id == "model-a:v42" + assert request.transport_scheduling_group.total_parts == 16 + assert request.transport_scheduling_group.part_id == "daemon-1" + assert request.transport_scheduling_group.priority == 7 + assert request.transport_scheduling_group.epoch == 42 + + +class _FakeMaterializeClient: + def __init__(self) -> None: + self.calls: list[dict[str, object]] = [] + + def get_artifact_index_by_id(self, artifact_id: str) -> bytes: + del artifact_id + return b"{}" + + def materialize_by_artifact_id_v2(self, **kwargs): + self.calls.append(kwargs) + response = store_daemon_pb2.MaterializeReplicaResponse() + response.status = store_daemon_pb2.MaterializeReplicaStatus.MATERIALIZE_REPLICA_STATUS_ALLOCATED + response.artifact_id = "aid" + response.canonical_index_bytes = b"{}" + return response + + +def test_materialize_artifact_v2_converts_transport_group_to_daemon_proto() -> None: + client = _FakeMaterializeClient() + group = TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + total_parts=16, + part_id="daemon-1", + priority=7, + epoch=42, + ) + + materialize_artifact_v2( + client=client, + daemon_address="daemon", + device_id=torch.device("cpu"), + artifact_id="aid", + key=None, + options=GetArtifactOptions( + wait_for_completion=False, + enable_verification=False, + ), + transport_request_id="transport-req-1", + transport_scheduling_group=group, + ) + + request = client.calls[0] + assert request["transport_request_id"] == "transport-req-1" + forwarded = request["transport_scheduling_group"] + assert isinstance(forwarded, store_daemon_pb2.TransportSchedulingGroupHint) + assert forwarded.group_kind == "weight_broadcast" + assert forwarded.group_id == "model-a:v42" + assert forwarded.total_parts == 16 + assert forwarded.part_id == "daemon-1" + assert forwarded.priority == 7 + assert forwarded.epoch == 42 diff --git a/tests/python/api/test_materialization_pipeline_v2.py b/tests/python/api/test_materialization_pipeline_v2.py index 031ff4c2..68b1e2d1 100644 --- a/tests/python/api/test_materialization_pipeline_v2.py +++ b/tests/python/api/test_materialization_pipeline_v2.py @@ -17,7 +17,11 @@ TensorPayloadDescriptor, _resolve_collective_load_group, ) -from tensorcast.api.context import CallContext, CollectiveLoadGroup +from tensorcast.api.context import ( + CallContext, + CollectiveLoadGroup, + TransportSchedulingGroup, +) from tensorcast.api.store.cache import ArtifactCache, ArtifactCacheEntry from tensorcast.api.store.materialization import MaterializationPipeline from tensorcast.api.store.retry import build_retry_policies @@ -457,3 +461,38 @@ def test_materialize_subset_preserves_generation(): runtime.close() assert materialized.generation == 5 + + +def test_materialize_subset_forwards_transport_hints(): + runtime = _RuntimeStub() + views = ViewOrchestrator(runtime) + pipeline = MaterializationPipeline(runtime, views) + payload = _make_payload({"a": torch.ones(1)}, replica_uuid="transport") + captured: dict[str, object] = {} + group = TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + total_parts=16, + part_id="daemon-1", + priority=7, + epoch=42, + ) + + def fake_materialize(**kwargs): + captured.update(kwargs) + return payload + + pipeline.set_materialize_fn(fake_materialize) + materialized, _ = pipeline.materialize_subset( + artifact_id="aid", + key=None, + device=0, + tensor_names=None, + transport_request_id="transport-req-1", + transport_scheduling_group=group, + ) + runtime.close() + + assert materialized.replica_uuid == "transport" + assert captured["transport_request_id"] == "transport-req-1" + assert captured["transport_scheduling_group"] == group From 9529a15b336807a41136945afe493551ed8b5955 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 13:13:01 +0800 Subject: [PATCH 14/49] feat(daemon): map materialize transport hints --- .../materialization_policy_utils.cc | 16 ++++++++++++++ .../materialization_policy_utils.h | 3 +++ .../replica_materialization_service.cc | 11 ++++++++++ .../materialization_policy_utils_test.cc | 21 +++++++++++++++++++ 4 files changed, 51 insertions(+) diff --git a/daemon/service/controllers/materialization_policy_utils.cc b/daemon/service/controllers/materialization_policy_utils.cc index c65eb812..b247ba7a 100644 --- a/daemon/service/controllers/materialization_policy_utils.cc +++ b/daemon/service/controllers/materialization_policy_utils.cc @@ -106,6 +106,22 @@ std::optional resolve_collective_group_ }; } +std::optional resolve_transport_scheduling_group_hint( + const v2::TransportSchedulingGroupHint* group) { + if (group == nullptr || group->group_kind().empty() || group->group_id().empty() || + group->part_id().empty() || group->total_parts() == 0) { + return std::nullopt; + } + return store::loading::TransportSchedulingGroupHint{ + .group_kind = group->group_kind(), + .group_id = group->group_id(), + .total_parts = group->total_parts(), + .part_id = group->part_id(), + .priority = group->priority(), + .epoch = group->epoch(), + }; +} + absl::StatusOr resolve_source_execution_topology( const v2::SourceExecutionTopology* topology) { ExecutionTopologyContext execution_topology; diff --git a/daemon/service/controllers/materialization_policy_utils.h b/daemon/service/controllers/materialization_policy_utils.h index 6cc31a0c..1f51d2f9 100644 --- a/daemon/service/controllers/materialization_policy_utils.h +++ b/daemon/service/controllers/materialization_policy_utils.h @@ -54,6 +54,9 @@ absl::StatusOr resolve_retrieval_policy_compat(const v2::Source std::optional resolve_collective_group_hint( const v2::CollectiveLoadGroup* group); +std::optional resolve_transport_scheduling_group_hint( + const v2::TransportSchedulingGroupHint* group); + absl::StatusOr resolve_source_execution_topology(const v2::SourceExecutionTopology* topology); absl::StatusOr resolve_collective_policy( diff --git a/daemon/service/controllers/replica_materialization_service.cc b/daemon/service/controllers/replica_materialization_service.cc index 8cb7711d..697e8e22 100644 --- a/daemon/service/controllers/replica_materialization_service.cc +++ b/daemon/service/controllers/replica_materialization_service.cc @@ -57,6 +57,7 @@ using materialization_policy::convert_view_spec; using materialization_policy::NormalizedMaterializationRequestContext; using materialization_policy::resolve_collective_group_hint; using materialization_policy::resolve_materialization_request_context; +using materialization_policy::resolve_transport_scheduling_group_hint; using materialization_policy::resolve_transform_placement; using materialization_policy::to_hint_export_policy; using materialization_post_seal::check_post_seal_view_reuse_safe; @@ -684,6 +685,16 @@ grpc::Status ReplicaMaterializationService::materialize_replica( hints.verify = verify_checksums ? store::loading::MaterializeHints::Verify::CHECKSUM : store::loading::MaterializeHints::Verify::NONE; apply_request_context_to_hints(request_context, &hints); + if (!req.transport_request_id().empty()) { + hints.transport_request_id = req.transport_request_id(); + } + if (req.has_transport_scheduling_group()) { + auto group_hint = + resolve_transport_scheduling_group_hint(&req.transport_scheduling_group()); + if (group_hint.has_value()) { + hints.transport_scheduling_group = std::move(*group_hint); + } + } if (prefer_direct_disk_for_local_import) { hints.set_retrieval_policy( store::loading::RetrievalPolicy{ diff --git a/daemon/service/materialization_policy_utils_test.cc b/daemon/service/materialization_policy_utils_test.cc index 2fafe615..2e23d824 100644 --- a/daemon/service/materialization_policy_utils_test.cc +++ b/daemon/service/materialization_policy_utils_test.cc @@ -7,6 +7,7 @@ namespace { using tensorcast::daemon::materialization_policy::default_collective_policy_for_mapped_target; +using tensorcast::daemon::materialization_policy::resolve_transport_scheduling_group_hint; using tensorcast::store::loading::CollectiveLoadGroupHint; using tensorcast::store::loading::ExecutionTopologyContext; namespace v2 = tensorcast::daemon::v2; @@ -31,4 +32,24 @@ TEST_CASE( v2::CollectivePolicy::COLLECTIVE_POLICY_DISABLE_COLLECTIVE); } +TEST_CASE("Transport scheduling group hint maps daemon proto", "[daemon][materialization][policy]") { + v2::TransportSchedulingGroupHint proto; + proto.set_group_kind("weight_broadcast"); + proto.set_group_id("model-a:v42"); + proto.set_total_parts(16); + proto.set_part_id("daemon-1"); + proto.set_priority(7); + proto.set_epoch(42); + + auto hint = resolve_transport_scheduling_group_hint(&proto); + + REQUIRE(hint.has_value()); + CHECK(hint->group_kind == "weight_broadcast"); + CHECK(hint->group_id == "model-a:v42"); + CHECK(hint->total_parts == 16); + CHECK(hint->part_id == "daemon-1"); + CHECK(hint->priority == 7); + CHECK(hint->epoch == 42); +} + } // namespace From f232f05ffacc2c47773847eefb18acb31b74038b Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 13:22:14 +0800 Subject: [PATCH 15/49] feat(plan): propagate transport group hints --- proto/tensorcast/plan/v1/plan.proto | 11 ++ tensorcast/api/plan/plan.py | 14 +++ tensorcast/node_agent/executor.py | 106 +++++++++++++++++- tests/python/api/test_plan_spec.py | 33 +++++- .../python/node_agent/test_plan_execution.py | 66 ++++++++++- 5 files changed, 226 insertions(+), 4 deletions(-) diff --git a/proto/tensorcast/plan/v1/plan.proto b/proto/tensorcast/plan/v1/plan.proto index 1960a09e..b2b9f7ab 100644 --- a/proto/tensorcast/plan/v1/plan.proto +++ b/proto/tensorcast/plan/v1/plan.proto @@ -11,6 +11,17 @@ message CallContext { optional uint64 deadline_ms = 3; string idempotency_key = 4; map tags = 5; + TransportSchedulingGroup transport_group = 6; +} + +message TransportSchedulingGroup { + string group_id = 1; + string group_kind = 2; + uint32 total_parts = 3; + string part_id = 4; + uint32 priority = 5; + uint64 epoch = 6; + string request_id = 7; } message GovernanceContext { diff --git a/tensorcast/api/plan/plan.py b/tensorcast/api/plan/plan.py index 89b2728b..5176ba4d 100644 --- a/tensorcast/api/plan/plan.py +++ b/tensorcast/api/plan/plan.py @@ -1258,6 +1258,17 @@ def _call_context_proto(self) -> plan_pb2.CallContext: proto.deadline_ms = int(self._ctx.deadline_ms) if self._ctx.tags: proto.tags.update({str(k): str(v) for k, v in self._ctx.tags.items()}) + if self._ctx.transport_group is not None: + group = self._ctx.transport_group + transport_group = proto.transport_group + transport_group.group_kind = group.group_kind + transport_group.group_id = group.group_id + transport_group.total_parts = int(group.total_parts) + transport_group.part_id = group.part_id + transport_group.priority = int(group.priority) + transport_group.epoch = int(group.epoch) + if group.request_id: + transport_group.request_id = group.request_id return proto def _governance_proto(self) -> plan_pb2.GovernanceContext | None: @@ -1399,6 +1410,9 @@ def _action_context( deadline_ms=self._ctx.deadline_ms, idempotency_key=derived_key, tags=self._ctx.tags, + collective=self._ctx.collective, + transport_group=self._ctx.transport_group, + governance=self._ctx.governance, ) def _execute_step( diff --git a/tensorcast/node_agent/executor.py b/tensorcast/node_agent/executor.py index 3b23c8dd..91dd4d8a 100644 --- a/tensorcast/node_agent/executor.py +++ b/tensorcast/node_agent/executor.py @@ -13,7 +13,7 @@ from tensorcast.api._device import CPU_DEVICE_ID, device_uuid_for from tensorcast.api._errors import DeviceMismatch from tensorcast.api._view_ops import NarrowOp, TransposeOp, ViewSpecBuildResult -from tensorcast.api.context import CallContext +from tensorcast.api.context import CallContext, TransportSchedulingGroup from tensorcast.api.errors import ArtifactError from tensorcast.api.operation import OperationError, OperationStatus from tensorcast.api.plan.artifact_set import ( @@ -150,15 +150,87 @@ def _call_context_from_proto(ctx: plan_pb2.CallContext) -> CallContext: qos = "realtime" elif ctx.qos == plan_pb2.QOS_CLASS_BACKGROUND: qos = "background" + transport_group = ( + _transport_group_from_proto(ctx.transport_group) + if ctx.HasField("transport_group") + else None + ) return CallContext( request_id=str(ctx.request_id), qos=qos, deadline_ms=int(ctx.deadline_ms) if ctx.HasField("deadline_ms") else None, idempotency_key=str(ctx.idempotency_key) if ctx.idempotency_key else None, tags=dict(ctx.tags) if ctx.tags else None, + transport_group=transport_group, + ) + + +def _transport_group_from_proto( + proto: plan_pb2.TransportSchedulingGroup, +) -> TransportSchedulingGroup | None: + if ( + not proto.group_kind + or not proto.group_id + or not proto.part_id + or proto.total_parts <= 0 + ): + return None + return TransportSchedulingGroup( + group_kind=str(proto.group_kind), + group_id=str(proto.group_id), + total_parts=int(proto.total_parts), + part_id=str(proto.part_id), + priority=int(proto.priority), + epoch=int(proto.epoch), + request_id=str(proto.request_id) if proto.request_id else None, + ) + + +def _transport_group_to_daemon_proto( + group: TransportSchedulingGroup | None, +) -> store_daemon_pb2.TransportSchedulingGroupHint | None: + if group is None: + return None + return store_daemon_pb2.TransportSchedulingGroupHint( + group_kind=group.group_kind, + group_id=group.group_id, + total_parts=int(group.total_parts), + part_id=group.part_id, + priority=int(group.priority), + epoch=int(group.epoch), ) +def _transport_request_id_for_selection( + *, + group: TransportSchedulingGroup, + daemon_id: str, + selection: common_pb2.ArtifactSelection, + device_id: int, + device_uuid: str, +) -> str: + if group.request_id: + return group.request_id + digest = hashlib.sha256() + digest.update(b"tensorcast.node_agent.prefetch.transport.v1") + for value in ( + daemon_id, + selection.artifact_id, + selection.logical_layout_hash.hex(), + selection.selection_hash.hex(), + str(device_id), + device_uuid, + group.group_kind, + group.group_id, + str(group.epoch), + str(group.total_parts), + group.part_id, + ): + digest.update(b"|") + digest.update(str(value).encode("utf-8")) + return f"prefetch:{digest.hexdigest()}" + + def _ctx_timeout_s(ctx: CallContext) -> float | None: if ctx.deadline_ms is None: return None @@ -556,6 +628,7 @@ def _action_context( deadline_ms=call_ctx.deadline_ms, idempotency_key=derived_key, tags=call_ctx.tags, + transport_group=call_ctx.transport_group, ) def _instance_action_context( @@ -583,6 +656,7 @@ def _instance_action_context( deadline_ms=call_ctx.deadline_ms, idempotency_key=derived_key, tags=call_ctx.tags, + transport_group=call_ctx.transport_group, ) def _artifact_from_selection( @@ -998,9 +1072,18 @@ def _prefetch( ) ns = uuid.uuid5(uuid.NAMESPACE_DNS, "tensorcast.op.v1") replica_uuid = str(uuid.uuid5(ns, action_key)) + scoped_ctx = CallContext( + request_id=call_ctx.request_id, + qos=call_ctx.qos, + deadline_ms=call_ctx.deadline_ms, + idempotency_key=action_key, + tags=call_ctx.tags, + transport_group=call_ctx.transport_group, + ) else: replica_uuid = uuid.uuid4().hex - timeout_s = _ctx_timeout_s(call_ctx) + scoped_ctx = call_ctx + timeout_s = _ctx_timeout_s(scoped_ctx) if timeout_s is not None and timeout_s <= 0: return NodeAgentStepResult( step_id=step.step_id, @@ -1019,6 +1102,7 @@ def _prefetch( replica_uuid=replica_uuid, timeout_s=timeout_s, wait_for_completion=False, + call_ctx=scoped_ctx, ) except DeviceMismatch as exc: return NodeAgentStepResult( @@ -1131,6 +1215,7 @@ def _prefetch_set( replica_uuid=replica_uuid, timeout_s=timeout_s, wait_for_completion=True, + call_ctx=scoped_ctx, ) item_status = OperationStatus( state="success", @@ -1180,6 +1265,7 @@ def _materialize_selection( replica_uuid: str, timeout_s: float | None, wait_for_completion: bool, + call_ctx: CallContext | None = None, ) -> None: if device_id == CPU_DEVICE_ID: target_device_type = store_daemon_pb2.DeviceType.DEVICE_TYPE_CPU @@ -1187,6 +1273,21 @@ def _materialize_selection( else: target_device_type = store_daemon_pb2.DeviceType.DEVICE_TYPE_GPU device_uuid = device_uuid_for(device_id) + transport_kwargs: dict[str, object] = {} + transport_group = call_ctx.transport_group if call_ctx is not None else None + if transport_group is not None: + transport_kwargs["transport_request_id"] = ( + _transport_request_id_for_selection( + group=transport_group, + daemon_id=self._daemon_id, + selection=selection, + device_id=device_id, + device_uuid=device_uuid, + ) + ) + transport_group_proto = _transport_group_to_daemon_proto(transport_group) + if transport_group_proto is not None: + transport_kwargs["transport_scheduling_group"] = transport_group_proto self._client.materialize_by_artifact_id_v2( selection=selection, replica_uuid=replica_uuid, @@ -1196,6 +1297,7 @@ def _materialize_selection( target_device_type=target_device_type, lease_mode=store_daemon_pb2.LeaseMode.LEASE_MODE_NO_LEASE, timeout_s=timeout_s, + **transport_kwargs, ) def _pin( diff --git a/tests/python/api/test_plan_spec.py b/tests/python/api/test_plan_spec.py index 259efa46..1f25ee0d 100644 --- a/tests/python/api/test_plan_spec.py +++ b/tests/python/api/test_plan_spec.py @@ -8,7 +8,11 @@ import torch from tensorcast.api._config import PlanType -from tensorcast.api.context import CallContext, GovernanceContext +from tensorcast.api.context import ( + CallContext, + GovernanceContext, + TransportSchedulingGroup, +) from tensorcast.api.errors import ArtifactError from tensorcast.api.plan import ( ARTIFACT_SET_CARRIER_INLINE, @@ -118,6 +122,32 @@ def test_plan_to_spec_is_deterministic() -> None: assert selection.view_id == "" +def test_plan_context_serializes_transport_group() -> None: + ctx = CallContext( + request_id="req-transport", + transport_group=TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + total_parts=16, + part_id="daemon-1", + priority=7, + epoch=42, + request_id="transport-req-1", + ), + ) + spec = Plan(ctx).to_spec() + + assert spec.context.HasField("transport_group") + group = spec.context.transport_group + assert group.group_kind == "weight_broadcast" + assert group.group_id == "model-a:v42" + assert group.total_parts == 16 + assert group.part_id == "daemon-1" + assert group.priority == 7 + assert group.epoch == 42 + assert group.request_id == "transport-req-1" + + def test_plan_view_selection_hash_populated() -> None: store = _StoreStub() canonical_bytes = _canonical_index_bytes() @@ -169,6 +199,7 @@ def test_plan_publish_serializes_canonical_action() -> None: assert first_step.action.publish.engine_request_id == "rid-123" assert int(first_step.action.publish.ttl_ms) == 60_000 + def test_plan_hydrate_serializes_publish_manifest() -> None: ctx = CallContext(request_id="req-hydrate", idempotency_key="idem-hydrate") plan = Plan(ctx) diff --git a/tests/python/node_agent/test_plan_execution.py b/tests/python/node_agent/test_plan_execution.py index 5a7df1b1..6e127599 100644 --- a/tests/python/node_agent/test_plan_execution.py +++ b/tests/python/node_agent/test_plan_execution.py @@ -10,7 +10,7 @@ import tensorcast.node_agent.executor as executor_mod from tensorcast.api._config import PlanType from tensorcast.api._errors import DeviceMismatch -from tensorcast.api.context import CallContext +from tensorcast.api.context import CallContext, TransportSchedulingGroup from tensorcast.api.plan import ArtifactSetRef, Plan, Worker from tensorcast.api.store import ( BuilderMode, @@ -38,6 +38,7 @@ from tensorcast.node_agent.executor import NodeAgentExecutor from tensorcast.node_agent.server import NodeAgentServicer from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.daemon.v2 import store_daemon_pb2 from tensorcast.proto.node_agent.v1 import node_agent_pb2 from tensorcast.proto.plan.v1 import plan_pb2 from tensorcast.types import ( @@ -53,10 +54,12 @@ def __init__(self) -> None: self.placement_timeout_s: float | None = None self.release_timeout_s: float | None = None self.materialized_artifact_ids: list[str] = [] + self.materialize_calls: list[dict[str, object]] = [] self.wait_for_completion_values: list[bool] = [] def materialize_by_artifact_id_v2(self, *args, **kwargs): # noqa: ANN002, ANN003 self.materialize_timeout_s = kwargs.get("timeout_s") + self.materialize_calls.append(dict(kwargs)) selection = kwargs.get("selection") if selection is not None: self.materialized_artifact_ids.append(str(selection.artifact_id)) @@ -442,6 +445,67 @@ class _StoreStub: assert all(daemon.wait_for_completion_values) +def test_node_agent_prefetch_forwards_transport_group_from_plan_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + daemon = _DaemonStub() + monkeypatch.setattr(executor_mod, "device_uuid_for", lambda _device_id: "gpu-0") + + class _StoreStub: + closed = False + _runtime = None + + canonical_bytes = _canonical_index_bytes() + store = _StoreStub() + artifact = Artifact( + store_ref=weakref.ref(store), + artifact_id="mi2:model-a:v42", + canonical_index_bytes=canonical_bytes, + canonical_index=canonical_index_from_bytes(canonical_bytes), + ) + plan = Plan( + CallContext( + request_id="req-prefetch", + transport_group=TransportSchedulingGroup( + group_kind="weight_broadcast", + group_id="model-a:v42", + total_parts=16, + part_id="daemon-1", + priority=7, + epoch=42, + request_id="transport-req-1", + ), + ) + ) + worker = Worker( + worker_id="worker-1", + daemon_address="127.0.0.1:50051", + daemon_id="daemon-1", + ) + plan.on_worker(worker).prefetch(artifact, device=0) + + executor = NodeAgentExecutor( + daemon_id="daemon-1", + daemon_address="127.0.0.1:50051", + instance_id="inst-1", + engine_adapter=None, + client_factory=lambda _addr: daemon, + ) + result = executor.execute_plan(plan.to_spec()) + + assert result.ok is True + request = daemon.materialize_calls[0] + assert request["transport_request_id"] == "transport-req-1" + group = request["transport_scheduling_group"] + assert isinstance(group, store_daemon_pb2.TransportSchedulingGroupHint) + assert group.group_kind == "weight_broadcast" + assert group.group_id == "model-a:v42" + assert group.total_parts == 16 + assert group.part_id == "daemon-1" + assert group.priority == 7 + assert group.epoch == 42 + + def test_node_agent_executes_manifest_backed_prefetch_set_with_bridge( monkeypatch: pytest.MonkeyPatch, ) -> None: From ff787da7c8ac55ed436d544526557f0c98e4f176 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 13:27:55 +0800 Subject: [PATCH 16/49] docs: document weight prefetch transport groups --- tensorcast/api/README.md | 7 +++++++ tensorcast/api/store/README.md | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/tensorcast/api/README.md b/tensorcast/api/README.md index 8973f837..808b5568 100644 --- a/tensorcast/api/README.md +++ b/tensorcast/api/README.md @@ -60,6 +60,13 @@ By default, the Python SDK surfaces a concise `ArtifactError` stack without gRPC artifact/view identity, but `ctx.idempotency_key` seeds deterministic operation ids for joinable actions. - Collective disk loads are explicit at the API boundary via `CallContext.collective=CollectiveLoadGroup(...)`; the SDK no longer infers collective mode from ambient GPU environment variables or overloads `replica_uuid` with group hints. +- Coordinated model-weight prefetch can opt into Global Store group dispatch via + `CallContext.transport_group=TransportSchedulingGroup(...)`. Use a stable + `group_kind`/`group_id`/`epoch` for the model version and unique + `part_id` per target daemon/rank; `Artifact.prefetch(...)` derives a stable + transport request id when `request_id` is omitted. `Plan` serializes the same + hint so Node Agent prefetches reach the daemon with the matching transport + scheduling group. - Long-tail control-plane actions return `Operation[T]` (sync/blocking): use `status()` / `result()` / `cancel()` to implement wait/cancel without ad-hoc polling loops. - `Artifact.prefetch(...)` warms a **daemon-owned** replica and supports both GPU and CPU/DRAM placement: diff --git a/tensorcast/api/store/README.md b/tensorcast/api/store/README.md index df719186..92f8c991 100644 --- a/tensorcast/api/store/README.md +++ b/tensorcast/api/store/README.md @@ -385,6 +385,13 @@ binding.swap("model:v2") `logical_layout_hash`, `selection_hash`) and target placement (daemon + device/tier). `selection_hash` is computed via `tensorcast.common.selection_identity` (stable `view_id` + `view_subset_hash`), matching Plan selection identity semantics. +- To spread a model-weight warmup across replicas, pass + `ctx=tc.context(transport_group=tc.TransportSchedulingGroup(...))` to + `artifact.prefetch(...)`. All targets for the same model version should share + `group_kind`, `group_id`, and `epoch`, while each target uses a unique + `part_id`; the SDK forwards the group and a stable `transport_request_id` to + the daemon so the existing Global Store group dispatcher can spread source + selection. Unset transport groups keep the ordinary materialization path. - `artifact.pin_device_residency(device=..., ttl_ms=..., ctx=...) -> Operation[PlacementPin]` creates a placement pin (process-independent device residency intent) backed by a daemon-scoped capability token; the returned `PlacementPin` supports `renew()` / `release()`. From e9187e38a68f79ff1b9e04f30126d4285249b06c Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:37:15 +0800 Subject: [PATCH 17/49] chore(communicator): simplify transfer progress formatting --- core/communicator/engine/engine.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/communicator/engine/engine.cc b/core/communicator/engine/engine.cc index 41281160..1d6e3778 100644 --- a/core/communicator/engine/engine.cc +++ b/core/communicator/engine/engine.cc @@ -2107,7 +2107,7 @@ uint64_t Communicator::add_transfer_progress_bytes( const double progress_percent = state->total_bytes > 0 ? static_cast(done) * 100.0 / static_cast(state->total_bytes) : 100.0; LOG(INFO) << std::format( - "[xfer_progress] side={} transport={} state=progress peer={} request={} bar=[{}] {:5.1f}% " + "[xfer_progress] side={} transport={} state=progress peer={} request={} bar=[{}] {:.1f}% " "done_gib={:.3f}/{:.3f} rate_inst_gibps={:.3f} rate_avg_gibps={:.3f}", state->side, state->transport, @@ -2151,7 +2151,7 @@ void Communicator::finish_transfer_progress( const std::string phase = status.ok() ? "done" : "failed"; const std::string status_text = status.ok() ? std::string() : truncate_token(status.message(), 120); const std::string line = std::format( - "[xfer_progress] side={} transport={} state={} peer={} request={} bar=[{}] {:5.1f}% " + "[xfer_progress] side={} transport={} state={} peer={} request={} bar=[{}] {:.1f}% " "done_gib={:.3f}/{:.3f} rate_inst_gibps={:.3f} rate_avg_gibps={:.3f}{}", state->side, state->transport, From 7780e04fb4f1d330605440f368068ce68a36230f Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 16:09:36 +0800 Subject: [PATCH 18/49] docs: design tree broadcast phase2 --- ...117-control-plane-tree-broadcast-phase2.md | 489 ++++++++++++++++++ 1 file changed, 489 insertions(+) create mode 100644 docs/designs/0117-control-plane-tree-broadcast-phase2.md diff --git a/docs/designs/0117-control-plane-tree-broadcast-phase2.md b/docs/designs/0117-control-plane-tree-broadcast-phase2.md new file mode 100644 index 00000000..c1681971 --- /dev/null +++ b/docs/designs/0117-control-plane-tree-broadcast-phase2.md @@ -0,0 +1,489 @@ +--- +slug: control-plane-tree-broadcast-phase2 +title: Control-Plane Tree Broadcast Phase 2 +status: proposed +areas: ["sdk", "daemon", "core", "global_store"] +created: 2026-04-29 +last_updated: 2026-04-29 +related_code: + - proto/tensorcast/global_store/v1/global_store.proto + - proto/tensorcast/daemon/v2/store_daemon.proto + - schema.sql + - tensorcast/global_store/services/transport_service.py + - tensorcast/global_store/repositories/replica_repository.py + - tensorcast/global_store/repositories/transport_repository.py + - tensorcast/global_store/rpc/transport_rpc_handler.py + - core/store/materialization/contracts/loading_spec.h + - core/store/components/global_store_client.h + - core/store/components/global_store_client.cc + - core/store/materialization/control/materialize_orchestrator.cc + - core/store/runtime/ingestion/materialization_facade.cc + - daemon/service/controllers/replica_materialization_service.cc + - tensorcast/api/context.py + - tensorcast/api/store/artifact.py + - tensorcast/api/_materialize.py + - tensorcast/daemon_ctl.py +links: + dependencies: + - ./0116-control-plane-coordinated-weight-broadcast.md + - ./0083-group-aware-transport-scheduling.md + - ./0048-ha-replica-visibility-and-retire.md +--- + +# Summary + +Phase 1 made model-weight prefetches visible to Global Store group dispatch by +threading transport request ids and scheduling group hints through SDK, +Store Daemon, C++ materialization, and `RequestReplicaTransport`. That reduces +source concentration when multiple eligible replicas already exist, but it +does not produce a strict tree. Phase 2 adds explicit control-plane broadcast +sessions and tree edges while keeping all data movement on the existing P2P +materialization path. + +The selected design is an end-to-end, session-aware `RequestReplicaTransport` +flow. SDK and plan callers create a broadcast session through Store Daemon, not +by connecting to Global Store directly. Child materialization calls carry a +`broadcast_session_id`; Global Store resolves the child worker to a planned +edge, validates the assigned parent replica with the same liveness, capacity, +and exportability rules used by normal transport scheduling, and returns only +that parent as the transport source. + +```mermaid +flowchart LR + A["Root exportable replica"] --> B["Global Store
BroadcastSession"] + B --> C["Tree planner
fanout edges"] + C --> D["Child daemon
MaterializeReplica"] + D --> E["RequestReplicaTransport
broadcast session hint"] + E --> F["Assigned parent replica"] + F --> G["Existing P2P data path
RDMA or MTCP"] + G --> H["Child registers replica
and exports if possible"] + H --> I["Global Store marks edge complete
and schedules next layer"] +``` + +# Goals / Non-Goals + +## Goals + +- Create a durable `BroadcastSession` for one artifact, optional view, epoch, + fanout, and target set. +- Generate explicit parent-child tree edges inside Global Store. +- Make child materialization pull only from the parent assigned by the active + edge when strict mode is enabled. +- Allow a successfully materialized child replica to become a parent for later + layers after replica registration/export succeeds. +- Reuse existing replica registry, `artifact_transports`, heartbeat, + accepting-new-requests, capacity, export metadata, verification metadata, and + transport completion outcome machinery. +- Keep SDK control-path access local to Store Daemon. +- Preserve ordinary `Artifact.prefetch()`, `tensor_dict()`, Phase 1 group + dispatch, MTCP fallback, and disk fallback behavior when no broadcast hint is + provided. + +## Non-Goals + +- Do not implement topology-aware rack/host/rail planning in Phase 2. +- Do not implement chunk-level pipeline forwarding in Phase 2. +- Do not introduce NCCL as the cross-cluster model-weight broadcast control + plane. +- Do not create a second control plane beside Global Store. +- Do not require daemon-to-daemon control RPCs for tree assignment; only data + movement remains daemon-to-daemon. + +# Prior Constraints Reviewed + +## Phase 1 group scheduling + +Design `0116` introduced stable transport request ids and group hints for soft +broadcast. Phase 2 keeps that path intact. A transport request may still carry a +Phase 1 scheduling group for observability and fairness, but the broadcast +session hint takes precedence for source selection because it represents a +strict edge assignment rather than a soft source-spread preference. + +## Global Store as sole coordination authority + +The current architecture separates Global Store metadata/coordination from +Store Daemon data movement. This design keeps tree state, failure retry, and +progress tracking in Global Store. Store Daemon remains the gateway used by SDK +callers and the executor of actual materialization. + +## SDK must not connect to Global Store directly + +Root `AGENTS.md` requires Python SDK code to go through Store Daemon for +control-path operations. Therefore session creation, session lookup, and +materialization hints are exposed through Store Daemon APIs. Store Daemon +forwards the session requests to Global Store through its existing metadata +gateway/client boundary. + +## HA and replica visibility + +Existing worker liveness, accepting-new-requests, capacity, and export-state +filters are required before a replica can serve P2P. Phase 2 does not relax +those checks for tree edges. Parent assignment is only valid while the parent +replica is transport-eligible. + +# Architecture & Interfaces + +## Persistent model + +Three tables own broadcast state: + +- `broadcast_sessions`: one row per dissemination attempt. +- `broadcast_targets`: one row per target worker/daemon in the session. +- `broadcast_edges`: one row per parent-child attempt. + +`artifact_transports` gains nullable `broadcast_session_id` and +`broadcast_edge_id` columns so transport lifecycle rows remain the audit trail +for bytes movement and completion outcomes. + +## Global Store RPC + +`ClusterRuntimeService` adds session management RPCs: + +```proto +rpc CreateBroadcastSession(CreateBroadcastSessionRequest) + returns (CreateBroadcastSessionResponse); +rpc GetBroadcastSession(GetBroadcastSessionRequest) + returns (GetBroadcastSessionResponse); +rpc ListBroadcastEdges(ListBroadcastEdgesRequest) + returns (ListBroadcastEdgesResponse); +rpc CancelBroadcastSession(CancelBroadcastSessionRequest) + returns (CancelBroadcastSessionResponse); +``` + +The create request accepts: + +```proto +message BroadcastTargetIdentity { + string worker_id = 1; + string daemon_id = 2; +} + +message CreateBroadcastSessionRequest { + string session_id = 1; + string artifact_id = 2; + tensorcast.common.v1.ByteSpaceRef requested_byte_space = 3; + uint64 epoch = 4; + uint32 fanout = 5; + repeated BroadcastTargetIdentity targets = 6; + string root_replica_id = 7; + bool strict_parent = 8; + uint32 max_attempts = 9; +} +``` + +`session_id` is caller-supplied for idempotency. If `root_replica_id` is empty, +Global Store selects an eligible root replica for the artifact/view. `daemon_id` +targets are resolved to active `worker_id` rows during planning when possible; +the target row stores both identities so daemon restarts remain diagnosable. + +## Session-aware transport request + +`RequestReplicaTransportRequest` gains a broadcast hint: + +```proto +message BroadcastTransportHint { + string session_id = 1; + bool strict_parent = 2; +} + +message RequestReplicaTransportRequest { + // existing fields... + BroadcastTransportHint broadcast = 11; +} +``` + +When `broadcast.session_id` is set, Global Store verifies that: + +- the request artifact/view matches the session, +- the session epoch matches the target epoch stored in the session, +- `requester_worker_id` maps to a session target, +- an active edge exists or can be planned for that target, +- the edge parent replica is currently transport-eligible. + +If `strict_parent` is true, Global Store returns only the assigned edge parent. +If no eligible parent is available before the wait deadline, the request times +out or the edge is failed and requeued according to the scheduler rules below. + +## Store Daemon and SDK API + +Store Daemon exposes daemon-local RPCs that forward broadcast session requests +to Global Store. The SDK surface builds on existing `CallContext`: + +```python +@dataclass(frozen=True, slots=True) +class BroadcastContext: + session_id: str + strict_parent: bool = True + +@dataclass(frozen=True, slots=True) +class CallContext: + # existing fields... + broadcast: BroadcastContext | None = None +``` + +`Artifact.prefetch(..., ctx=tensorcast.context(broadcast=...))` forwards the +broadcast hint to `materialize_artifact_v2`, `DaemonCtl`, and +`MaterializeReplicaRequest`. + +Store Daemon maps the daemon proto hint into `MaterializeHints`: + +```c++ +struct BroadcastHint { + std::string session_id; + bool strict_parent{true}; +}; + +struct MaterializeHints { + // existing fields... + std::optional broadcast; +}; +``` + +`GlobalStoreClient::request_replica_transport()` and +`request_view_transport()` accept the same optional hint and copy it into +`RequestReplicaTransportRequest.broadcast`. + +# Schema Changes + +`schema.sql` adds: + +```sql +CREATE TABLE IF NOT EXISTS broadcast_sessions ( + session_id TEXT PRIMARY KEY, + artifact_id TEXT NOT NULL, + requested_view_id TEXT NULL, + epoch BIGINT NOT NULL, + fanout INTEGER NOT NULL, + max_attempts INTEGER NOT NULL DEFAULT 3, + strict_parent BOOLEAN NOT NULL DEFAULT TRUE, + state TEXT CHECK (state IN ('planning','active','completed','failed','cancelled')) NOT NULL, + root_replica_id UUID NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL +); + +CREATE TABLE IF NOT EXISTS broadcast_targets ( + session_id TEXT NOT NULL, + target_worker_id TEXT NOT NULL, + target_daemon_id TEXT NULL, + state TEXT CHECK (state IN ('pending','assigned','materializing','completed','failed','cancelled')) NOT NULL, + level INTEGER NULL, + attempt INTEGER NOT NULL DEFAULT 0, + assigned_edge_id TEXT NULL, + completed_replica_id UUID NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, + PRIMARY KEY (session_id, target_worker_id) +); + +CREATE TABLE IF NOT EXISTS broadcast_edges ( + edge_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + parent_worker_id TEXT NOT NULL, + parent_replica_id UUID NOT NULL, + child_worker_id TEXT NOT NULL, + level INTEGER NOT NULL, + attempt INTEGER NOT NULL DEFAULT 1, + state TEXT CHECK (state IN ('planned','assigned','materializing','completed','failed','cancelled')) NOT NULL, + transport_request_id TEXT NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_broadcast_edges_one_active_child + ON broadcast_edges(session_id, child_worker_id) + WHERE state IN ('planned','assigned','materializing'); + +ALTER TABLE artifact_transports ADD COLUMN broadcast_session_id TEXT NULL; +ALTER TABLE artifact_transports ADD COLUMN broadcast_edge_id TEXT NULL; +CREATE INDEX IF NOT EXISTS idx_artifact_transports_broadcast + ON artifact_transports(broadcast_session_id, broadcast_edge_id, status); +``` + +Implementation should add a DuckDB migration file for persistent deployments +and keep `schema.sql` as the canonical new database shape. + +# Scheduler State Machine + +## Session planning + +`CreateBroadcastSession` inserts the session and targets, then plans the first +layer. The initial parent pool contains the specified root replica or eligible +root candidates. The scheduler creates at most `fanout` planned edges for +pending targets. A target moves from `pending` to `assigned` when an edge is +created. + +```mermaid +stateDiagram-v2 + [*] --> planning + planning --> active + active --> completed + active --> failed + active --> cancelled + cancelled --> [*] + completed --> [*] + failed --> [*] +``` + +## Transport claim + +When a broadcast transport request arrives, Global Store: + +1. Finds the session by `session_id`. +2. Resolves `requester_worker_id` to a target row. +3. Finds the target active edge or creates one if capacity is available in the + current parent pool. +4. Validates the parent replica with the normal transport eligibility checks. +5. Atomically increments parent replica `current_requests`. +6. Creates an `artifact_transports` row linked to the session and edge. +7. Marks the edge `materializing` and target `materializing`. + +If parent validation fails, the current edge is marked `failed`, the target +returns to `pending`, and a new edge is planned if another eligible parent +exists. + +## Completion and promotion + +For broadcast transports, success means more than copied bytes: + +```text +P2P load succeeded ++ child replica registration succeeded ++ child replica is visible as a usable future source +``` + +The materialization path should therefore call +`CompleteReplicaTransport(SUCCESS)` after successful local replica registration +for broadcast requests. If P2P succeeds but registration/export fails, the +transport completes with `FAILED`, the edge does not advance, and the target is +eligible for retry. + +On success, Global Store: + +- marks the transport completed with `success`, +- marks the edge completed, +- marks the target completed and records `completed_replica_id`, +- adds the child replica to the parent pool for future planning, +- schedules more pending targets up to the fanout limit, +- marks the session completed only after all targets complete. + +## Failure retry + +Failed, expired, and cancelled transports do not count as success. They fail +only the edge attempt. The target returns to `pending` with `attempt + 1`. +Global Store replans from the current eligible parent pool, preferring +completed child replicas and falling back to the root pool when needed. After +`max_attempts`, the target is marked `failed` and the session is marked +`failed` unless it has already been cancelled. + +# Error Model + +- `CreateBroadcastSession` rejects empty artifact ids, zero fanout, zero + targets, duplicate target identities, invalid epoch, and missing eligible + root sources. +- `RequestReplicaTransport` with a broadcast hint rejects artifact/view/session + mismatches and unknown requester workers. +- Parent heartbeat stale, not accepting, over capacity, not exportable, or + missing transport metadata causes the edge attempt to fail and replan. +- Reusing a `transport_request_id` with a different payload remains rejected by + existing idempotency checks. +- `CompleteReplicaTransport(SUCCESS)` on a non-broadcast transport keeps + existing behavior. +- `CompleteReplicaTransport(SUCCESS)` on a broadcast transport with no linked + edge is invalid. +- Cancelled sessions stop generating new edges. In-flight transports may finish + with failed/cancelled outcomes but must not advance cancelled sessions. + +# Compatibility & Rollout + +The feature is additive. Existing materialization calls do not set +`broadcast.session_id`, so normal source selection and Phase 1 group dispatch +continue to run unchanged. + +Rollout should proceed in three gates: + +- Gate 1: Global Store schema, repository, service, and RPC tests pass with the + feature unused. +- Gate 2: daemon/core hint propagation works, but broadcast sessions are + enabled only in tests or explicit callers. +- Gate 3: SDK `BroadcastContext` and daemon-mediated session creation become + available for model-weight prefetch paths. + +Backout is straightforward for callers: omit the broadcast hint and they return +to Phase 1 group dispatch or ordinary source selection. + +# Naming Compliance + +| Interface | Language | Compliance | +| --- | --- | --- | +| `BroadcastContext` | Python class | PascalCase class name matching existing context dataclasses. | +| `broadcast` | Python field | snake_case field name. | +| `CreateBroadcastSessionRequest` | Proto message | PascalCase message name. | +| `BroadcastTransportHint` | Proto message | PascalCase message name. | +| `broadcast_session_id` | SQL/proto field | snake_case field name. | +| `BroadcastHint` | C++ struct | PascalCase struct name. | +| `request_replica_transport` | C++ method | Existing snake_case method name retained. | +| `create_broadcast_session` | Python service method | snake_case method name. | + +# Testing + +Global Store tests should cover: + +- session creation with root selection, targets, and first-layer edges, +- fanout limits, +- broadcast transport requests returning only the edge parent, +- parent ineligibility failing only the edge attempt, +- success completion advancing edge, target, and session state, +- failed/expired/cancelled outcomes not advancing tree progress, +- artifact/view/epoch mismatch rejection, +- max-attempt target failure, +- unchanged group dispatch tests. + +Daemon/core tests should cover: + +- daemon `MaterializeReplicaRequest.broadcast` to `MaterializeHints.broadcast`, +- C++ `GlobalStoreClient` proto mapping, +- broadcast materialization still using existing P2P loader, +- broadcast success completing after registration, +- registration/export failure completing the transport as `FAILED`. + +An end-to-end local test should run with fake CUDA or CPU-compatible settings: + +1. Register an exportable root replica. +2. Create a session with three targets and `fanout=1` or `fanout=2`. +3. Run target prefetches with the same `broadcast_session_id`. +4. Verify transport rows show edge parent assignments. +5. Verify first-layer children become later parents. +6. Force parent failure and verify only the affected target is replanned. +7. Verify ordinary unhinted prefetch still works. + +# Acceptance Criteria + +- A caller can create a broadcast session through Store Daemon. +- Global Store generates a parent-child tree plan for the session. +- Child materialization pulls from the edge-assigned parent in strict mode. +- A completed child registers a replica and can become a next-layer parent. +- Parent failure replans affected targets without failing the whole session + until `max_attempts` is exceeded. +- Session completion matches all target materialization successes. +- Session epoch isolates different model versions. +- Existing `tensorcast.artifact(...).tensor_dict()`, ordinary + `Artifact.prefetch()`, Phase 1 group dispatch, and fallback paths continue to + work without broadcast hints. + +# References + +- `docs/designs/0116-control-plane-coordinated-weight-broadcast.md` +- `docs/designs/0083-group-aware-transport-scheduling.md` +- `schema.sql` +- `tensorcast/global_store/services/transport_service.py` +- `tensorcast/global_store/repositories/replica_repository.py` +- `core/store/materialization/control/materialize_orchestrator.cc` +- `core/store/runtime/ingestion/materialization_facade.cc` +- `proto/tensorcast/global_store/v1/global_store.proto` +- `proto/tensorcast/daemon/v2/store_daemon.proto` From 749cdb180d0ecd0dd8585cda23891fba95f302d7 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 16:22:42 +0800 Subject: [PATCH 19/49] docs: plan tree broadcast phase2 --- ...117-control-plane-tree-broadcast-phase2.md | 2075 +++++++++++++++++ 1 file changed, 2075 insertions(+) create mode 100644 docs/plans/0117-control-plane-tree-broadcast-phase2.md diff --git a/docs/plans/0117-control-plane-tree-broadcast-phase2.md b/docs/plans/0117-control-plane-tree-broadcast-phase2.md new file mode 100644 index 00000000..2e0ae6cc --- /dev/null +++ b/docs/plans/0117-control-plane-tree-broadcast-phase2.md @@ -0,0 +1,2075 @@ +--- +slug: control-plane-tree-broadcast-phase2 +title: Control-Plane Tree Broadcast Phase 2 Implementation Plan +links: + design: ../designs/0117-control-plane-tree-broadcast-phase2.md +--- + +# Control-Plane Tree Broadcast Phase 2 Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build end-to-end Global Store coordinated tree broadcast for model-weight materialization while reusing existing P2P data movement. + +**Architecture:** Add durable broadcast session/target/edge state in Global Store, route broadcast-tagged `RequestReplicaTransport` calls through edge-assigned parents, and propagate broadcast hints from SDK through Store Daemon and C++ `MaterializeHints`. Broadcast success advances only after P2P materialization and local replica registration/export have completed. + +**Tech Stack:** DuckDB schema, Python Global Store repositories/services/RPC handlers, protobuf/buf, Python SDK, C++ Store Daemon controllers, C++ StoreEngine materialization, pytest, Bazel/Catch2. + +--- + +# Current State & Grounding + +- Branch: `runze/broadcast-weight`. +- Design: `docs/designs/0117-control-plane-tree-broadcast-phase2.md`. +- Phase 1 design and implementation are already present in `docs/designs/0116-control-plane-coordinated-weight-broadcast.md`. +- `RequestReplicaTransportRequest` already has `request_id`, `requester_worker_id`, and `scheduling_group`. +- `artifact_transports` and `pending_transport_requests` already track transport request ids, requester worker ids, scheduling groups, and completion outcomes. +- `ReplicaRepository.find_available_for_transport()` already filters source replicas by export metadata, worker liveness, `accepting_new_requests`, capacity, and memory tier ordering. +- `MaterializeHints` already carries `transport_request_id`, `transport_requester_worker_id`, and `transport_scheduling_group`. +- Current gap: there is no broadcast session/edge state, no strict parent assignment in transport selection, and no SDK/daemon/core broadcast hint. +- Existing dirty worktree before this plan includes generated proto artifacts and `pyproject.toml`; implementation must not revert or stage unrelated pre-existing changes. + +# File Structure + +- Create: `tensorcast/global_store/models/broadcast.py` for session/target/edge dataclasses and state enums. +- Create: `tensorcast/global_store/repositories/broadcast_repository.py` for DuckDB CRUD and atomic edge state transitions. +- Create: `tensorcast/global_store/services/broadcast_service.py` for session creation, planning, transport-edge claim, completion advancement, retry, and cancellation. +- Create: `tensorcast/global_store/rpc/broadcast_rpc_handler.py` for Global Store broadcast RPC validation and protobuf mapping. +- Modify: `schema.sql` and add migration `tensorcast/global_store/migrations/0019_broadcast_sessions.py`. +- Modify: `tensorcast/global_store/models/__init__.py`, `repositories/__init__.py`, `services/__init__.py`, `grpc_service.py`, and `rpc_servicer_mixins.py` to wire the new domain. +- Modify: `tensorcast/global_store/repositories/transport_repository.py` to persist `broadcast_session_id` and `broadcast_edge_id`. +- Modify: `tensorcast/global_store/repositories/replica_repository.py` to claim a specific parent replica and to find a child replica by worker after registration. +- Modify: `tensorcast/global_store/services/transport_service.py` to delegate broadcast transport claims and completions to `BroadcastService`. +- Modify: `tensorcast/global_store/rpc/transport_rpc_handler.py` to parse `BroadcastTransportHint`. +- Modify: `proto/tensorcast/global_store/v1/global_store.proto` and `proto/tensorcast/daemon/v2/store_daemon.proto`. +- Modify: `tensorcast/api/context.py`, `tensorcast/api/__init__.py`, `tensorcast/__init__.py`, `tensorcast/api/store/artifact.py`, `tensorcast/api/_materialize.py`, and `tensorcast/daemon_ctl.py` for SDK/daemon hint propagation. +- Modify: `daemon/service/controllers/materialization_policy_utils.{h,cc}`, `daemon/service/controllers/replica_materialization_service.cc`, and daemon RPC controller files for broadcast session forwarding. +- Modify: `core/store/materialization/contracts/loading_spec.h`, `core/store/components/global_store_client.{h,cc}`, `core/store/testing/recording_global_store_client.h`, `core/store/materialization/control/materialize_orchestrator.cc`, and `core/store/runtime/ingestion/materialization_facade.cc`. +- Test: add Global Store repository/service/RPC tests under `tests/python/global_store/`. +- Test: update SDK tests under `tests/python/api/`. +- Test: update C++ daemon and core tests under existing Bazel targets. + +# Phases & Milestones + +- [ ] Phase 1: Add Global Store broadcast schema, domain models, and repository tests. +- [ ] Phase 2: Add Global Store broadcast service/RPC and first-layer tree planning. +- [ ] Phase 3: Route broadcast-tagged transport requests through assigned parent edges. +- [ ] Phase 4: Propagate broadcast hints through SDK, daemon proto, daemon controller, and C++ client. +- [ ] Phase 5: Enforce broadcast completion semantics after child replica registration. +- [ ] Phase 6: Add integration coverage, docs cross-links, and final verification. + +### Task 1: Broadcast Schema, Models, And Repository + +**Files:** +- Create: `tensorcast/global_store/models/broadcast.py` +- Create: `tensorcast/global_store/repositories/broadcast_repository.py` +- Modify: `tensorcast/global_store/models/__init__.py` +- Modify: `tensorcast/global_store/repositories/__init__.py` +- Modify: `schema.sql` +- Create: `tensorcast/global_store/migrations/0019_broadcast_sessions.py` +- Modify: `tests/python/global_store/conftest.py` +- Test: add `tests/python/global_store/test_broadcast_repository.py` + +- [ ] **Step 1: Write failing repository tests** + +Create `tests/python/global_store/test_broadcast_repository.py`: + +```python +from __future__ import annotations + +from uuid import UUID + +from tensorcast.global_store.models import ( + BroadcastEdge, + BroadcastEdgeState, + BroadcastSession, + BroadcastSessionState, + BroadcastTarget, + BroadcastTargetState, +) +from tensorcast.global_store.repositories import BroadcastRepository + + +def test_broadcast_repository_creates_session_targets_and_edges(db_connection): + repo = BroadcastRepository(db_connection) + session = BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=42, + fanout=2, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + repo.create_session(session) + repo.upsert_target( + BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child-1", + target_daemon_id="daemon-child-1", + state=BroadcastTargetState.PENDING, + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child-1", + level=1, + attempt=1, + state=BroadcastEdgeState.PLANNED, + ) + ) + + loaded = repo.find_session("session-a") + assert loaded is not None + assert loaded.artifact_id == "mi2:test" + assert loaded.epoch == 42 + assert loaded.state is BroadcastSessionState.ACTIVE + + target = repo.find_target("session-a", "worker-child-1") + assert target is not None + assert target.target_daemon_id == "daemon-child-1" + assert target.state is BroadcastTargetState.PENDING + + edge = repo.find_active_edge_for_child("session-a", "worker-child-1") + assert edge is not None + assert edge.parent_worker_id == "worker-root" + assert edge.state is BroadcastEdgeState.PLANNED + + +def test_broadcast_repository_prevents_two_active_edges_for_child(db_connection): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + first = BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.PLANNED, + ) + repo.create_edge(first) + + try: + repo.create_edge( + BroadcastEdge( + edge_id="edge-2", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=2, + state=BroadcastEdgeState.ASSIGNED, + ) + ) + except Exception as exc: # noqa: BLE001 + assert "active" in str(exc).lower() or "constraint" in str(exc).lower() + else: + raise AssertionError("expected active edge uniqueness to reject duplicate child") + + +def test_broadcast_repository_marks_edge_completed_and_target_completed(db_connection): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + repo.upsert_target( + BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + state=BroadcastTargetState.MATERIALIZING, + assigned_edge_id="edge-1", + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.MATERIALIZING, + transport_request_id="transport-request-1", + ) + ) + + completed_replica_id = UUID("00000000-0000-0000-0000-000000000002") + assert repo.mark_edge_completed( + edge_id="edge-1", + completed_replica_id=completed_replica_id, + ) + edge = repo.find_edge("edge-1") + target = repo.find_target("session-a", "worker-child") + assert edge is not None + assert target is not None + assert edge.state is BroadcastEdgeState.COMPLETED + assert target.state is BroadcastTargetState.COMPLETED + assert target.completed_replica_id == completed_replica_id +``` + +- [ ] **Step 2: Run repository tests and verify they fail** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_repository.py -v +``` + +Expected: FAIL with import errors for `BroadcastRepository` and broadcast model classes. + +- [ ] **Step 3: Add schema and migration** + +In `schema.sql`, add the `broadcast_sessions`, `broadcast_targets`, and `broadcast_edges` tables from [docs/designs/0117-control-plane-tree-broadcast-phase2.md](/data/tot/tensorcast/docs/designs/0117-control-plane-tree-broadcast-phase2.md), plus the `artifact_transports.broadcast_session_id` and `artifact_transports.broadcast_edge_id` columns. + +Create `tensorcast/global_store/migrations/0019_broadcast_sessions.py`: + +```python +# Copyright (c) 2026, TensorCast Team. + +"""Migration 0019: add broadcast session state.""" + +from __future__ import annotations + +from duckdb import DuckDBPyConnection + +UP_QUERIES: tuple[str, ...] = ( + """ + CREATE TABLE IF NOT EXISTS broadcast_sessions ( + session_id TEXT PRIMARY KEY, + artifact_id TEXT NOT NULL, + requested_view_id TEXT NULL, + epoch BIGINT NOT NULL, + fanout INTEGER NOT NULL, + max_attempts INTEGER NOT NULL DEFAULT 3, + strict_parent BOOLEAN NOT NULL DEFAULT TRUE, + state TEXT CHECK (state IN ('planning','active','completed','failed','cancelled')) NOT NULL, + root_replica_id UUID NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL + ); + """, + """ + CREATE TABLE IF NOT EXISTS broadcast_targets ( + session_id TEXT NOT NULL, + target_worker_id TEXT NOT NULL, + target_daemon_id TEXT NULL, + state TEXT CHECK (state IN ('pending','assigned','materializing','completed','failed','cancelled')) NOT NULL, + level INTEGER NULL, + attempt INTEGER NOT NULL DEFAULT 0, + assigned_edge_id TEXT NULL, + completed_replica_id UUID NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, + PRIMARY KEY (session_id, target_worker_id) + ); + """, + """ + CREATE TABLE IF NOT EXISTS broadcast_edges ( + edge_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + parent_worker_id TEXT NOT NULL, + parent_replica_id UUID NOT NULL, + child_worker_id TEXT NOT NULL, + level INTEGER NOT NULL, + attempt INTEGER NOT NULL DEFAULT 1, + state TEXT CHECK (state IN ('planned','assigned','materializing','completed','failed','cancelled')) NOT NULL, + transport_request_id TEXT NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL + ); + """, + """ + CREATE UNIQUE INDEX IF NOT EXISTS idx_broadcast_edges_one_active_child + ON broadcast_edges(session_id, child_worker_id) + WHERE state IN ('planned','assigned','materializing'); + """, + "ALTER TABLE artifact_transports ADD COLUMN IF NOT EXISTS broadcast_session_id TEXT NULL;", + "ALTER TABLE artifact_transports ADD COLUMN IF NOT EXISTS broadcast_edge_id TEXT NULL;", + """ + CREATE INDEX IF NOT EXISTS idx_artifact_transports_broadcast + ON artifact_transports(broadcast_session_id, broadcast_edge_id, status); + """, +) + +DOWN_QUERIES: tuple[str, ...] = ( + "DROP INDEX IF EXISTS idx_artifact_transports_broadcast;", + "ALTER TABLE artifact_transports DROP COLUMN IF EXISTS broadcast_edge_id;", + "ALTER TABLE artifact_transports DROP COLUMN IF EXISTS broadcast_session_id;", + "DROP INDEX IF EXISTS idx_broadcast_edges_one_active_child;", + "DROP TABLE IF EXISTS broadcast_edges;", + "DROP TABLE IF EXISTS broadcast_targets;", + "DROP TABLE IF EXISTS broadcast_sessions;", +) + + +def upgrade(conn: DuckDBPyConnection) -> None: + """Apply migration 0019.""" + for query in UP_QUERIES: + conn.execute(query) + + +def downgrade(conn: DuckDBPyConnection) -> None: + """Rollback migration 0019.""" + for query in DOWN_QUERIES: + conn.execute(query) +``` + +- [ ] **Step 4: Add broadcast models** + +Create `tensorcast/global_store/models/broadcast.py`: + +```python +# Copyright (c) 2026, TensorCast Team. + +"""Broadcast session domain models.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from uuid import UUID + + +class BroadcastSessionState(str, Enum): + PLANNING = "planning" + ACTIVE = "active" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class BroadcastTargetState(str, Enum): + PENDING = "pending" + ASSIGNED = "assigned" + MATERIALIZING = "materializing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class BroadcastEdgeState(str, Enum): + PLANNED = "planned" + ASSIGNED = "assigned" + MATERIALIZING = "materializing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class BroadcastSession: + session_id: str + artifact_id: str + requested_view_id: str | None + epoch: int + fanout: int + max_attempts: int + strict_parent: bool + state: BroadcastSessionState = BroadcastSessionState.PLANNING + root_replica_id: UUID | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + completed_at: datetime | None = None + + +@dataclass +class BroadcastTarget: + session_id: str + target_worker_id: str + target_daemon_id: str | None + state: BroadcastTargetState = BroadcastTargetState.PENDING + level: int | None = None + attempt: int = 0 + assigned_edge_id: str | None = None + completed_replica_id: UUID | None = None + failure_reason: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + completed_at: datetime | None = None + + +@dataclass +class BroadcastEdge: + edge_id: str + session_id: str + parent_worker_id: str + parent_replica_id: UUID + child_worker_id: str + level: int + attempt: int = 1 + state: BroadcastEdgeState = BroadcastEdgeState.PLANNED + transport_request_id: str | None = None + failure_reason: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + completed_at: datetime | None = None +``` + +Export these names from `tensorcast/global_store/models/__init__.py`. + +- [ ] **Step 5: Add repository** + +Create `tensorcast/global_store/repositories/broadcast_repository.py` and implement these public methods: + +- `create_session(self, session: BroadcastSession, cursor: DuckDBPyConnection | None = None) -> BroadcastSession` +- `find_session(self, session_id: str, cursor: DuckDBPyConnection | None = None) -> BroadcastSession | None` +- `update_session_state(self, session_id: str, state: BroadcastSessionState, cursor: DuckDBPyConnection | None = None) -> bool` +- `upsert_target(self, target: BroadcastTarget, cursor: DuckDBPyConnection | None = None) -> BroadcastTarget` +- `find_target(self, session_id: str, target_worker_id: str, cursor: DuckDBPyConnection | None = None) -> BroadcastTarget | None` +- `list_targets(self, session_id: str, cursor: DuckDBPyConnection | None = None) -> list[BroadcastTarget]` +- `list_targets_by_state(self, session_id: str, state: BroadcastTargetState, limit: int, cursor: DuckDBPyConnection | None = None) -> list[BroadcastTarget]` +- `create_edge(self, edge: BroadcastEdge, cursor: DuckDBPyConnection | None = None) -> BroadcastEdge` +- `find_edge(self, edge_id: str, cursor: DuckDBPyConnection | None = None) -> BroadcastEdge | None` +- `find_active_edge_for_child(self, session_id: str, child_worker_id: str, cursor: DuckDBPyConnection | None = None) -> BroadcastEdge | None` +- `mark_edge_materializing(self, edge_id: str, transport_request_id: str, cursor: DuckDBPyConnection | None = None) -> bool` +- `mark_edge_failed(self, edge_id: str, reason: str, cursor: DuckDBPyConnection | None = None) -> bool` +- `mark_edge_completed(self, edge_id: str, completed_replica_id: UUID | None, cursor: DuckDBPyConnection | None = None) -> bool` +- `count_incomplete_targets(self, session_id: str, cursor: DuckDBPyConnection | None = None) -> int` + +Use `_normalize_required_text()` and `_normalize_optional_text()` helpers patterned after `TransportRepository`. Convert state strings through the enum classes in `_row_to_session`, `_row_to_target`, and `_row_to_edge`. + +Export `BroadcastRepository` from `tensorcast/global_store/repositories/__init__.py` and add it to the `repositories` fixture in `tests/python/global_store/conftest.py`. + +- [ ] **Step 6: Run repository tests and commit** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_repository.py -v +``` + +Expected: PASS. + +Commit: + +```bash +git add schema.sql tensorcast/global_store/migrations/0019_broadcast_sessions.py tensorcast/global_store/models/__init__.py tensorcast/global_store/models/broadcast.py tensorcast/global_store/repositories/__init__.py tensorcast/global_store/repositories/broadcast_repository.py tests/python/global_store/conftest.py tests/python/global_store/test_broadcast_repository.py +git commit -m "feat(global-store): add broadcast state repository" +``` + +### Task 2: Broadcast Service And Global Store RPC + +**Files:** +- Create: `tensorcast/global_store/services/broadcast_service.py` +- Create: `tensorcast/global_store/rpc/broadcast_rpc_handler.py` +- Modify: `tensorcast/global_store/services/__init__.py` +- Modify: `tensorcast/global_store/grpc_service.py` +- Modify: `tensorcast/global_store/rpc_servicer_mixins.py` +- Modify: `proto/tensorcast/global_store/v1/global_store.proto` +- Test: add `tests/python/global_store/test_broadcast_service.py` +- Test: add `tests/python/global_store/test_broadcast_rpc.py` + +- [ ] **Step 1: Write failing service tests** + +Create `tests/python/global_store/test_broadcast_service.py`: + +```python +from __future__ import annotations + +from tensorcast.global_store.models import ( + BroadcastEdgeState, + BroadcastSessionState, + BroadcastTargetState, + ExportState, + MemoryType, + Replica, + Worker, +) +from tensorcast.global_store.services import BroadcastService + + +def _worker(worker_id: str, daemon_id: str, node_id: str) -> Worker: + return Worker( + worker_id=worker_id, + daemon_id=daemon_id, + node_id=node_id, + node_address=f"10.0.0.{node_id[-1]}", + grpc_port=5000 + int(node_id[-1]), + p2p_port=6000 + int(node_id[-1]), + mem_pool_total_size=4096, + mem_pool_available_size=4096, + accepting_new_requests=True, + ) + + +def _exportable_replica(artifact_id: str, worker: Worker) -> Replica: + return Replica( + artifact_id=artifact_id, + node_id=worker.node_id, + node_address=worker.node_address, + node_port=worker.p2p_port, + memory_size=1024, + memory_type=MemoryType.GPU, + device_id=0, + max_concurrency=4, + current_requests=0, + is_available=True, + remote_memory_keys=[f"rk-{worker.worker_id}"], + buffer_sizes=[1024], + export_state=ExportState.EXPORTABLE, + worker_id=worker.worker_id, + ) + + +def test_create_session_plans_first_layer_by_fanout(repositories): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root", "daemon-root", "node1") + child1 = _worker("worker-child-1", "daemon-child-1", "node2") + child2 = _worker("worker-child-2", "daemon-child-2", "node3") + child3 = _worker("worker-child-3", "daemon-child-3", "node4") + for worker in (root, child1, child2, child3): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-a", root)) + + session = service.create_session( + session_id="session-a", + artifact_id="mi2:model-a", + requested_view_id=None, + epoch=42, + fanout=2, + target_daemon_ids=["daemon-child-1", "daemon-child-2", "daemon-child-3"], + root_replica_id=str(root_replica.replica_id), + strict_parent=True, + max_attempts=3, + ) + + assert session.state is BroadcastSessionState.ACTIVE + targets = broadcast_repo.list_targets("session-a") + assert len(targets) == 3 + assigned = [t for t in targets if t.state is BroadcastTargetState.ASSIGNED] + pending = [t for t in targets if t.state is BroadcastTargetState.PENDING] + assert len(assigned) == 2 + assert len(pending) == 1 + edges = [ + broadcast_repo.find_active_edge_for_child("session-a", t.target_worker_id) + for t in assigned + ] + assert all(edge is not None for edge in edges) + assert all(edge.state is BroadcastEdgeState.PLANNED for edge in edges if edge) + assert all(edge.parent_replica_id == root_replica.replica_id for edge in edges if edge) +``` + +- [ ] **Step 2: Run service test and verify it fails** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_service.py::test_create_session_plans_first_layer_by_fanout -v +``` + +Expected: FAIL with `ImportError` for `BroadcastService`. + +- [ ] **Step 3: Implement `BroadcastService.create_session()` and first-layer planning** + +Create `tensorcast/global_store/services/broadcast_service.py`: + +```python +# Copyright (c) 2026, TensorCast Team. + +"""Broadcast session planning and progress service.""" + +from __future__ import annotations + +import uuid +from uuid import UUID + +from tensorcast.global_store.exceptions import NotFoundError, ValidationError +from tensorcast.global_store.models import ( + BroadcastEdge, + BroadcastEdgeState, + BroadcastSession, + BroadcastSessionState, + BroadcastTarget, + BroadcastTargetState, +) +from tensorcast.global_store.repositories.broadcast_repository import BroadcastRepository +from tensorcast.global_store.repositories.replica_repository import ReplicaRepository +from tensorcast.global_store.repositories.worker_repository import WorkerRepository + + +class BroadcastService: + def __init__( + self, + *, + broadcast_repository: BroadcastRepository, + replica_repository: ReplicaRepository, + worker_repository: WorkerRepository, + ) -> None: + self.broadcast_repository = broadcast_repository + self.replica_repository = replica_repository + self.worker_repository = worker_repository + + def create_session( + self, + *, + session_id: str, + artifact_id: str, + requested_view_id: str | None, + epoch: int, + fanout: int, + target_worker_ids: list[str] | None = None, + target_daemon_ids: list[str] | None = None, + root_replica_id: str | None = None, + strict_parent: bool = True, + max_attempts: int = 3, + ) -> BroadcastSession: + session_id = session_id.strip() + artifact_id = artifact_id.strip() + if not session_id: + raise ValidationError("session_id is required") + if not artifact_id: + raise ValidationError("artifact_id is required") + if int(epoch) < 0: + raise ValidationError("epoch must be non-negative") + if int(fanout) <= 0: + raise ValidationError("fanout must be positive") + if int(max_attempts) <= 0: + raise ValidationError("max_attempts must be positive") + + targets = self._resolve_targets( + target_worker_ids=target_worker_ids or [], + target_daemon_ids=target_daemon_ids or [], + ) + if not targets: + raise ValidationError("at least one target is required") + + root_replica = self._resolve_root_replica( + artifact_id=artifact_id, + requested_view_id=requested_view_id, + root_replica_id=root_replica_id, + ) + session = BroadcastSession( + session_id=session_id, + artifact_id=artifact_id, + requested_view_id=requested_view_id, + epoch=int(epoch), + fanout=int(fanout), + max_attempts=int(max_attempts), + strict_parent=bool(strict_parent), + state=BroadcastSessionState.ACTIVE, + root_replica_id=root_replica.replica_id, + ) + with self.broadcast_repository.transaction() as tx: + existing = self.broadcast_repository.find_session(session_id, cursor=tx) + if existing is not None: + return existing + self.broadcast_repository.create_session(session, cursor=tx) + for worker_id, daemon_id in targets: + self.broadcast_repository.upsert_target( + BroadcastTarget( + session_id=session_id, + target_worker_id=worker_id, + target_daemon_id=daemon_id, + state=BroadcastTargetState.PENDING, + ), + cursor=tx, + ) + self._plan_more_edges(session, cursor=tx) + loaded = self.broadcast_repository.find_session(session_id) + if loaded is None: + raise RuntimeError(f"broadcast session missing after create: {session_id}") + return loaded +``` + +Add `_resolve_targets()`, `_resolve_root_replica()`, and `_plan_more_edges()` in the same file. `_resolve_targets()` must call `WorkerRepository.find_by_id(worker_id, include_inactive=False)` and `WorkerRepository.find_by_daemon_id(daemon_id, include_inactive=False)`. `_resolve_root_replica()` must call `ReplicaRepository.find_by_id(UUID(root_replica_id))` when a root id is provided. For the empty root id case, select the root by calling `ReplicaRepository.find_available_for_transport()` with a short heartbeat timeout, store the returned replica id on the session, and immediately decrement that replica counter after planning because session creation only reserves topology, not an active byte transfer. + +`_plan_more_edges()` should list pending targets, list completed targets, build the parent pool as `[root_replica_id] + completed_replica_id values`, and create at most `fanout - active_edges_count` new `BroadcastEdge` rows. + +Export `BroadcastService` from `tensorcast/global_store/services/__init__.py`. + +- [ ] **Step 4: Add Global Store proto RPCs and handler tests** + +Append `tests/python/global_store/test_broadcast_rpc.py`: + +```python +from __future__ import annotations + +from tensorcast.proto.global_store.v1 import global_store_pb2 + + +def test_create_broadcast_session_rpc_returns_edges(servicer, test_context, memory_info): + root_worker = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-root", + node_id="node-root", + node_address="10.10.0.1", + grpc_port=50101, + p2p_port=50102, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + child_worker = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-child", + node_id="node-child", + node_address="10.10.0.2", + grpc_port=50201, + p2p_port=50202, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + memory_info.node_id = "node-root" + memory_info.node_address = "10.10.0.1" + memory_info.node_port = 50102 + register_resp = servicer.RegisterReplica( + global_store_pb2.RegisterReplicaRequest( + artifact_id="mi2:model-rpc", + worker_id=root_worker, + memory_info=memory_info, + max_concurrency=4, + ), + test_context, + ) + assert register_resp.status == global_store_pb2.STATUS_OK + + response = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-rpc", + artifact_id="mi2:model-rpc", + epoch=7, + fanout=1, + strict_parent=True, + max_attempts=3, + root_replica_id=register_resp.replica_id, + targets=[ + global_store_pb2.BroadcastTargetIdentity( + worker_id=child_worker, + daemon_id="daemon-child", + ) + ], + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_OK + assert response.session.session_id == "session-rpc" + assert response.session.state == global_store_pb2.BROADCAST_SESSION_STATE_ACTIVE + edge_resp = servicer.ListBroadcastEdges( + global_store_pb2.ListBroadcastEdgesRequest(session_id="session-rpc"), + test_context, + ) + assert edge_resp.status == global_store_pb2.STATUS_OK + assert len(edge_resp.edges) == 1 + assert edge_resp.edges[0].child_worker_id == child_worker +``` + +- [ ] **Step 5: Run RPC test and verify it fails** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_rpc.py::test_create_broadcast_session_rpc_returns_edges -v +``` + +Expected: FAIL because `CreateBroadcastSessionRequest` and related proto messages do not exist. + +- [ ] **Step 6: Extend Global Store proto and regenerate Python stubs** + +Modify `proto/tensorcast/global_store/v1/global_store.proto`: + +```proto +service ClusterRuntimeService { + // Insert these RPCs beside the existing transport and metadata RPCs. + rpc CreateBroadcastSession(CreateBroadcastSessionRequest) returns (CreateBroadcastSessionResponse) {} + rpc GetBroadcastSession(GetBroadcastSessionRequest) returns (GetBroadcastSessionResponse) {} + rpc ListBroadcastEdges(ListBroadcastEdgesRequest) returns (ListBroadcastEdgesResponse) {} + rpc CancelBroadcastSession(CancelBroadcastSessionRequest) returns (CancelBroadcastSessionResponse) {} +} + +enum BroadcastSessionState { + BROADCAST_SESSION_STATE_UNSPECIFIED = 0; + BROADCAST_SESSION_STATE_PLANNING = 1; + BROADCAST_SESSION_STATE_ACTIVE = 2; + BROADCAST_SESSION_STATE_COMPLETED = 3; + BROADCAST_SESSION_STATE_FAILED = 4; + BROADCAST_SESSION_STATE_CANCELLED = 5; +} + +enum BroadcastTargetState { + BROADCAST_TARGET_STATE_UNSPECIFIED = 0; + BROADCAST_TARGET_STATE_PENDING = 1; + BROADCAST_TARGET_STATE_ASSIGNED = 2; + BROADCAST_TARGET_STATE_MATERIALIZING = 3; + BROADCAST_TARGET_STATE_COMPLETED = 4; + BROADCAST_TARGET_STATE_FAILED = 5; + BROADCAST_TARGET_STATE_CANCELLED = 6; +} + +enum BroadcastEdgeState { + BROADCAST_EDGE_STATE_UNSPECIFIED = 0; + BROADCAST_EDGE_STATE_PLANNED = 1; + BROADCAST_EDGE_STATE_ASSIGNED = 2; + BROADCAST_EDGE_STATE_MATERIALIZING = 3; + BROADCAST_EDGE_STATE_COMPLETED = 4; + BROADCAST_EDGE_STATE_FAILED = 5; + BROADCAST_EDGE_STATE_CANCELLED = 6; +} + +message BroadcastTargetIdentity { + string worker_id = 1; + string daemon_id = 2; +} + +message BroadcastSessionInfo { + string session_id = 1; + string artifact_id = 2; + tensorcast.common.v1.ByteSpaceRef requested_byte_space = 3; + uint64 epoch = 4; + uint32 fanout = 5; + uint32 max_attempts = 6; + bool strict_parent = 7; + BroadcastSessionState state = 8; + string root_replica_id = 9; +} + +message BroadcastTargetInfo { + string session_id = 1; + string target_worker_id = 2; + string target_daemon_id = 3; + BroadcastTargetState state = 4; + uint32 level = 5; + uint32 attempt = 6; + string assigned_edge_id = 7; + string completed_replica_id = 8; + string failure_reason = 9; +} + +message BroadcastEdgeInfo { + string edge_id = 1; + string session_id = 2; + string parent_worker_id = 3; + string parent_replica_id = 4; + string child_worker_id = 5; + uint32 level = 6; + uint32 attempt = 7; + BroadcastEdgeState state = 8; + string transport_request_id = 9; + string failure_reason = 10; +} + +message CreateBroadcastSessionRequest { + string session_id = 1; + string artifact_id = 2; + tensorcast.common.v1.ByteSpaceRef requested_byte_space = 3; + uint64 epoch = 4; + uint32 fanout = 5; + repeated BroadcastTargetIdentity targets = 6; + string root_replica_id = 7; + bool strict_parent = 8; + uint32 max_attempts = 9; +} + +message CreateBroadcastSessionResponse { + Status status = 1; + BroadcastSessionInfo session = 2; + repeated BroadcastTargetInfo targets = 3; + repeated BroadcastEdgeInfo edges = 4; +} + +message GetBroadcastSessionRequest { + string session_id = 1; +} + +message GetBroadcastSessionResponse { + Status status = 1; + BroadcastSessionInfo session = 2; + repeated BroadcastTargetInfo targets = 3; +} + +message ListBroadcastEdgesRequest { + string session_id = 1; +} + +message ListBroadcastEdgesResponse { + Status status = 1; + repeated BroadcastEdgeInfo edges = 2; +} + +message CancelBroadcastSessionRequest { + string session_id = 1; + string reason = 2; +} + +message CancelBroadcastSessionResponse { + Status status = 1; +} +``` + +Regenerate stubs for local validation: + +```bash +bash tools/build_proto_python.sh +``` + +Expected: command exits 0 and Python generated files contain the new broadcast messages. Do not stage unrelated generated files that were dirty before this task. + +- [ ] **Step 7: Implement RPC handler and service wiring** + +Create `tensorcast/global_store/rpc/broadcast_rpc_handler.py` with a `BroadcastRpcHandler` class exposing these methods: `__init__(self, *, broadcast_service: BroadcastService, logger) -> None`, `create_broadcast_session(self, request, context)`, `get_broadcast_session(self, request, context)`, `list_broadcast_edges(self, request, context)`, and `cancel_broadcast_session(self, request, context)`. + +Use explicit mapping helpers: + +```python +_SESSION_STATE_TO_PROTO = { + BroadcastSessionState.PLANNING: global_store_pb2.BROADCAST_SESSION_STATE_PLANNING, + BroadcastSessionState.ACTIVE: global_store_pb2.BROADCAST_SESSION_STATE_ACTIVE, + BroadcastSessionState.COMPLETED: global_store_pb2.BROADCAST_SESSION_STATE_COMPLETED, + BroadcastSessionState.FAILED: global_store_pb2.BROADCAST_SESSION_STATE_FAILED, + BroadcastSessionState.CANCELLED: global_store_pb2.BROADCAST_SESSION_STATE_CANCELLED, +} +``` + +Wire `BroadcastRepository` in `GlobalStoreServicer._init_repositories()`, `BroadcastService` in the service initialization path, `BroadcastRpcHandler` in handler initialization, and add forwarding methods to `ClusterRuntimeRpcServicerMixin`. + +- [ ] **Step 8: Run service/RPC tests and commit** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_service.py tests/python/global_store/test_broadcast_rpc.py -v +``` + +Expected: PASS. + +Commit: + +```bash +git add proto/tensorcast/global_store/v1/global_store.proto tensorcast/global_store/services/__init__.py tensorcast/global_store/services/broadcast_service.py tensorcast/global_store/rpc/broadcast_rpc_handler.py tensorcast/global_store/grpc_service.py tensorcast/global_store/rpc_servicer_mixins.py tests/python/global_store/test_broadcast_service.py tests/python/global_store/test_broadcast_rpc.py +git commit -m "feat(global-store): add broadcast session rpc" +``` + +### Task 3: Broadcast-Aware Transport Selection And Completion + +**Files:** +- Modify: `tensorcast/global_store/models/transport.py` +- Modify: `tensorcast/global_store/repositories/transport_repository.py` +- Modify: `tensorcast/global_store/repositories/replica_repository.py` +- Modify: `tensorcast/global_store/services/broadcast_service.py` +- Modify: `tensorcast/global_store/services/transport_service.py` +- Modify: `tensorcast/global_store/rpc/transport_rpc_handler.py` +- Modify: `proto/tensorcast/global_store/v1/global_store.proto` +- Test: add `tests/python/global_store/test_broadcast_transport.py` +- Test: update `tests/python/global_store/test_services.py` only if existing group-dispatch assertions need broadcast columns in row projections. + +- [ ] **Step 1: Write failing broadcast transport tests** + +Create `tests/python/global_store/test_broadcast_transport.py`: + +```python +from __future__ import annotations + +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.global_store.v1 import global_store_pb2 + + +def _register_worker(servicer, context, worker_id_suffix: str) -> str: + response = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id=f"daemon-{worker_id_suffix}", + node_id=f"node-{worker_id_suffix}", + node_address=f"10.20.0.{worker_id_suffix}", + grpc_port=51000 + int(worker_id_suffix), + p2p_port=52000 + int(worker_id_suffix), + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + context, + ) + assert response.status == global_store_pb2.STATUS_OK + return response.worker_id + + +def _register_replica(servicer, context, worker_id: str, node_suffix: str, key: str) -> str: + memory_info = global_store_pb2.RegisterReplicaRequest().memory_info + memory_info.node_id = f"node-{node_suffix}" + memory_info.node_address = f"10.20.0.{node_suffix}" + memory_info.node_port = 52000 + int(node_suffix) + memory_info.memory_size = 1024 + memory_info.memory_type = common_pb2.MemoryType.MEMORY_TYPE_GPU + memory_info.device_id = 0 + memory_info.transport.export_state = ( + common_pb2.ReplicaTransportMetadata.EXPORT_STATE_EXPORTABLE + ) + memory_info.transport.remote_memory_keys.append(key) + memory_info.transport.buffer_sizes.append(1024) + response = servicer.RegisterReplica( + global_store_pb2.RegisterReplicaRequest( + artifact_id="mi2:model-transport", + worker_id=worker_id, + memory_info=memory_info, + max_concurrency=4, + ), + context, + ) + assert response.status == global_store_pb2.STATUS_OK + return response.replica_id + + +def test_broadcast_transport_uses_edge_parent(servicer, test_context): + root_worker = _register_worker(servicer, test_context, "1") + alternate_worker = _register_worker(servicer, test_context, "2") + child_worker = _register_worker(servicer, test_context, "3") + root_replica_id = _register_replica(servicer, test_context, root_worker, "1", "rk-root") + _register_replica(servicer, test_context, alternate_worker, "2", "rk-alt") + create = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-transport", + artifact_id="mi2:model-transport", + epoch=1, + fanout=1, + root_replica_id=root_replica_id, + strict_parent=True, + max_attempts=3, + targets=[global_store_pb2.BroadcastTargetIdentity(worker_id=child_worker)], + ), + test_context, + ) + assert create.status == global_store_pb2.STATUS_OK + + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id="mi2:model-transport", + source_node_id="node-3", + source_address="10.20.0.3", + source_port=52003, + requester_worker_id=child_worker, + request_id="broadcast-request-1", + ) + request.local_memory_info.memory_type = common_pb2.MemoryType.MEMORY_TYPE_GPU + request.local_memory_info.device_id = 0 + request.broadcast.session_id = "session-transport" + request.broadcast.strict_parent = True + response = servicer.RequestReplicaTransport(request, test_context) + + assert response.status == global_store_pb2.STATUS_OK + assert response.remote_memory_info.transport.remote_memory_keys == ["rk-root"] + assert response.remote_memory_info.node_id == "node-1" + + +def test_broadcast_failed_transport_requeues_target(servicer, test_context): + root_worker = _register_worker(servicer, test_context, "1") + child_worker = _register_worker(servicer, test_context, "3") + root_replica_id = _register_replica(servicer, test_context, root_worker, "1", "rk-root") + servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-failure", + artifact_id="mi2:model-transport", + epoch=1, + fanout=1, + root_replica_id=root_replica_id, + strict_parent=True, + max_attempts=3, + targets=[global_store_pb2.BroadcastTargetIdentity(worker_id=child_worker)], + ), + test_context, + ) + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id="mi2:model-transport", + source_node_id="node-3", + source_address="10.20.0.3", + source_port=52003, + requester_worker_id=child_worker, + request_id="broadcast-request-fail", + ) + request.local_memory_info.memory_type = common_pb2.MemoryType.MEMORY_TYPE_GPU + request.local_memory_info.device_id = 0 + request.broadcast.session_id = "session-failure" + transport = servicer.RequestReplicaTransport(request, test_context) + assert transport.status == global_store_pb2.STATUS_OK + + complete = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=transport.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_FAILED, + outcome_detail="injected failure", + ), + test_context, + ) + assert complete.status == global_store_pb2.STATUS_OK + + edges = servicer.ListBroadcastEdges( + global_store_pb2.ListBroadcastEdgesRequest(session_id="session-failure"), + test_context, + ) + assert edges.status == global_store_pb2.STATUS_OK + assert any(edge.state == global_store_pb2.BROADCAST_EDGE_STATE_FAILED for edge in edges.edges) +``` + +- [ ] **Step 2: Run tests and verify they fail** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_transport.py -v +``` + +Expected: FAIL because `RequestReplicaTransportRequest.broadcast` and broadcast-aware transport routing do not exist. + +- [ ] **Step 3: Add transport broadcast fields** + +Update `Transport` model with: + +```python +broadcast_session_id: str | None = None +broadcast_edge_id: str | None = None +``` + +Update `TransportRepository._TRANSPORT_PROJECTION`, `create_with_cursor()`, `create_if_absent_with_cursor()`, `_row_to_model()`, and `list_rows_in_created_window()` to include the two fields. Preserve existing column order compatibility by appending the new fields to projections and mapping by column name. + +- [ ] **Step 4: Add specific parent claim helpers to `ReplicaRepository`** + +Add: + +```python +def claim_replica_for_transport( + self, + *, + replica_id: UUID, + artifact_id: str, + view_id: str | None, + heartbeat_timeout_seconds: float, + cursor=None, +) -> TransportSelectionResult: + """Claim one exact replica if it is currently transport eligible.""" +``` + +The method should load the same joined worker/liveness row used by `find_available_for_transport()`, call `_evaluate_transport_candidate()`, increment `replica_counters.current_requests` only when eligible, and return `TransportSelectionResult(replica=None, exportable_replicas=candidate_exportable_replicas)` when not eligible. + +Add: + +```python +def find_exportable_replica_for_worker( + self, + *, + artifact_id: str, + view_id: str | None, + worker_id: str, + heartbeat_timeout_seconds: float, + cursor=None, +) -> Replica | None: + """Return the best registered child replica after materialization completes.""" +``` + +Use the same eligibility checks, but do not increment `current_requests`. + +- [ ] **Step 5: Add `BroadcastService.claim_transport_edge()` and completion advancement** + +In `BroadcastService`, add `claim_transport_edge(self, *, session_id: str, artifact_id: str, requested_view_id: str | None, requester_worker_id: str, request_id: str, heartbeat_timeout_seconds: float, cursor: DuckDBPyConnection) -> tuple[Replica, BroadcastEdge]`. + +This method must verify session artifact/view/epoch, find or plan an active edge for `requester_worker_id`, claim `edge.parent_replica_id` with `ReplicaRepository.claim_replica_for_transport()`, mark the edge materializing with `request_id`, and return the claimed replica plus edge. + +Add `complete_transport_edge(self, *, session_id: str, edge_id: str, transport_outcome: TransportCompletionOutcome, outcome_detail: str | None, cursor: DuckDBPyConnection) -> None`. + +On success, call `ReplicaRepository.find_exportable_replica_for_worker()` for the child worker and record that replica id in `BroadcastRepository.mark_edge_completed()`. If no child replica is visible, mark the edge failed with reason `child_replica_not_exportable_after_success`. On failure, mark the edge failed and requeue the target until `max_attempts` is reached. + +- [ ] **Step 6: Integrate broadcast claim into `TransportService`** + +Add a small value object in `tensorcast/global_store/models/transport.py`: + +```python +@dataclass(frozen=True) +class BroadcastTransportHint: + session_id: str + strict_parent: bool = True +``` + +Extend `TransportService.__init__()` with optional `broadcast_service: BroadcastService | None = None`. + +Extend `request_transport()` and `_build_request_fingerprint()` with `broadcast_hint`. In `_dispatch_pending_requests()`, leave normal queued dispatch unchanged. For requests with a broadcast hint, bypass group-dispatch queueing and claim the edge synchronously inside one transaction: + +When `broadcast_hint is not None`, call `_request_transport_broadcast()` with the same normalized artifact, requester, memory, request id, timeout, and scheduling parameters already available in `request_transport()`, plus the parsed `BroadcastTransportHint`. + +`_request_transport_broadcast()` should call `broadcast_service.claim_transport_edge()`, build a `Transport` with `broadcast_session_id` and `broadcast_edge_id`, create it idempotently, and return the claimed parent replica. + +In `complete_transport()`, after `transport_repository.complete_if_in_progress()`, if `transport.broadcast_session_id` and `transport.broadcast_edge_id` are set, call `broadcast_service.complete_transport_edge()`. + +- [ ] **Step 7: Parse broadcast hint in transport RPC** + +In `proto/tensorcast/global_store/v1/global_store.proto`, add: + +```proto +message BroadcastTransportHint { + string session_id = 1; + bool strict_parent = 2; +} +``` + +Add field 11 to `RequestReplicaTransportRequest`: + +```proto +BroadcastTransportHint broadcast = 11; +``` + +Run: + +```bash +bash tools/build_proto_python.sh +``` + +In `TransportRpcHandler.request_replica_transport()`, parse: + +```python +broadcast_hint: BroadcastTransportHint | None = None +if request.HasField("broadcast"): + session_id = request.broadcast.session_id.strip() + if not session_id: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("broadcast.session_id is required") + return global_store_pb2.RequestReplicaTransportResponse( + status=global_store_pb2.Status.STATUS_ERROR + ) + broadcast_hint = BroadcastTransportHint( + session_id=session_id, + strict_parent=bool(request.broadcast.strict_parent), + ) +``` + +Pass `broadcast_hint` into `TransportService.request_transport()`. + +- [ ] **Step 8: Run transport tests and existing group-dispatch regression tests** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_transport.py tests/python/global_store/test_services.py::TestTransportService -v +``` + +Expected: PASS. Existing group-dispatch tests must still pass without broadcast hints. + +Commit: + +```bash +git add proto/tensorcast/global_store/v1/global_store.proto tensorcast/global_store/models/transport.py tensorcast/global_store/repositories/transport_repository.py tensorcast/global_store/repositories/replica_repository.py tensorcast/global_store/services/broadcast_service.py tensorcast/global_store/services/transport_service.py tensorcast/global_store/rpc/transport_rpc_handler.py tests/python/global_store/test_broadcast_transport.py +git commit -m "feat(global-store): route broadcast transports through tree edges" +``` + +### Task 4: SDK And Store Daemon Broadcast Hint Propagation + +**Files:** +- Modify: `proto/tensorcast/daemon/v2/store_daemon.proto` +- Modify: `tensorcast/api/context.py` +- Modify: `tensorcast/api/__init__.py` +- Modify: `tensorcast/__init__.py` +- Modify: `tensorcast/api/store/artifact.py` +- Modify: `tensorcast/api/_materialize.py` +- Modify: `tensorcast/daemon_ctl.py` +- Modify: `daemon/service/controllers/materialization_policy_utils.h` +- Modify: `daemon/service/controllers/materialization_policy_utils.cc` +- Modify: `daemon/service/controllers/replica_materialization_service.cc` +- Test: update `tests/python/api/test_prefetch_operation.py` +- Test: add `tests/python/api/test_daemon_ctl_broadcast_hint.py` +- Test: update `daemon/service/materialization_policy_utils_test.cc` + +- [ ] **Step 1: Write failing SDK prefetch test** + +Append to `tests/python/api/test_prefetch_operation.py`: + +```python +def test_prefetch_forwards_broadcast_context_hint() -> None: + store = _Store() + artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") + ctx = tc.context( + broadcast=tc.BroadcastContext( + session_id="broadcast-session-1", + strict_parent=True, + ) + ) + + artifact.prefetch(device="cuda:0", ctx=ctx) + + call = store._materialization.calls[0] + assert call["broadcast_session_id"] == "broadcast-session-1" + assert call["broadcast_strict_parent"] is True +``` + +- [ ] **Step 2: Run SDK test and verify it fails** + +Run: + +```bash +pytest tests/python/api/test_prefetch_operation.py::test_prefetch_forwards_broadcast_context_hint -v +``` + +Expected: FAIL because `BroadcastContext` and materialization kwargs do not exist. + +- [ ] **Step 3: Add SDK context type and prefetch forwarding** + +In `tensorcast/api/context.py`, add: + +```python +@dataclass(frozen=True, slots=True) +class BroadcastContext: + """Strict tree broadcast session hint for materialization.""" + + session_id: str + strict_parent: bool = True + + def __post_init__(self) -> None: + session_id = str(self.session_id).strip() + if not session_id: + raise ValueError("BroadcastContext.session_id must be non-empty") + object.__setattr__(self, "session_id", session_id) + object.__setattr__(self, "strict_parent", bool(self.strict_parent)) +``` + +Add `broadcast: BroadcastContext | None = None` to `CallContext` and `context()`, and export `BroadcastContext` from `tensorcast/api/context.py`, `tensorcast/api/__init__.py`, and `tensorcast/__init__.py`. + +In `Artifact.prefetch()`, pass: + +```python +broadcast_session_id=ctx.broadcast.session_id if ctx and ctx.broadcast else None, +broadcast_strict_parent=ctx.broadcast.strict_parent if ctx and ctx.broadcast else True, +``` + +to `pipeline.materialize_subset()`. Thread the same kwargs through `MaterializationPipeline.materialize_subset()` into `materialize_artifact_v2()`. + +- [ ] **Step 4: Write failing DaemonCtl proto copy test** + +Create `tests/python/api/test_daemon_ctl_broadcast_hint.py`: + +```python +from __future__ import annotations + +from tensorcast.daemon_ctl import DaemonCtl +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.daemon.v2 import store_daemon_pb2 + + +class _FakeUnary: + def __init__(self) -> None: + self.requests: list[store_daemon_pb2.MaterializeReplicaRequest] = [] + + def __call__(self, request, timeout=None, metadata=None): # noqa: ANN001 + del timeout, metadata + self.requests.append(request) + response = store_daemon_pb2.MaterializeReplicaResponse() + response.status = store_daemon_pb2.MATERIALIZE_REPLICA_STATUS_ALLOCATED + return response + + +class _FakeStub: + def __init__(self) -> None: + self.MaterializeReplica = _FakeUnary() + + +def test_daemon_ctl_copies_broadcast_hint(monkeypatch) -> None: # noqa: ANN001 + fake_stub = _FakeStub() + ctl = DaemonCtl("127.0.0.1:1") + ctl.stub_v2 = fake_stub + monkeypatch.setattr(ctl, "_get_effective_pid", lambda: 123) + + selection = common_pb2.ArtifactSelection(artifact_id="mi2:model") + ctl.materialize_by_artifact_id_v2( + selection=selection, + replica_uuid="replica-1", + device_uuid="gpu-uuid", + return_response=True, + wait_for_completion=False, + broadcast_session_id="session-a", + broadcast_strict_parent=True, + ) + + request = fake_stub.MaterializeReplica.requests[0] + assert request.broadcast.session_id == "session-a" + assert request.broadcast.strict_parent is True +``` + +- [ ] **Step 5: Update daemon proto and DaemonCtl** + +In `proto/tensorcast/daemon/v2/store_daemon.proto`, add: + +```proto +message BroadcastMaterializationHint { + string session_id = 1; + bool strict_parent = 2; +} +``` + +Add field 23 to `MaterializeReplicaRequest`: + +```proto +BroadcastMaterializationHint broadcast = 23; +``` + +Run: + +```bash +bash tools/build_proto_python.sh +``` + +In `tensorcast/daemon_ctl.py`, add overload and implementation kwargs: + +```python +broadcast_session_id: str | None = None, +broadcast_strict_parent: bool = True, +``` + +and copy: + +```python +if broadcast_session_id: + request.broadcast.session_id = str(broadcast_session_id) + request.broadcast.strict_parent = bool(broadcast_strict_parent) +``` + +In `tensorcast/api/_materialize.py`, add the same kwargs and pass them to `client.materialize_by_artifact_id_v2()`. + +- [ ] **Step 6: Add C++ daemon mapping test** + +Append to `daemon/service/materialization_policy_utils_test.cc`: + +```c++ +using tensorcast::daemon::materialization_policy::resolve_broadcast_materialization_hint; + +TEST_CASE("Broadcast materialization hint maps daemon proto", "[daemon][materialization][policy]") { + v2::BroadcastMaterializationHint proto; + proto.set_session_id("session-a"); + proto.set_strict_parent(true); + + auto hint = resolve_broadcast_materialization_hint(&proto); + + REQUIRE(hint.has_value()); + CHECK(hint->session_id == "session-a"); + CHECK(hint->strict_parent); +} +``` + +Run: + +```bash +bazel test //daemon:materialization_policy_utils_test --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: FAIL because `resolve_broadcast_materialization_hint` does not exist. + +- [ ] **Step 7: Implement C++ daemon hint mapping** + +In `core/store/materialization/contracts/loading_spec.h`, add: + +```c++ +struct BroadcastHint { + std::string session_id; + bool strict_parent{true}; +}; +``` + +Add to `MaterializeHints`: + +```c++ +std::optional broadcast; +``` + +In `daemon/service/controllers/materialization_policy_utils.h`, declare: + +```c++ +std::optional resolve_broadcast_materialization_hint( + const v2::BroadcastMaterializationHint* hint); +``` + +In `.cc`, define it to trim/validate non-empty `session_id` and return `std::nullopt` for null or empty hints. + +In `ReplicaMaterializationService::materialize_replica()`, after transport group mapping: + +```c++ +if (req.has_broadcast()) { + auto broadcast_hint = resolve_broadcast_materialization_hint(&req.broadcast()); + if (broadcast_hint.has_value()) { + hints.broadcast = std::move(*broadcast_hint); + } +} +``` + +- [ ] **Step 8: Run SDK and daemon hint tests and commit** + +Run: + +```bash +pytest tests/python/api/test_prefetch_operation.py::test_prefetch_forwards_broadcast_context_hint tests/python/api/test_daemon_ctl_broadcast_hint.py -v +bazel test //daemon:materialization_policy_utils_test --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: PASS. + +Commit: + +```bash +git add proto/tensorcast/daemon/v2/store_daemon.proto tensorcast/api/context.py tensorcast/api/__init__.py tensorcast/__init__.py tensorcast/api/store/artifact.py tensorcast/api/_materialize.py tensorcast/daemon_ctl.py daemon/service/controllers/materialization_policy_utils.h daemon/service/controllers/materialization_policy_utils.cc daemon/service/controllers/replica_materialization_service.cc core/store/materialization/contracts/loading_spec.h tests/python/api/test_prefetch_operation.py tests/python/api/test_daemon_ctl_broadcast_hint.py daemon/service/materialization_policy_utils_test.cc +git commit -m "feat(materialize): propagate broadcast session hints" +``` + +### Task 5: C++ Global Store Client And Materialization Completion Semantics + +**Files:** +- Modify: `core/store/components/global_store_client.h` +- Modify: `core/store/components/global_store_client.cc` +- Modify: `core/store/testing/recording_global_store_client.h` +- Modify: `core/store/materialization/control/materialize_orchestrator.cc` +- Modify: `core/store/runtime/ingestion/materialization_facade.cc` +- Test: update `core/store/materialization/control/materialize_orchestrator_reselection_test.cc` + +- [ ] **Step 1: Write failing C++ propagation test** + +Append to `core/store/materialization/control/materialize_orchestrator_reselection_test.cc`: + +```c++ +TEST_CASE( + "MaterializeOrchestrator propagates broadcast hint to transport request", + "[store][materialize][broadcast]") { + auto gs_client = std::make_shared(); + gs_client->connected = true; + gs_client->allow_replica_transport = true; + gs_client->push_scripted_transport_session(make_transport_session( + "transport-broadcast", "node-remote", "10.9.9.2", 50042, common::memory::MemoryLocation::GPU, 0)); + + FakeMaterializationBackend backend; + MaterializeHints hints; + hints.artifact_id = "artifact-broadcast"; + hints.allow_p2p = true; + hints.allow_disk = false; + hints.transport_request_id = "request-broadcast-1"; + hints.broadcast = loading::BroadcastHint{ + .session_id = "session-a", + .strict_parent = true, + }; + + components::WorkerIdentity local_identity{ + .worker_id = "worker-local", + .node_id = "node-local", + .node_address = "10.9.9.1", + .p2p_port = 50041, + }; + MaterializeOrchestrator orchestrator( + gsl::not_null{&backend}, + gsl::not_null{gs_client.get()}, + local_identity); + + auto result = orchestrator.run("artifact-broadcast", make_gpu_target(0), hints, std::nullopt); + REQUIRE(result.ok()); + REQUIRE(gs_client->replica_request_broadcast_hints.size() == 1); + REQUIRE(gs_client->replica_request_broadcast_hints.front().has_value()); + CHECK(gs_client->replica_request_broadcast_hints.front()->session_id == "session-a"); + CHECK(gs_client->replica_request_broadcast_hints.front()->strict_parent); +} +``` + +- [ ] **Step 2: Run C++ test and verify it fails** + +Run: + +```bash +bazel test //core/store/materialization/control:materialize_orchestrator_reselection_test --test_env=TENSORCAST_CUDA_BACKEND=fake --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: FAIL because `BroadcastHint` is not part of `IGlobalStoreClient` requests. + +- [ ] **Step 3: Add C++ GlobalStoreClient broadcast hint type and proto mapping** + +In `core/store/components/global_store_client.h`, add: + +```c++ +struct BroadcastTransportHint { + std::string session_id; + bool strict_parent{true}; +}; +``` + +Add `const std::optional& broadcast_hint = std::nullopt` to `request_replica_transport()` and `request_view_transport()` interface and implementation signatures, after `scheduling_group`. + +In `core/store/components/global_store_client.cc`, add: + +```c++ +void apply_broadcast_transport_hint( + const std::optional& broadcast_hint, + global_store::RequestReplicaTransportRequest* request) { + if (!broadcast_hint.has_value() || broadcast_hint->session_id.empty()) { + return; + } + auto* out = request->mutable_broadcast(); + out->set_session_id(broadcast_hint->session_id); + out->set_strict_parent(broadcast_hint->strict_parent); +} +``` + +Call this helper in both request methods. + +- [ ] **Step 4: Thread hints from `MaterializeHints` to C++ client** + +In `materialize_orchestrator.cc` and `materialization_facade.cc`, add a helper: + +```c++ +std::optional to_broadcast_transport_hint( + const MaterializeHints& hints) { + if (!hints.broadcast.has_value() || hints.broadcast->session_id.empty()) { + return std::nullopt; + } + return components::BroadcastTransportHint{ + .session_id = hints.broadcast->session_id, + .strict_parent = hints.broadcast->strict_parent, + }; +} +``` + +Pass `broadcast_hint` immediately after `scheduling_group_hint` in all `request_replica_transport()` and `request_view_transport()` calls. Update `core/store/testing/recording_global_store_client.h` to capture `replica_request_broadcast_hints` and `view_request_broadcast_hints`. + +- [ ] **Step 5: Change broadcast success completion ordering** + +In `MaterializeOrchestrator::run()`, change the P2P success block: + +```c++ +if (load_or.ok()) { + const auto& handle = *load_or; + absl::Status reg_status = backend_->register_replica_with_global_store(handle.key(), {}); + const bool broadcast_request = hints.broadcast.has_value() && !hints.broadcast->session_id.empty(); + if (!reg_status.ok()) { + LOG(WARNING) << "register_replica_with_global_store returned error: " << reg_status; + if (broadcast_request) { + absl::Status comp_status = gs_client_->complete_replica_transport( + session.transport_id, + components::TransportCompletionOutcome::kFailed, + reg_status.ToString()); + if (!comp_status.ok()) { + LOG(WARNING) << "complete_replica_transport returned error: " << comp_status; + } + last_p2p_status = reg_status; + // Continue existing reselection/fallback handling. + } else { + absl::Status comp_status = gs_client_->complete_replica_transport( + session.transport_id, components::TransportCompletionOutcome::kSuccess); + if (!comp_status.ok()) { + LOG(WARNING) << "complete_replica_transport returned error: " << comp_status; + } + return load_or; + } + } else { + absl::Status comp_status = gs_client_->complete_replica_transport( + session.transport_id, components::TransportCompletionOutcome::kSuccess); + if (!comp_status.ok()) { + LOG(WARNING) << "complete_replica_transport returned error: " << comp_status; + } + return load_or; + } +} +``` + +Preserve the existing failed P2P completion path for `load_or` failures. Apply the same broadcast-specific ordering in `MaterializationFacade` paths that directly request P2P transport and register replicas. + +- [ ] **Step 6: Add failure completion test** + +Add a C++ test using `FakeMaterializationBackend` with `fail_register_replica` behavior. The assertion should verify: + +```c++ +REQUIRE(gs_client->completed_transport_ids.size() == 1); +CHECK(gs_client->completed_transport_ids.front() == "transport-broadcast"); +CHECK(gs_client->completed_transport_outcomes.front() == components::TransportCompletionOutcome::kFailed); +``` + +Run: + +```bash +bazel test //core/store/materialization/control:materialize_orchestrator_reselection_test --test_env=TENSORCAST_CUDA_BACKEND=fake --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: PASS. + +- [ ] **Step 7: Commit C++ propagation and completion semantics** + +Commit: + +```bash +git add core/store/components/global_store_client.h core/store/components/global_store_client.cc core/store/testing/recording_global_store_client.h core/store/materialization/control/materialize_orchestrator.cc core/store/runtime/ingestion/materialization_facade.cc core/store/materialization/control/materialize_orchestrator_reselection_test.cc +git commit -m "feat(core): enforce broadcast transport parent hints" +``` + +### Task 6: Daemon-Mediated Broadcast Session API + +**Files:** +- Modify: `proto/tensorcast/daemon/v2/store_daemon.proto` +- Modify: `tensorcast/daemon_ctl.py` +- Modify: daemon service implementation files under `daemon/service/` +- Modify: `tensorcast/api/store/__init__.py` +- Test: add `tests/python/api/test_broadcast_session_api.py` +- Test: add daemon C++ controller test if a local controller seam exists; otherwise use Python DaemonCtl request construction test plus Global Store RPC tests. + +- [ ] **Step 1: Add SDK API failing test** + +Create `tests/python/api/test_broadcast_session_api.py`: + +```python +from __future__ import annotations + +import tensorcast as tc +from tensorcast.proto.daemon.v2 import store_daemon_pb2 + + +def test_store_create_broadcast_session_calls_daemon(monkeypatch) -> None: # noqa: ANN001 + calls: list[dict[str, object]] = [] + + class _Client: + def create_broadcast_session(self, **kwargs): + calls.append(kwargs) + response = store_daemon_pb2.CreateBroadcastSessionResponse() + response.status = store_daemon_pb2.BROADCAST_SESSION_STATUS_OK + response.session_id = "session-a" + return response + + class _Runtime: + daemon_endpoint = "daemon" + daemon_id = "daemon-local" + closed = False + + def ensure_client(self): + return _Client() + + from tensorcast.api.store import Store + + store = Store("daemon", runtime=_Runtime()) + session = store.create_broadcast_session( + session_id="session-a", + artifact_id="mi2:model", + epoch=42, + fanout=2, + target_daemon_ids=["daemon-a", "daemon-b"], + root_replica_id="00000000-0000-0000-0000-000000000001", + ) + + assert session.session_id == "session-a" + assert calls[0]["artifact_id"] == "mi2:model" + assert calls[0]["target_daemon_ids"] == ["daemon-a", "daemon-b"] +``` + +- [ ] **Step 2: Add daemon proto for session forwarding** + +In `proto/tensorcast/daemon/v2/store_daemon.proto`, add daemon-facing messages: + +```proto +enum BroadcastSessionStatus { + BROADCAST_SESSION_STATUS_UNSPECIFIED = 0; + BROADCAST_SESSION_STATUS_OK = 1; + BROADCAST_SESSION_STATUS_ERROR = 2; + BROADCAST_SESSION_STATUS_NOT_FOUND = 3; +} + +message CreateBroadcastSessionRequest { + string session_id = 1; + string artifact_id = 2; + string requested_view_id = 3; + uint64 epoch = 4; + uint32 fanout = 5; + repeated string target_worker_ids = 6; + repeated string target_daemon_ids = 7; + string root_replica_id = 8; + bool strict_parent = 9; + uint32 max_attempts = 10; +} + +message CreateBroadcastSessionResponse { + BroadcastSessionStatus status = 1; + string session_id = 2; +} +``` + +Add the Store Daemon service RPC: + +```proto +rpc CreateBroadcastSession(CreateBroadcastSessionRequest) returns (CreateBroadcastSessionResponse); +``` + +Run: + +```bash +bash tools/build_proto_python.sh +``` + +- [ ] **Step 3: Implement `DaemonCtl.create_broadcast_session()` and SDK store helper** + +In `tensorcast/daemon_ctl.py`, add: + +```python +def create_broadcast_session( + self, + *, + session_id: str, + artifact_id: str, + epoch: int, + fanout: int, + target_worker_ids: list[str] | None = None, + target_daemon_ids: list[str] | None = None, + requested_view_id: str | None = None, + root_replica_id: str | None = None, + strict_parent: bool = True, + max_attempts: int = 3, + timeout_s: float = 30.0, +) -> store_daemon_pb2.CreateBroadcastSessionResponse: + request = store_daemon_pb2.CreateBroadcastSessionRequest( + session_id=session_id, + artifact_id=artifact_id, + requested_view_id=requested_view_id or "", + epoch=int(epoch), + fanout=int(fanout), + root_replica_id=root_replica_id or "", + strict_parent=bool(strict_parent), + max_attempts=int(max_attempts), + ) + request.target_worker_ids.extend(target_worker_ids or []) + request.target_daemon_ids.extend(target_daemon_ids or []) + return self._unary_call( + self.stub_v2.CreateBroadcastSession, + request, + timeout=float(timeout_s), + retries=1, + ) +``` + +In `tensorcast/api/store/__init__.py`, add `from dataclasses import dataclass` near the other standard-library imports, then add a small return dataclass: + +```python +@dataclass(frozen=True, slots=True) +class BroadcastSessionHandle: + session_id: str +``` + +and `Store.create_broadcast_session()` that calls `self._runtime.ensure_client().create_broadcast_session()` with the SDK method arguments and returns `BroadcastSessionHandle(session_id=response.session_id)`. + +- [ ] **Step 4: Implement daemon forwarding** + +Add a daemon service handler that maps daemon request fields to `global_store::CreateBroadcastSessionRequest` using the existing `GlobalStoreClient` stub. Return `BROADCAST_SESSION_STATUS_ERROR` when Global Store is disconnected or returns a non-OK status. Keep this handler thin; planning remains in Global Store. + +- [ ] **Step 5: Run API tests and commit** + +Run: + +```bash +pytest tests/python/api/test_broadcast_session_api.py tests/python/api/test_prefetch_operation.py::test_prefetch_forwards_broadcast_context_hint -v +``` + +Expected: PASS. + +Commit: + +```bash +git add proto/tensorcast/daemon/v2/store_daemon.proto tensorcast/daemon_ctl.py tensorcast/api/store/__init__.py tests/python/api/test_broadcast_session_api.py daemon/service +git commit -m "feat(daemon): expose broadcast session creation" +``` + +### Task 7: End-To-End Regression And Documentation Updates + +**Files:** +- Test: add `tests/python/global_store/test_broadcast_e2e.py` +- Modify: `tensorcast/global_store/README.md` +- Modify: `core/store/README.md` only if broadcast hint behavior changes public materialization semantics. +- Modify: `docs/designs/0117-control-plane-tree-broadcast-phase2.md` if implementation changes any accepted interface. + +- [ ] **Step 1: Add lightweight Global Store E2E test** + +Create `tests/python/global_store/test_broadcast_e2e.py`: + +```python +from __future__ import annotations + +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.global_store.v1 import global_store_pb2 + + +def _register_worker(servicer, context, idx: int) -> str: + response = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id=f"daemon-e2e-{idx}", + node_id=f"node-e2e-{idx}", + node_address=f"10.30.0.{idx}", + grpc_port=53000 + idx, + p2p_port=54000 + idx, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + context, + ) + assert response.status == global_store_pb2.STATUS_OK + return response.worker_id + + +def _register_exportable_replica(servicer, context, worker_id: str, idx: int) -> str: + request = global_store_pb2.RegisterReplicaRequest( + artifact_id="mi2:model-e2e", + worker_id=worker_id, + max_concurrency=4, + ) + request.memory_info.node_id = f"node-e2e-{idx}" + request.memory_info.node_address = f"10.30.0.{idx}" + request.memory_info.node_port = 54000 + idx + request.memory_info.memory_size = 1024 + request.memory_info.memory_type = common_pb2.MemoryType.MEMORY_TYPE_GPU + request.memory_info.device_id = 0 + request.memory_info.transport.export_state = ( + common_pb2.ReplicaTransportMetadata.EXPORT_STATE_EXPORTABLE + ) + request.memory_info.transport.remote_memory_keys.append(f"rk-e2e-{idx}") + request.memory_info.transport.buffer_sizes.append(1024) + response = servicer.RegisterReplica(request, context) + assert response.status == global_store_pb2.STATUS_OK + return response.replica_id + + +def _request_broadcast_transport(servicer, context, session_id: str, worker_id: str, idx: int, request_id: str): + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id="mi2:model-e2e", + source_node_id=f"node-e2e-{idx}", + source_address=f"10.30.0.{idx}", + source_port=54000 + idx, + requester_worker_id=worker_id, + request_id=request_id, + ) + request.local_memory_info.memory_type = common_pb2.MemoryType.MEMORY_TYPE_GPU + request.local_memory_info.device_id = 0 + request.broadcast.session_id = session_id + request.broadcast.strict_parent = True + response = servicer.RequestReplicaTransport(request, context) + assert response.status == global_store_pb2.STATUS_OK + return response + + +def test_tree_broadcast_promotes_first_child_to_second_layer_parent(servicer, test_context): + root = _register_worker(servicer, test_context, 1) + child1 = _register_worker(servicer, test_context, 2) + child2 = _register_worker(servicer, test_context, 3) + root_replica = _register_exportable_replica(servicer, test_context, root, 1) + create = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-e2e", + artifact_id="mi2:model-e2e", + epoch=1, + fanout=1, + root_replica_id=root_replica, + strict_parent=True, + max_attempts=3, + targets=[ + global_store_pb2.BroadcastTargetIdentity(worker_id=child1), + global_store_pb2.BroadcastTargetIdentity(worker_id=child2), + ], + ), + test_context, + ) + assert create.status == global_store_pb2.STATUS_OK + + first = _request_broadcast_transport(servicer, test_context, "session-e2e", child1, 2, "request-child-1") + assert first.remote_memory_info.node_id == "node-e2e-1" + _register_exportable_replica(servicer, test_context, child1, 2) + complete = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=first.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_SUCCESS, + ), + test_context, + ) + assert complete.status == global_store_pb2.STATUS_OK + + second = _request_broadcast_transport(servicer, test_context, "session-e2e", child2, 3, "request-child-2") + assert second.remote_memory_info.node_id in {"node-e2e-1", "node-e2e-2"} +``` + +- [ ] **Step 2: Run E2E and core regression suites** + +Run: + +```bash +pytest tests/python/global_store/test_broadcast_repository.py tests/python/global_store/test_broadcast_service.py tests/python/global_store/test_broadcast_rpc.py tests/python/global_store/test_broadcast_transport.py tests/python/global_store/test_broadcast_e2e.py -v +pytest tests/python/api/test_prefetch_operation.py tests/python/api/test_daemon_ctl_broadcast_hint.py tests/python/api/test_broadcast_session_api.py -v +bazel test //daemon:materialization_policy_utils_test --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +bazel test //core/store/materialization/control:materialize_orchestrator_reselection_test --test_env=TENSORCAST_CUDA_BACKEND=fake --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: PASS. + +- [ ] **Step 3: Update docs** + +In `tensorcast/global_store/README.md`, add a short section under the transport scheduling area: + +```markdown +### Broadcast Sessions + +Broadcast sessions coordinate strict tree dissemination for model-weight prefetch. A session records the artifact/view/epoch, target workers, fanout, and parent-child edges. Broadcast-tagged `RequestReplicaTransport` calls resolve to the parent replica assigned by the active edge; untagged requests continue to use group dispatch or ordinary source selection. +``` + +Update `docs/designs/0117-control-plane-tree-broadcast-phase2.md` if implementation changes a proto field name, state name, or completion rule from the accepted design. + +- [ ] **Step 4: Final verification and commit** + +Run: + +```bash +ruff check tensorcast/global_store tensorcast/api tests/python/global_store tests/python/api +ruff format tensorcast/global_store tensorcast/api tests/python/global_store tests/python/api +pytest tests/python/global_store/test_broadcast_repository.py tests/python/global_store/test_broadcast_service.py tests/python/global_store/test_broadcast_rpc.py tests/python/global_store/test_broadcast_transport.py tests/python/global_store/test_broadcast_e2e.py -v +pytest tests/python/api/test_prefetch_operation.py tests/python/api/test_daemon_ctl_broadcast_hint.py tests/python/api/test_broadcast_session_api.py -v +bazel test //daemon:materialization_policy_utils_test --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +bazel test //core/store/materialization/control:materialize_orchestrator_reselection_test --test_env=TENSORCAST_CUDA_BACKEND=fake --test_output=errors --noshow_progress --noshow_loading_progress --ui_event_filters=warning,error +``` + +Expected: PASS for all commands. + +Commit: + +```bash +git add tests/python/global_store/test_broadcast_e2e.py tensorcast/global_store/README.md docs/designs/0117-control-plane-tree-broadcast-phase2.md +git commit -m "test: cover broadcast tree dissemination" +``` + +# Self-Review Checklist + +- [ ] Schema changes are represented in `schema.sql` and migration `0019_broadcast_sessions.py`. +- [ ] Global Store broadcast RPCs never move artifact bytes. +- [ ] SDK broadcast session creation goes through Store Daemon, not directly to Global Store. +- [ ] Broadcast source selection uses edge-assigned parents only in strict mode. +- [ ] Parent selection still enforces heartbeat, accepting-new-requests, capacity, export state, remote memory keys, and buffer sizes. +- [ ] `FAILED`, `EXPIRED`, and `CANCELLED` transport outcomes do not advance tree progress. +- [ ] Broadcast success waits for child replica visibility after registration. +- [ ] Existing unhinted materialization and Phase 1 group dispatch tests still pass. From 8ac8c826a189df2540f182a13972c82c0b96678a Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 16:42:39 +0800 Subject: [PATCH 20/49] feat(global-store): add broadcast state repository --- schema.sql | 56 ++ .../migrations/0019_broadcast_sessions.py | 88 +++ tensorcast/global_store/models/__init__.py | 14 + tensorcast/global_store/models/broadcast.py | 97 +++ .../global_store/repositories/__init__.py | 2 + .../repositories/broadcast_repository.py | 731 ++++++++++++++++++ tests/python/global_store/conftest.py | 2 + .../global_store/test_broadcast_repository.py | 163 ++++ 8 files changed, 1153 insertions(+) create mode 100644 tensorcast/global_store/migrations/0019_broadcast_sessions.py create mode 100644 tensorcast/global_store/models/broadcast.py create mode 100644 tensorcast/global_store/repositories/broadcast_repository.py create mode 100644 tests/python/global_store/test_broadcast_repository.py diff --git a/schema.sql b/schema.sql index 2612f1d4..e9d518ee 100644 --- a/schema.sql +++ b/schema.sql @@ -286,6 +286,8 @@ CREATE TABLE IF NOT EXISTS artifact_transports ( group_part_id TEXT NULL, group_priority INTEGER NULL, group_epoch BIGINT NULL, + broadcast_session_id TEXT NULL, + broadcast_edge_id TEXT NULL, completion_outcome TEXT NULL, completion_detail TEXT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, @@ -301,6 +303,60 @@ CREATE INDEX IF NOT EXISTS idx_artifact_transports_completed_at ON artifact_tran CREATE UNIQUE INDEX IF NOT EXISTS idx_artifact_transports_request_id_unique ON artifact_transports(request_id); CREATE INDEX IF NOT EXISTS idx_artifact_transports_group_status ON artifact_transports(group_kind, group_id, group_epoch, status); CREATE INDEX IF NOT EXISTS idx_artifact_transports_requester_status ON artifact_transports(requester_worker_id, status); +CREATE INDEX IF NOT EXISTS idx_artifact_transports_broadcast ON artifact_transports(broadcast_session_id, broadcast_edge_id, status); + +-- Tree broadcast session state. +CREATE TABLE IF NOT EXISTS broadcast_sessions ( + session_id TEXT PRIMARY KEY, + artifact_id TEXT NOT NULL, + requested_view_id TEXT NULL, + epoch BIGINT NOT NULL, + fanout INTEGER NOT NULL, + max_attempts INTEGER NOT NULL DEFAULT 3, + strict_parent BOOLEAN NOT NULL DEFAULT TRUE, + state TEXT CHECK (state IN ('planning','active','completed','failed','cancelled')) NOT NULL, + root_replica_id UUID NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL +); + +CREATE TABLE IF NOT EXISTS broadcast_targets ( + session_id TEXT NOT NULL, + target_worker_id TEXT NOT NULL, + target_daemon_id TEXT NULL, + state TEXT CHECK (state IN ('pending','assigned','materializing','completed','failed','cancelled')) NOT NULL, + level INTEGER NULL, + attempt INTEGER NOT NULL DEFAULT 0, + assigned_edge_id TEXT NULL, + completed_replica_id UUID NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, + PRIMARY KEY (session_id, target_worker_id) +); + +CREATE TABLE IF NOT EXISTS broadcast_edges ( + edge_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + parent_worker_id TEXT NOT NULL, + parent_replica_id UUID NOT NULL, + child_worker_id TEXT NOT NULL, + level INTEGER NOT NULL, + attempt INTEGER NOT NULL DEFAULT 1, + state TEXT CHECK (state IN ('planned','assigned','materializing','completed','failed','cancelled')) NOT NULL, + transport_request_id TEXT NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL +); + +CREATE INDEX IF NOT EXISTS idx_broadcast_targets_session_state + ON broadcast_targets(session_id, state, updated_at); +CREATE INDEX IF NOT EXISTS idx_broadcast_edges_session_child_state + ON broadcast_edges(session_id, child_worker_id, state); -- Pending request queue for group-aware transport scheduling. CREATE TABLE IF NOT EXISTS pending_transport_requests ( diff --git a/tensorcast/global_store/migrations/0019_broadcast_sessions.py b/tensorcast/global_store/migrations/0019_broadcast_sessions.py new file mode 100644 index 00000000..3af26b46 --- /dev/null +++ b/tensorcast/global_store/migrations/0019_broadcast_sessions.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026, TensorCast Team. + +"""Migration 0019: add broadcast session state tables.""" + +from __future__ import annotations + +from duckdb import DuckDBPyConnection + +UP_QUERIES: tuple[str, ...] = ( + """ + CREATE TABLE IF NOT EXISTS broadcast_sessions ( + session_id TEXT PRIMARY KEY, + artifact_id TEXT NOT NULL, + requested_view_id TEXT NULL, + epoch BIGINT NOT NULL, + fanout INTEGER NOT NULL, + max_attempts INTEGER NOT NULL DEFAULT 3, + strict_parent BOOLEAN NOT NULL DEFAULT TRUE, + state TEXT CHECK (state IN ('planning','active','completed','failed','cancelled')) NOT NULL, + root_replica_id UUID NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL + ); + """, + """ + CREATE TABLE IF NOT EXISTS broadcast_targets ( + session_id TEXT NOT NULL, + target_worker_id TEXT NOT NULL, + target_daemon_id TEXT NULL, + state TEXT CHECK (state IN ('pending','assigned','materializing','completed','failed','cancelled')) NOT NULL, + level INTEGER NULL, + attempt INTEGER NOT NULL DEFAULT 0, + assigned_edge_id TEXT NULL, + completed_replica_id UUID NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, + PRIMARY KEY (session_id, target_worker_id) + ); + """, + """ + CREATE TABLE IF NOT EXISTS broadcast_edges ( + edge_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + parent_worker_id TEXT NOT NULL, + parent_replica_id UUID NOT NULL, + child_worker_id TEXT NOT NULL, + level INTEGER NOT NULL, + attempt INTEGER NOT NULL DEFAULT 1, + state TEXT CHECK (state IN ('planned','assigned','materializing','completed','failed','cancelled')) NOT NULL, + transport_request_id TEXT NULL, + failure_reason TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL + ); + """, + "CREATE INDEX IF NOT EXISTS idx_broadcast_targets_session_state ON broadcast_targets(session_id, state, updated_at);", + "CREATE INDEX IF NOT EXISTS idx_broadcast_edges_session_child_state ON broadcast_edges(session_id, child_worker_id, state);", + "ALTER TABLE artifact_transports ADD COLUMN IF NOT EXISTS broadcast_session_id TEXT NULL;", + "ALTER TABLE artifact_transports ADD COLUMN IF NOT EXISTS broadcast_edge_id TEXT NULL;", + "CREATE INDEX IF NOT EXISTS idx_artifact_transports_broadcast ON artifact_transports(broadcast_session_id, broadcast_edge_id, status);", +) + +DOWN_QUERIES: tuple[str, ...] = ( + "DROP INDEX IF EXISTS idx_artifact_transports_broadcast;", + "ALTER TABLE artifact_transports DROP COLUMN IF EXISTS broadcast_edge_id;", + "ALTER TABLE artifact_transports DROP COLUMN IF EXISTS broadcast_session_id;", + "DROP INDEX IF EXISTS idx_broadcast_edges_session_child_state;", + "DROP INDEX IF EXISTS idx_broadcast_targets_session_state;", + "DROP TABLE IF EXISTS broadcast_edges;", + "DROP TABLE IF EXISTS broadcast_targets;", + "DROP TABLE IF EXISTS broadcast_sessions;", +) + + +def upgrade(conn: DuckDBPyConnection) -> None: + """Apply migration 0019.""" + for query in UP_QUERIES: + conn.execute(query) + + +def downgrade(conn: DuckDBPyConnection) -> None: + """Rollback migration 0019.""" + for query in DOWN_QUERIES: + conn.execute(query) diff --git a/tensorcast/global_store/models/__init__.py b/tensorcast/global_store/models/__init__.py index 4a72b8ed..62933c95 100644 --- a/tensorcast/global_store/models/__init__.py +++ b/tensorcast/global_store/models/__init__.py @@ -2,6 +2,14 @@ """Domain models for Global Store.""" +from .broadcast import ( + BroadcastEdge, + BroadcastEdgeState, + BroadcastSession, + BroadcastSessionState, + BroadcastTarget, + BroadcastTargetState, +) from .instance import Instance from .pending_transport_request import ( PendingTransportRequest, @@ -24,6 +32,12 @@ from .worker import Worker, WorkerMemoryTierState __all__ = [ + "BroadcastEdge", + "BroadcastEdgeState", + "BroadcastSession", + "BroadcastSessionState", + "BroadcastTarget", + "BroadcastTargetState", "Instance", "Replica", "Transport", diff --git a/tensorcast/global_store/models/broadcast.py b/tensorcast/global_store/models/broadcast.py new file mode 100644 index 00000000..7b365d82 --- /dev/null +++ b/tensorcast/global_store/models/broadcast.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026, TensorCast Team. + +"""Broadcast domain models.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from uuid import UUID + + +class BroadcastSessionState(str, Enum): + """Lifecycle states for a broadcast session.""" + + PLANNING = "planning" + ACTIVE = "active" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class BroadcastTargetState(str, Enum): + """Lifecycle states for one broadcast target worker.""" + + PENDING = "pending" + ASSIGNED = "assigned" + MATERIALIZING = "materializing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class BroadcastEdgeState(str, Enum): + """Lifecycle states for one parent-child broadcast attempt.""" + + PLANNED = "planned" + ASSIGNED = "assigned" + MATERIALIZING = "materializing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class BroadcastSession: + """Persistent state for a tree broadcast attempt.""" + + session_id: str + artifact_id: str + requested_view_id: str | None + epoch: int + fanout: int + max_attempts: int + strict_parent: bool + state: BroadcastSessionState + root_replica_id: UUID | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + completed_at: datetime | None = None + + +@dataclass +class BroadcastTarget: + """Persistent state for one target in a broadcast session.""" + + session_id: str + target_worker_id: str + target_daemon_id: str | None + state: BroadcastTargetState + level: int | None = None + attempt: int = 0 + assigned_edge_id: str | None = None + completed_replica_id: UUID | None = None + failure_reason: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + completed_at: datetime | None = None + + +@dataclass +class BroadcastEdge: + """Persistent state for one parent-child broadcast edge attempt.""" + + edge_id: str + session_id: str + parent_worker_id: str + parent_replica_id: UUID + child_worker_id: str + level: int + attempt: int + state: BroadcastEdgeState + transport_request_id: str | None = None + failure_reason: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + completed_at: datetime | None = None diff --git a/tensorcast/global_store/repositories/__init__.py b/tensorcast/global_store/repositories/__init__.py index 38fa4b45..a1e27419 100644 --- a/tensorcast/global_store/repositories/__init__.py +++ b/tensorcast/global_store/repositories/__init__.py @@ -9,6 +9,7 @@ from .assembly_layout_binding_repository import AssemblyLayoutBindingRepository from .assembly_readiness_cut_repository import AssemblyReadinessCutRepository from .assembly_slot_occupancy_repository import AssemblySlotOccupancyRepository +from .broadcast_repository import BroadcastRepository from .chunk_directory_repository import ChunkDirectoryRepository from .cluster_info_repository import ClusterInfoRepository from .idempotency_repository import IdempotencyRepository @@ -39,6 +40,7 @@ "AssemblyLayoutBindingRepository", "AssemblyReadinessCutRepository", "AssemblySlotOccupancyRepository", + "BroadcastRepository", "InstanceRepository", "IdempotencyRepository", "MemoryTierSnapshotRepository", diff --git a/tensorcast/global_store/repositories/broadcast_repository.py b/tensorcast/global_store/repositories/broadcast_repository.py new file mode 100644 index 00000000..4070f8f7 --- /dev/null +++ b/tensorcast/global_store/repositories/broadcast_repository.py @@ -0,0 +1,731 @@ +# Copyright (c) 2026, TensorCast Team. + +"""Repository for broadcast state data access.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from uuid import UUID + +from duckdb import DuckDBPyConnection + +from tensorcast.global_store.models import ( + BroadcastEdge, + BroadcastEdgeState, + BroadcastSession, + BroadcastSessionState, + BroadcastTarget, + BroadcastTargetState, +) +from tensorcast.global_store.repositories.base import BaseRepository + + +class BroadcastRepository(BaseRepository): + """Repository for managing persistent tree broadcast state.""" + + _ACTIVE_EDGE_STATES = ( + BroadcastEdgeState.PLANNED, + BroadcastEdgeState.ASSIGNED, + BroadcastEdgeState.MATERIALIZING, + ) + _SESSION_PROJECTION = ( + "session_id, artifact_id, requested_view_id, epoch, fanout, max_attempts, " + "strict_parent, state, root_replica_id, created_at, updated_at, completed_at" + ) + _TARGET_PROJECTION = ( + "session_id, target_worker_id, target_daemon_id, state, level, attempt, " + "assigned_edge_id, completed_replica_id, failure_reason, created_at, " + "updated_at, completed_at" + ) + _EDGE_PROJECTION = ( + "edge_id, session_id, parent_worker_id, parent_replica_id, child_worker_id, " + "level, attempt, state, transport_request_id, failure_reason, created_at, " + "updated_at, completed_at" + ) + + def create_session( + self, + session: BroadcastSession, + cursor: DuckDBPyConnection | None = None, + ) -> BroadcastSession: + """Create a broadcast session row.""" + normalized_session_id = self._normalize_required_text(session.session_id) + normalized_artifact_id = self._normalize_required_text(session.artifact_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + cursor.execute( + """ + INSERT INTO broadcast_sessions ( + session_id, artifact_id, requested_view_id, epoch, fanout, + max_attempts, strict_parent, state, root_replica_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + normalized_session_id, + normalized_artifact_id, + self._normalize_optional_text(session.requested_view_id), + int(session.epoch), + int(session.fanout), + int(session.max_attempts), + bool(session.strict_parent), + session.state.value, + self._uuid_to_text(session.root_replica_id), + ], + ) + session.session_id = normalized_session_id + session.artifact_id = normalized_artifact_id + return session + finally: + if owns_cursor: + cursor.close() + + def find_session( + self, + session_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> BroadcastSession | None: + """Find a broadcast session by ID.""" + normalized_session_id = self._normalize_required_text(session_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._SESSION_PROJECTION} + FROM broadcast_sessions + WHERE session_id = ? + """, + [normalized_session_id], + ) + row = query.fetchone() + if row is None: + return None + assert query.description is not None + columns = [desc[0] for desc in query.description] + return self._row_to_session(row, columns) + finally: + if owns_cursor: + cursor.close() + + def update_session_state( + self, + session_id: str, + state: BroadcastSessionState, + cursor: DuckDBPyConnection | None = None, + ) -> bool: + """Update the lifecycle state for a broadcast session.""" + normalized_session_id = self._normalize_required_text(session_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + completed_sql = ( + ", completed_at = CURRENT_TIMESTAMP" + if state + in ( + BroadcastSessionState.COMPLETED, + BroadcastSessionState.FAILED, + BroadcastSessionState.CANCELLED, + ) + else "" + ) + row = cursor.execute( + f""" + UPDATE broadcast_sessions + SET state = ?, updated_at = CURRENT_TIMESTAMP {completed_sql} + WHERE session_id = ? + RETURNING session_id + """, + [state.value, normalized_session_id], + ).fetchone() + return row is not None + finally: + if owns_cursor: + cursor.close() + + def upsert_target( + self, + target: BroadcastTarget, + cursor: DuckDBPyConnection | None = None, + ) -> BroadcastTarget: + """Insert or update one broadcast target row.""" + normalized_session_id = self._normalize_required_text(target.session_id) + normalized_target_worker_id = self._normalize_required_text( + target.target_worker_id + ) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + updated_at = datetime.now() + cursor.execute( + """ + INSERT INTO broadcast_targets ( + session_id, target_worker_id, target_daemon_id, state, level, + attempt, assigned_edge_id, completed_replica_id, failure_reason, + completed_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (session_id, target_worker_id) DO UPDATE SET + target_daemon_id = excluded.target_daemon_id, + state = excluded.state, + level = excluded.level, + attempt = excluded.attempt, + assigned_edge_id = excluded.assigned_edge_id, + completed_replica_id = excluded.completed_replica_id, + failure_reason = excluded.failure_reason, + updated_at = ?, + completed_at = excluded.completed_at + """, + [ + normalized_session_id, + normalized_target_worker_id, + self._normalize_optional_text(target.target_daemon_id), + target.state.value, + self._normalize_optional_int(target.level), + int(target.attempt), + self._normalize_optional_text(target.assigned_edge_id), + self._uuid_to_text(target.completed_replica_id), + self._normalize_optional_text(target.failure_reason), + target.completed_at, + updated_at, + ], + ) + target.session_id = normalized_session_id + target.target_worker_id = normalized_target_worker_id + return target + finally: + if owns_cursor: + cursor.close() + + def find_target( + self, + session_id: str, + target_worker_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> BroadcastTarget | None: + """Find a broadcast target by session and worker ID.""" + normalized_session_id = self._normalize_required_text(session_id) + normalized_target_worker_id = self._normalize_required_text(target_worker_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._TARGET_PROJECTION} + FROM broadcast_targets + WHERE session_id = ? AND target_worker_id = ? + """, + [normalized_session_id, normalized_target_worker_id], + ) + row = query.fetchone() + if row is None: + return None + assert query.description is not None + columns = [desc[0] for desc in query.description] + return self._row_to_target(row, columns) + finally: + if owns_cursor: + cursor.close() + + def list_targets( + self, + session_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> list[BroadcastTarget]: + """List all targets for a broadcast session.""" + normalized_session_id = self._normalize_required_text(session_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._TARGET_PROJECTION} + FROM broadcast_targets + WHERE session_id = ? + ORDER BY target_worker_id ASC + """, + [normalized_session_id], + ) + rows = query.fetchall() + assert query.description is not None + columns = [desc[0] for desc in query.description] + return [self._row_to_target(row, columns) for row in rows] + finally: + if owns_cursor: + cursor.close() + + def list_targets_by_state( + self, + session_id: str, + state: BroadcastTargetState, + limit: int, + cursor: DuckDBPyConnection | None = None, + ) -> list[BroadcastTarget]: + """List targets for a session with the requested state.""" + normalized_session_id = self._normalize_required_text(session_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._TARGET_PROJECTION} + FROM broadcast_targets + WHERE session_id = ? AND state = ? + ORDER BY updated_at ASC, target_worker_id ASC + LIMIT ? + """, + [normalized_session_id, state.value, int(limit)], + ) + rows = query.fetchall() + assert query.description is not None + columns = [desc[0] for desc in query.description] + return [self._row_to_target(row, columns) for row in rows] + finally: + if owns_cursor: + cursor.close() + + def create_edge( + self, + edge: BroadcastEdge, + cursor: DuckDBPyConnection | None = None, + ) -> BroadcastEdge: + """Create a broadcast edge row.""" + normalized_edge_id = self._normalize_required_text(edge.edge_id) + normalized_session_id = self._normalize_required_text(edge.session_id) + normalized_child_worker_id = self._normalize_required_text(edge.child_worker_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + if edge.state in self._ACTIVE_EDGE_STATES: + existing = self.find_active_edge_for_child( + normalized_session_id, + normalized_child_worker_id, + cursor=cursor, + ) + if existing is not None: + raise ValueError( + "active broadcast edge already exists for child " + f"{normalized_child_worker_id} in session {normalized_session_id}" + ) + + cursor.execute( + """ + INSERT INTO broadcast_edges ( + edge_id, session_id, parent_worker_id, parent_replica_id, + child_worker_id, level, attempt, state, transport_request_id, + failure_reason + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + normalized_edge_id, + normalized_session_id, + self._normalize_required_text(edge.parent_worker_id), + self._uuid_to_text(edge.parent_replica_id), + normalized_child_worker_id, + int(edge.level), + int(edge.attempt), + edge.state.value, + self._normalize_optional_text(edge.transport_request_id), + self._normalize_optional_text(edge.failure_reason), + ], + ) + edge.edge_id = normalized_edge_id + edge.session_id = normalized_session_id + edge.child_worker_id = normalized_child_worker_id + return edge + finally: + if owns_cursor: + cursor.close() + + def find_edge( + self, + edge_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> BroadcastEdge | None: + """Find a broadcast edge by ID.""" + normalized_edge_id = self._normalize_required_text(edge_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._EDGE_PROJECTION} + FROM broadcast_edges + WHERE edge_id = ? + """, + [normalized_edge_id], + ) + row = query.fetchone() + if row is None: + return None + assert query.description is not None + columns = [desc[0] for desc in query.description] + return self._row_to_edge(row, columns) + finally: + if owns_cursor: + cursor.close() + + def find_active_edge_for_child( + self, + session_id: str, + child_worker_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> BroadcastEdge | None: + """Find an active edge for a session/child worker pair.""" + normalized_session_id = self._normalize_required_text(session_id) + normalized_child_worker_id = self._normalize_required_text(child_worker_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._EDGE_PROJECTION} + FROM broadcast_edges + WHERE session_id = ? + AND child_worker_id = ? + AND state IN ('planned', 'assigned', 'materializing') + ORDER BY attempt DESC, updated_at DESC + LIMIT 1 + """, + [normalized_session_id, normalized_child_worker_id], + ) + row = query.fetchone() + if row is None: + return None + assert query.description is not None + columns = [desc[0] for desc in query.description] + return self._row_to_edge(row, columns) + finally: + if owns_cursor: + cursor.close() + + def mark_edge_materializing( + self, + edge_id: str, + transport_request_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> bool: + """Mark an edge and its target as materializing.""" + normalized_edge_id = self._normalize_required_text(edge_id) + normalized_transport_request_id = self._normalize_required_text( + transport_request_id + ) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + edge = self.find_edge(normalized_edge_id, cursor=cursor) + if edge is None: + return False + row = cursor.execute( + """ + UPDATE broadcast_edges + SET state = 'materializing', + transport_request_id = ?, + updated_at = CURRENT_TIMESTAMP + WHERE edge_id = ? + RETURNING edge_id + """, + [normalized_transport_request_id, normalized_edge_id], + ).fetchone() + if row is None: + return False + cursor.execute( + """ + UPDATE broadcast_targets + SET state = 'materializing', + assigned_edge_id = ?, + level = ?, + attempt = ?, + updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? AND target_worker_id = ? + """, + [ + normalized_edge_id, + int(edge.level), + int(edge.attempt), + edge.session_id, + edge.child_worker_id, + ], + ) + return True + finally: + if owns_cursor: + cursor.close() + + def mark_edge_failed( + self, + edge_id: str, + reason: str, + cursor: DuckDBPyConnection | None = None, + ) -> bool: + """Mark an edge and its target as failed.""" + normalized_edge_id = self._normalize_required_text(edge_id) + normalized_reason = self._normalize_required_text(reason) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + edge = self.find_edge(normalized_edge_id, cursor=cursor) + if edge is None: + return False + row = cursor.execute( + """ + UPDATE broadcast_edges + SET state = 'failed', + failure_reason = ?, + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE edge_id = ? + RETURNING edge_id + """, + [normalized_reason, normalized_edge_id], + ).fetchone() + if row is None: + return False + cursor.execute( + """ + UPDATE broadcast_targets + SET state = 'failed', + failure_reason = ?, + assigned_edge_id = ?, + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE session_id = ? AND target_worker_id = ? + """, + [ + normalized_reason, + normalized_edge_id, + edge.session_id, + edge.child_worker_id, + ], + ) + return True + finally: + if owns_cursor: + cursor.close() + + def mark_edge_completed( + self, + edge_id: str, + completed_replica_id: UUID | None, + cursor: DuckDBPyConnection | None = None, + ) -> bool: + """Mark an edge and its target as completed.""" + normalized_edge_id = self._normalize_required_text(edge_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + edge = self.find_edge(normalized_edge_id, cursor=cursor) + if edge is None: + return False + row = cursor.execute( + """ + UPDATE broadcast_edges + SET state = 'completed', + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE edge_id = ? + RETURNING edge_id + """, + [normalized_edge_id], + ).fetchone() + if row is None: + return False + cursor.execute( + """ + UPDATE broadcast_targets + SET state = 'completed', + completed_replica_id = ?, + assigned_edge_id = ?, + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE session_id = ? AND target_worker_id = ? + """, + [ + self._uuid_to_text(completed_replica_id), + normalized_edge_id, + edge.session_id, + edge.child_worker_id, + ], + ) + return True + finally: + if owns_cursor: + cursor.close() + + def count_incomplete_targets( + self, + session_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> int: + """Count non-terminal targets for a broadcast session.""" + normalized_session_id = self._normalize_required_text(session_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + row = cursor.execute( + """ + SELECT COUNT(*) + FROM broadcast_targets + WHERE session_id = ? + AND state NOT IN ('completed', 'failed', 'cancelled') + """, + [normalized_session_id], + ).fetchone() + return int(row[0]) if row is not None else 0 + finally: + if owns_cursor: + cursor.close() + + @staticmethod + def _normalize_optional_text(value: str | None) -> str | None: + return BroadcastRepository._normalize_text_token(value) + + @staticmethod + def _normalize_required_text(value: str | None) -> str: + normalized = BroadcastRepository._normalize_text_token(value) + if normalized is None: + raise ValueError("required broadcast field is missing") + return normalized + + @staticmethod + def _normalize_optional_int(value: int | None) -> int | None: + if value is None: + return None + return int(value) + + @staticmethod + def _normalize_text_token(value: str | None) -> str | None: + if value is None: + return None + stripped = str(value).strip() + if not stripped: + return None + return BroadcastRepository._collapse_exact_double(stripped) + + @staticmethod + def _collapse_exact_double(value: str) -> str: + size = len(value) + if size < 8 or (size % 2) != 0: + return value + half = size // 2 + if value[:half] != value[half:]: + return value + return value[:half] + + @staticmethod + def _uuid_to_text(value: UUID | None) -> str | None: + if value is None: + return None + return str(value) + + @staticmethod + def _uuid_or_none(raw: Any) -> UUID | None: + if raw is None: + return None + if isinstance(raw, UUID): + return raw + return UUID(str(raw)) + + @staticmethod + def _coerce_datetime_optional(raw: Any) -> datetime | None: + if raw is None: + return None + if isinstance(raw, datetime): + return raw + return datetime.fromisoformat(str(raw)) + + @classmethod + def _row_to_session( + cls, + row: tuple[Any, ...], + columns: list[str], + ) -> BroadcastSession: + idx = {column: i for i, column in enumerate(columns)} + return BroadcastSession( + session_id=str(row[idx["session_id"]]), + artifact_id=str(row[idx["artifact_id"]]), + requested_view_id=cls._normalize_optional_text( + row[idx["requested_view_id"]] + ), + epoch=int(row[idx["epoch"]]), + fanout=int(row[idx["fanout"]]), + max_attempts=int(row[idx["max_attempts"]]), + strict_parent=bool(row[idx["strict_parent"]]), + state=BroadcastSessionState(str(row[idx["state"]])), + root_replica_id=cls._uuid_or_none(row[idx["root_replica_id"]]), + created_at=cls._coerce_datetime_optional(row[idx["created_at"]]), + updated_at=cls._coerce_datetime_optional(row[idx["updated_at"]]), + completed_at=cls._coerce_datetime_optional(row[idx["completed_at"]]), + ) + + @classmethod + def _row_to_target( + cls, + row: tuple[Any, ...], + columns: list[str], + ) -> BroadcastTarget: + idx = {column: i for i, column in enumerate(columns)} + raw_level = row[idx["level"]] + return BroadcastTarget( + session_id=str(row[idx["session_id"]]), + target_worker_id=str(row[idx["target_worker_id"]]), + target_daemon_id=cls._normalize_optional_text( + row[idx["target_daemon_id"]] + ), + state=BroadcastTargetState(str(row[idx["state"]])), + level=int(raw_level) if raw_level is not None else None, + attempt=int(row[idx["attempt"]]), + assigned_edge_id=cls._normalize_optional_text( + row[idx["assigned_edge_id"]] + ), + completed_replica_id=cls._uuid_or_none(row[idx["completed_replica_id"]]), + failure_reason=cls._normalize_optional_text(row[idx["failure_reason"]]), + created_at=cls._coerce_datetime_optional(row[idx["created_at"]]), + updated_at=cls._coerce_datetime_optional(row[idx["updated_at"]]), + completed_at=cls._coerce_datetime_optional(row[idx["completed_at"]]), + ) + + @classmethod + def _row_to_edge( + cls, + row: tuple[Any, ...], + columns: list[str], + ) -> BroadcastEdge: + idx = {column: i for i, column in enumerate(columns)} + parent_replica_id = cls._uuid_or_none(row[idx["parent_replica_id"]]) + if parent_replica_id is None: + raise ValueError("broadcast edge parent_replica_id is missing") + return BroadcastEdge( + edge_id=str(row[idx["edge_id"]]), + session_id=str(row[idx["session_id"]]), + parent_worker_id=str(row[idx["parent_worker_id"]]), + parent_replica_id=parent_replica_id, + child_worker_id=str(row[idx["child_worker_id"]]), + level=int(row[idx["level"]]), + attempt=int(row[idx["attempt"]]), + state=BroadcastEdgeState(str(row[idx["state"]])), + transport_request_id=cls._normalize_optional_text( + row[idx["transport_request_id"]] + ), + failure_reason=cls._normalize_optional_text(row[idx["failure_reason"]]), + created_at=cls._coerce_datetime_optional(row[idx["created_at"]]), + updated_at=cls._coerce_datetime_optional(row[idx["updated_at"]]), + completed_at=cls._coerce_datetime_optional(row[idx["completed_at"]]), + ) diff --git a/tests/python/global_store/conftest.py b/tests/python/global_store/conftest.py index 35444420..80244153 100644 --- a/tests/python/global_store/conftest.py +++ b/tests/python/global_store/conftest.py @@ -19,6 +19,7 @@ AssemblyLayoutBindingRepository, AssemblyReadinessCutRepository, AssemblySlotOccupancyRepository, + BroadcastRepository, InstanceRepository, LayoutSpecRepository, LeafRepository, @@ -211,6 +212,7 @@ def repositories(db_connection): "assembly_layout_binding": AssemblyLayoutBindingRepository(db_connection), "assembly_readiness_cut": AssemblyReadinessCutRepository(db_connection), "assembly_slot_occupancy": AssemblySlotOccupancyRepository(db_connection), + "broadcast": BroadcastRepository(db_connection), "proof": ProofRepository(db_connection), } diff --git a/tests/python/global_store/test_broadcast_repository.py b/tests/python/global_store/test_broadcast_repository.py new file mode 100644 index 00000000..0c030db3 --- /dev/null +++ b/tests/python/global_store/test_broadcast_repository.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from uuid import UUID + +from tensorcast.global_store.models import ( + BroadcastEdge, + BroadcastEdgeState, + BroadcastSession, + BroadcastSessionState, + BroadcastTarget, + BroadcastTargetState, +) +from tensorcast.global_store.repositories import BroadcastRepository + + +def test_broadcast_repository_creates_session_targets_and_edges(db_connection): + repo = BroadcastRepository(db_connection) + session = BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=42, + fanout=2, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + repo.create_session(session) + repo.upsert_target( + BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child-1", + target_daemon_id="daemon-child-1", + state=BroadcastTargetState.PENDING, + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child-1", + level=1, + attempt=1, + state=BroadcastEdgeState.PLANNED, + ) + ) + + loaded = repo.find_session("session-a") + assert loaded is not None + assert loaded.artifact_id == "mi2:test" + assert loaded.epoch == 42 + assert loaded.state is BroadcastSessionState.ACTIVE + + target = repo.find_target("session-a", "worker-child-1") + assert target is not None + assert target.target_daemon_id == "daemon-child-1" + assert target.state is BroadcastTargetState.PENDING + + edge = repo.find_active_edge_for_child("session-a", "worker-child-1") + assert edge is not None + assert edge.parent_worker_id == "worker-root" + assert edge.state is BroadcastEdgeState.PLANNED + + +def test_broadcast_repository_prevents_two_active_edges_for_child(db_connection): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + first = BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.PLANNED, + ) + repo.create_edge(first) + + try: + repo.create_edge( + BroadcastEdge( + edge_id="edge-2", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=2, + state=BroadcastEdgeState.ASSIGNED, + ) + ) + except Exception as exc: # noqa: BLE001 + assert "active" in str(exc).lower() or "constraint" in str(exc).lower() + else: + raise AssertionError("expected active edge uniqueness to reject duplicate child") + + +def test_broadcast_repository_marks_edge_completed_and_target_completed(db_connection): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + repo.upsert_target( + BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + state=BroadcastTargetState.MATERIALIZING, + assigned_edge_id="edge-1", + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.MATERIALIZING, + transport_request_id="transport-request-1", + ) + ) + + completed_replica_id = UUID("00000000-0000-0000-0000-000000000002") + assert repo.mark_edge_completed( + edge_id="edge-1", + completed_replica_id=completed_replica_id, + ) + edge = repo.find_edge("edge-1") + target = repo.find_target("session-a", "worker-child") + assert edge is not None + assert target is not None + assert edge.state is BroadcastEdgeState.COMPLETED + assert target.state is BroadcastTargetState.COMPLETED + assert target.completed_replica_id == completed_replica_id From c7a056123c4343f426d4ce4c0cafedba5a08073c Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 16:46:55 +0800 Subject: [PATCH 21/49] fix(global-store): enforce broadcast repository invariants --- tensorcast/global_store/models/broadcast.py | 8 +- .../repositories/broadcast_repository.py | 10 +++ .../global_store/test_broadcast_repository.py | 83 +++++++++++++++++++ 3 files changed, 97 insertions(+), 4 deletions(-) diff --git a/tensorcast/global_store/models/broadcast.py b/tensorcast/global_store/models/broadcast.py index 7b365d82..a5c3e318 100644 --- a/tensorcast/global_store/models/broadcast.py +++ b/tensorcast/global_store/models/broadcast.py @@ -53,7 +53,7 @@ class BroadcastSession: fanout: int max_attempts: int strict_parent: bool - state: BroadcastSessionState + state: BroadcastSessionState = BroadcastSessionState.PLANNING root_replica_id: UUID | None = None created_at: datetime | None = None updated_at: datetime | None = None @@ -67,7 +67,7 @@ class BroadcastTarget: session_id: str target_worker_id: str target_daemon_id: str | None - state: BroadcastTargetState + state: BroadcastTargetState = BroadcastTargetState.PENDING level: int | None = None attempt: int = 0 assigned_edge_id: str | None = None @@ -88,8 +88,8 @@ class BroadcastEdge: parent_replica_id: UUID child_worker_id: str level: int - attempt: int - state: BroadcastEdgeState + attempt: int = 1 + state: BroadcastEdgeState = BroadcastEdgeState.PLANNED transport_request_id: str | None = None failure_reason: str | None = None created_at: datetime | None = None diff --git a/tensorcast/global_store/repositories/broadcast_repository.py b/tensorcast/global_store/repositories/broadcast_repository.py index 4070f8f7..7c5d34af 100644 --- a/tensorcast/global_store/repositories/broadcast_repository.py +++ b/tensorcast/global_store/repositories/broadcast_repository.py @@ -430,6 +430,16 @@ def mark_edge_materializing( edge = self.find_edge(normalized_edge_id, cursor=cursor) if edge is None: return False + existing = self.find_active_edge_for_child( + edge.session_id, + edge.child_worker_id, + cursor=cursor, + ) + if existing is not None and existing.edge_id != edge.edge_id: + raise ValueError( + "active broadcast edge already exists for child " + f"{edge.child_worker_id} in session {edge.session_id}" + ) row = cursor.execute( """ UPDATE broadcast_edges diff --git a/tests/python/global_store/test_broadcast_repository.py b/tests/python/global_store/test_broadcast_repository.py index 0c030db3..2d879ab1 100644 --- a/tests/python/global_store/test_broadcast_repository.py +++ b/tests/python/global_store/test_broadcast_repository.py @@ -13,6 +13,36 @@ from tensorcast.global_store.repositories import BroadcastRepository +def test_broadcast_models_default_to_initial_states(): + session = BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=2, + max_attempts=3, + strict_parent=True, + ) + target = BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + ) + edge = BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + ) + + assert session.state is BroadcastSessionState.PLANNING + assert target.state is BroadcastTargetState.PENDING + assert edge.attempt == 1 + assert edge.state is BroadcastEdgeState.PLANNED + + def test_broadcast_repository_creates_session_targets_and_edges(db_connection): repo = BroadcastRepository(db_connection) session = BroadcastSession( @@ -111,6 +141,59 @@ def test_broadcast_repository_prevents_two_active_edges_for_child(db_connection) raise AssertionError("expected active edge uniqueness to reject duplicate child") +def test_broadcast_repository_prevents_materializing_edge_when_child_has_active_edge( + db_connection, +): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-active", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.PLANNED, + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-failed", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=2, + state=BroadcastEdgeState.FAILED, + ) + ) + + try: + repo.mark_edge_materializing( + edge_id="edge-failed", + transport_request_id="transport-request-1", + ) + except Exception as exc: # noqa: BLE001 + assert "active" in str(exc).lower() + else: + raise AssertionError("expected materializing transition to reject duplicate child") + + def test_broadcast_repository_marks_edge_completed_and_target_completed(db_connection): repo = BroadcastRepository(db_connection) repo.create_session( From 1f3c281f685b8a0629a2559956678839df6c2343 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 16:56:17 +0800 Subject: [PATCH 22/49] fix(global-store): harden broadcast edge transitions --- .../repositories/broadcast_repository.py | 409 ++++++++++-------- .../global_store/test_broadcast_repository.py | 186 +++++++- 2 files changed, 407 insertions(+), 188 deletions(-) diff --git a/tensorcast/global_store/repositories/broadcast_repository.py b/tensorcast/global_store/repositories/broadcast_repository.py index 7c5d34af..28b2f929 100644 --- a/tensorcast/global_store/repositories/broadcast_repository.py +++ b/tensorcast/global_store/repositories/broadcast_repository.py @@ -299,54 +299,59 @@ def create_edge( cursor: DuckDBPyConnection | None = None, ) -> BroadcastEdge: """Create a broadcast edge row.""" + if cursor is None: + conflict: ValueError | None = None + with self.transaction() as tx: + try: + return self.create_edge(edge, cursor=tx) + except ValueError as exc: + conflict = exc + if conflict is not None: + raise conflict + raise RuntimeError("broadcast edge transaction exited without result") + normalized_edge_id = self._normalize_required_text(edge.edge_id) normalized_session_id = self._normalize_required_text(edge.session_id) normalized_child_worker_id = self._normalize_required_text(edge.child_worker_id) - owns_cursor = cursor is None - if owns_cursor: - cursor = self.get_cursor() - try: - if edge.state in self._ACTIVE_EDGE_STATES: - existing = self.find_active_edge_for_child( - normalized_session_id, - normalized_child_worker_id, - cursor=cursor, - ) - if existing is not None: - raise ValueError( - "active broadcast edge already exists for child " - f"{normalized_child_worker_id} in session {normalized_session_id}" - ) - cursor.execute( - """ - INSERT INTO broadcast_edges ( - edge_id, session_id, parent_worker_id, parent_replica_id, - child_worker_id, level, attempt, state, transport_request_id, - failure_reason + if edge.state in self._ACTIVE_EDGE_STATES: + existing = self.find_active_edge_for_child( + normalized_session_id, + normalized_child_worker_id, + cursor=cursor, + ) + if existing is not None: + raise ValueError( + "active broadcast edge already exists for child " + f"{normalized_child_worker_id} in session {normalized_session_id}" ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - [ - normalized_edge_id, - normalized_session_id, - self._normalize_required_text(edge.parent_worker_id), - self._uuid_to_text(edge.parent_replica_id), - normalized_child_worker_id, - int(edge.level), - int(edge.attempt), - edge.state.value, - self._normalize_optional_text(edge.transport_request_id), - self._normalize_optional_text(edge.failure_reason), - ], + + cursor.execute( + """ + INSERT INTO broadcast_edges ( + edge_id, session_id, parent_worker_id, parent_replica_id, + child_worker_id, level, attempt, state, transport_request_id, + failure_reason ) - edge.edge_id = normalized_edge_id - edge.session_id = normalized_session_id - edge.child_worker_id = normalized_child_worker_id - return edge - finally: - if owns_cursor: - cursor.close() + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + normalized_edge_id, + normalized_session_id, + self._normalize_required_text(edge.parent_worker_id), + self._uuid_to_text(edge.parent_replica_id), + normalized_child_worker_id, + int(edge.level), + int(edge.attempt), + edge.state.value, + self._normalize_optional_text(edge.transport_request_id), + self._normalize_optional_text(edge.failure_reason), + ], + ) + edge.edge_id = normalized_edge_id + edge.session_id = normalized_session_id + edge.child_worker_id = normalized_child_worker_id + return edge def find_edge( self, @@ -419,62 +424,96 @@ def mark_edge_materializing( cursor: DuckDBPyConnection | None = None, ) -> bool: """Mark an edge and its target as materializing.""" + if cursor is None: + conflict: ValueError | None = None + with self.transaction() as tx: + try: + return self.mark_edge_materializing( + edge_id=edge_id, + transport_request_id=transport_request_id, + cursor=tx, + ) + except ValueError as exc: + conflict = exc + if conflict is not None: + raise conflict + raise RuntimeError("broadcast edge transaction exited without result") + normalized_edge_id = self._normalize_required_text(edge_id) normalized_transport_request_id = self._normalize_required_text( transport_request_id ) - owns_cursor = cursor is None - if owns_cursor: - cursor = self.get_cursor() - try: - edge = self.find_edge(normalized_edge_id, cursor=cursor) - if edge is None: - return False - existing = self.find_active_edge_for_child( + edge = self.find_edge(normalized_edge_id, cursor=cursor) + if edge is None: + return False + if edge.state not in (BroadcastEdgeState.PLANNED, BroadcastEdgeState.ASSIGNED): + return False + + existing = cursor.execute( + """ + SELECT edge_id + FROM broadcast_edges + WHERE session_id = ? + AND child_worker_id = ? + AND edge_id != ? + AND state IN ('planned', 'assigned', 'materializing') + LIMIT 1 + """, + [edge.session_id, edge.child_worker_id, edge.edge_id], + ).fetchone() + if existing is not None: + raise ValueError( + "active broadcast edge already exists for child " + f"{edge.child_worker_id} in session {edge.session_id}" + ) + + target = self.find_target(edge.session_id, edge.child_worker_id, cursor=cursor) + if target is None: + return False + if ( + target.assigned_edge_id is not None + and target.assigned_edge_id != edge.edge_id + ): + return False + + edge_row = cursor.execute( + """ + UPDATE broadcast_edges + SET state = 'materializing', + transport_request_id = ?, + updated_at = CURRENT_TIMESTAMP + WHERE edge_id = ? + AND state IN ('planned', 'assigned') + RETURNING edge_id + """, + [normalized_transport_request_id, normalized_edge_id], + ).fetchone() + if edge_row is None: + return False + + target_row = cursor.execute( + """ + UPDATE broadcast_targets + SET state = 'materializing', + assigned_edge_id = ?, + level = ?, + attempt = ?, + updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? + AND target_worker_id = ? + AND (assigned_edge_id IS NULL OR assigned_edge_id = ?) + RETURNING target_worker_id + """, + [ + normalized_edge_id, + int(edge.level), + int(edge.attempt), edge.session_id, edge.child_worker_id, - cursor=cursor, - ) - if existing is not None and existing.edge_id != edge.edge_id: - raise ValueError( - "active broadcast edge already exists for child " - f"{edge.child_worker_id} in session {edge.session_id}" - ) - row = cursor.execute( - """ - UPDATE broadcast_edges - SET state = 'materializing', - transport_request_id = ?, - updated_at = CURRENT_TIMESTAMP - WHERE edge_id = ? - RETURNING edge_id - """, - [normalized_transport_request_id, normalized_edge_id], - ).fetchone() - if row is None: - return False - cursor.execute( - """ - UPDATE broadcast_targets - SET state = 'materializing', - assigned_edge_id = ?, - level = ?, - attempt = ?, - updated_at = CURRENT_TIMESTAMP - WHERE session_id = ? AND target_worker_id = ? - """, - [ - normalized_edge_id, - int(edge.level), - int(edge.attempt), - edge.session_id, - edge.child_worker_id, - ], - ) - return True - finally: - if owns_cursor: - cursor.close() + normalized_edge_id, + ], + ).fetchone() + return target_row is not None def mark_edge_failed( self, @@ -483,50 +522,60 @@ def mark_edge_failed( cursor: DuckDBPyConnection | None = None, ) -> bool: """Mark an edge and its target as failed.""" + if cursor is None: + with self.transaction() as tx: + return self.mark_edge_failed( + edge_id=edge_id, + reason=reason, + cursor=tx, + ) + normalized_edge_id = self._normalize_required_text(edge_id) normalized_reason = self._normalize_required_text(reason) - owns_cursor = cursor is None - if owns_cursor: - cursor = self.get_cursor() - try: - edge = self.find_edge(normalized_edge_id, cursor=cursor) - if edge is None: - return False - row = cursor.execute( - """ - UPDATE broadcast_edges - SET state = 'failed', - failure_reason = ?, - updated_at = CURRENT_TIMESTAMP, - completed_at = CURRENT_TIMESTAMP - WHERE edge_id = ? - RETURNING edge_id - """, - [normalized_reason, normalized_edge_id], - ).fetchone() - if row is None: - return False - cursor.execute( - """ - UPDATE broadcast_targets - SET state = 'failed', - failure_reason = ?, - assigned_edge_id = ?, - updated_at = CURRENT_TIMESTAMP, - completed_at = CURRENT_TIMESTAMP - WHERE session_id = ? AND target_worker_id = ? - """, - [ - normalized_reason, - normalized_edge_id, - edge.session_id, - edge.child_worker_id, - ], - ) - return True - finally: - if owns_cursor: - cursor.close() + edge = self.find_edge(normalized_edge_id, cursor=cursor) + if edge is None: + return False + target = self.find_target(edge.session_id, edge.child_worker_id, cursor=cursor) + if target is None or target.assigned_edge_id != edge.edge_id: + return False + + row = cursor.execute( + """ + UPDATE broadcast_edges + SET state = 'failed', + failure_reason = ?, + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE edge_id = ? + RETURNING edge_id + """, + [normalized_reason, normalized_edge_id], + ).fetchone() + if row is None: + return False + + target_row = cursor.execute( + """ + UPDATE broadcast_targets + SET state = 'failed', + failure_reason = ?, + assigned_edge_id = ?, + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE session_id = ? + AND target_worker_id = ? + AND assigned_edge_id = ? + RETURNING target_worker_id + """, + [ + normalized_reason, + normalized_edge_id, + edge.session_id, + edge.child_worker_id, + normalized_edge_id, + ], + ).fetchone() + return target_row is not None def mark_edge_completed( self, @@ -535,48 +584,58 @@ def mark_edge_completed( cursor: DuckDBPyConnection | None = None, ) -> bool: """Mark an edge and its target as completed.""" + if cursor is None: + with self.transaction() as tx: + return self.mark_edge_completed( + edge_id=edge_id, + completed_replica_id=completed_replica_id, + cursor=tx, + ) + normalized_edge_id = self._normalize_required_text(edge_id) - owns_cursor = cursor is None - if owns_cursor: - cursor = self.get_cursor() - try: - edge = self.find_edge(normalized_edge_id, cursor=cursor) - if edge is None: - return False - row = cursor.execute( - """ - UPDATE broadcast_edges - SET state = 'completed', - updated_at = CURRENT_TIMESTAMP, - completed_at = CURRENT_TIMESTAMP - WHERE edge_id = ? - RETURNING edge_id - """, - [normalized_edge_id], - ).fetchone() - if row is None: - return False - cursor.execute( - """ - UPDATE broadcast_targets - SET state = 'completed', - completed_replica_id = ?, - assigned_edge_id = ?, - updated_at = CURRENT_TIMESTAMP, - completed_at = CURRENT_TIMESTAMP - WHERE session_id = ? AND target_worker_id = ? - """, - [ - self._uuid_to_text(completed_replica_id), - normalized_edge_id, - edge.session_id, - edge.child_worker_id, - ], - ) - return True - finally: - if owns_cursor: - cursor.close() + edge = self.find_edge(normalized_edge_id, cursor=cursor) + if edge is None: + return False + target = self.find_target(edge.session_id, edge.child_worker_id, cursor=cursor) + if target is None or target.assigned_edge_id != edge.edge_id: + return False + + row = cursor.execute( + """ + UPDATE broadcast_edges + SET state = 'completed', + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE edge_id = ? + RETURNING edge_id + """, + [normalized_edge_id], + ).fetchone() + if row is None: + return False + + target_row = cursor.execute( + """ + UPDATE broadcast_targets + SET state = 'completed', + completed_replica_id = ?, + assigned_edge_id = ?, + updated_at = CURRENT_TIMESTAMP, + completed_at = CURRENT_TIMESTAMP + WHERE session_id = ? + AND target_worker_id = ? + AND assigned_edge_id = ? + RETURNING target_worker_id + """, + [ + self._uuid_to_text(completed_replica_id), + normalized_edge_id, + edge.session_id, + edge.child_worker_id, + normalized_edge_id, + ], + ).fetchone() + return target_row is not None def count_incomplete_targets( self, diff --git a/tests/python/global_store/test_broadcast_repository.py b/tests/python/global_store/test_broadcast_repository.py index 2d879ab1..1c4b0e42 100644 --- a/tests/python/global_store/test_broadcast_repository.py +++ b/tests/python/global_store/test_broadcast_repository.py @@ -2,6 +2,8 @@ from uuid import UUID +import pytest + from tensorcast.global_store.models import ( BroadcastEdge, BroadcastEdgeState, @@ -122,7 +124,7 @@ def test_broadcast_repository_prevents_two_active_edges_for_child(db_connection) ) repo.create_edge(first) - try: + with pytest.raises(ValueError, match="active"): repo.create_edge( BroadcastEdge( edge_id="edge-2", @@ -135,10 +137,6 @@ def test_broadcast_repository_prevents_two_active_edges_for_child(db_connection) state=BroadcastEdgeState.ASSIGNED, ) ) - except Exception as exc: # noqa: BLE001 - assert "active" in str(exc).lower() or "constraint" in str(exc).lower() - else: - raise AssertionError("expected active edge uniqueness to reject duplicate child") def test_broadcast_repository_prevents_materializing_edge_when_child_has_active_edge( @@ -170,6 +168,66 @@ def test_broadcast_repository_prevents_materializing_edge_when_child_has_active_ state=BroadcastEdgeState.PLANNED, ) ) + repo.upsert_target( + BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + state=BroadcastTargetState.ASSIGNED, + assigned_edge_id="edge-contender", + ) + ) + db_connection.execute( + """ + INSERT INTO broadcast_edges ( + edge_id, session_id, parent_worker_id, parent_replica_id, + child_worker_id, level, attempt, state + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + "edge-contender", + "session-a", + "worker-root", + "00000000-0000-0000-0000-000000000001", + "worker-child", + 1, + 2, + "planned", + ], + ) + + with pytest.raises(ValueError, match="active"): + repo.mark_edge_materializing( + edge_id="edge-contender", + transport_request_id="transport-request-1", + ) + + +def test_broadcast_repository_does_not_materialize_terminal_edge(db_connection): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + repo.upsert_target( + BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + state=BroadcastTargetState.FAILED, + assigned_edge_id="edge-failed", + ) + ) repo.create_edge( BroadcastEdge( edge_id="edge-failed", @@ -178,20 +236,122 @@ def test_broadcast_repository_prevents_materializing_edge_when_child_has_active_ parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), child_worker_id="worker-child", level=1, - attempt=2, + attempt=1, state=BroadcastEdgeState.FAILED, ) ) - try: - repo.mark_edge_materializing( - edge_id="edge-failed", - transport_request_id="transport-request-1", + assert not repo.mark_edge_materializing( + edge_id="edge-failed", + transport_request_id="transport-request-1", + ) + edge = repo.find_edge("edge-failed") + target = repo.find_target("session-a", "worker-child") + assert edge is not None + assert target is not None + assert edge.state is BroadcastEdgeState.FAILED + assert edge.transport_request_id is None + assert target.state is BroadcastTargetState.FAILED + + +def test_broadcast_repository_leaves_edge_unchanged_when_target_missing( + db_connection, +): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-1", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.MATERIALIZING, ) - except Exception as exc: # noqa: BLE001 - assert "active" in str(exc).lower() + ) + + assert not repo.mark_edge_completed( + edge_id="edge-1", + completed_replica_id=UUID("00000000-0000-0000-0000-000000000002"), + ) + edge = repo.find_edge("edge-1") + assert edge is not None + assert edge.state is BroadcastEdgeState.MATERIALIZING + assert edge.completed_at is None + + +@pytest.mark.parametrize("transition", ["failed", "completed"]) +def test_broadcast_repository_stale_edge_transition_does_not_clobber_newer_target( + db_connection, + transition, +): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-a", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + repo.upsert_target( + BroadcastTarget( + session_id="session-a", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + state=BroadcastTargetState.ASSIGNED, + assigned_edge_id="edge-new", + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-old", + session_id="session-a", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.MATERIALIZING, + ) + ) + + if transition == "failed": + changed = repo.mark_edge_failed(edge_id="edge-old", reason="source unavailable") else: - raise AssertionError("expected materializing transition to reject duplicate child") + changed = repo.mark_edge_completed( + edge_id="edge-old", + completed_replica_id=UUID("00000000-0000-0000-0000-000000000002"), + ) + + edge = repo.find_edge("edge-old") + target = repo.find_target("session-a", "worker-child") + assert not changed + assert edge is not None + assert target is not None + assert edge.state is BroadcastEdgeState.MATERIALIZING + assert edge.completed_at is None + assert target.state is BroadcastTargetState.ASSIGNED + assert target.assigned_edge_id == "edge-new" def test_broadcast_repository_marks_edge_completed_and_target_completed(db_connection): From 8db320ad73f161779b4a4a5fcc95a7ab41cea2ba Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:13:32 +0800 Subject: [PATCH 23/49] feat(global-store): add broadcast session rpc --- .../global_store/v1/global_store.proto | 131 +++++++++ tensorcast/global_store/grpc_service.py | 13 + .../global_store/rpc/broadcast_rpc_handler.py | 211 ++++++++++++++ .../global_store/rpc_servicer_mixins.py | 18 ++ tensorcast/global_store/services/__init__.py | 2 + .../services/broadcast_service.py | 272 ++++++++++++++++++ .../python/global_store/test_broadcast_rpc.py | 73 +++++ .../global_store/test_broadcast_service.py | 91 ++++++ 8 files changed, 811 insertions(+) create mode 100644 tensorcast/global_store/rpc/broadcast_rpc_handler.py create mode 100644 tensorcast/global_store/services/broadcast_service.py create mode 100644 tests/python/global_store/test_broadcast_rpc.py create mode 100644 tests/python/global_store/test_broadcast_service.py diff --git a/proto/tensorcast/global_store/v1/global_store.proto b/proto/tensorcast/global_store/v1/global_store.proto index 5b398b23..d11ec3c2 100644 --- a/proto/tensorcast/global_store/v1/global_store.proto +++ b/proto/tensorcast/global_store/v1/global_store.proto @@ -46,6 +46,12 @@ service ClusterRuntimeService { rpc CompleteReplicaTransport(CompleteReplicaTransportRequest) returns (CompleteReplicaTransportResponse) {} rpc QueryTransportWindow(QueryTransportWindowRequest) returns (QueryTransportWindowResponse) {} + // Tree broadcast planning + rpc CreateBroadcastSession(CreateBroadcastSessionRequest) returns (CreateBroadcastSessionResponse) {} + rpc GetBroadcastSession(GetBroadcastSessionRequest) returns (GetBroadcastSessionResponse) {} + rpc ListBroadcastEdges(ListBroadcastEdgesRequest) returns (ListBroadcastEdgesResponse) {} + rpc CancelBroadcastSession(CancelBroadcastSessionRequest) returns (CancelBroadcastSessionResponse) {} + // Chunk directory operations rpc QueryChunkLocations(QueryChunkLocationsRequest) returns (QueryChunkLocationsResponse) {} rpc BatchUpdateChunkStates(BatchUpdateChunkStatesRequest) returns (BatchUpdateChunkStatesResponse) {} @@ -153,6 +159,35 @@ enum ConnectionStatus { CONNECTION_STATUS_CONNECTED = 3; } +enum BroadcastSessionState { + BROADCAST_SESSION_STATE_UNSPECIFIED = 0; + BROADCAST_SESSION_STATE_PLANNING = 1; + BROADCAST_SESSION_STATE_ACTIVE = 2; + BROADCAST_SESSION_STATE_COMPLETED = 3; + BROADCAST_SESSION_STATE_FAILED = 4; + BROADCAST_SESSION_STATE_CANCELLED = 5; +} + +enum BroadcastTargetState { + BROADCAST_TARGET_STATE_UNSPECIFIED = 0; + BROADCAST_TARGET_STATE_PENDING = 1; + BROADCAST_TARGET_STATE_ASSIGNED = 2; + BROADCAST_TARGET_STATE_MATERIALIZING = 3; + BROADCAST_TARGET_STATE_COMPLETED = 4; + BROADCAST_TARGET_STATE_FAILED = 5; + BROADCAST_TARGET_STATE_CANCELLED = 6; +} + +enum BroadcastEdgeState { + BROADCAST_EDGE_STATE_UNSPECIFIED = 0; + BROADCAST_EDGE_STATE_PLANNED = 1; + BROADCAST_EDGE_STATE_ASSIGNED = 2; + BROADCAST_EDGE_STATE_MATERIALIZING = 3; + BROADCAST_EDGE_STATE_COMPLETED = 4; + BROADCAST_EDGE_STATE_FAILED = 5; + BROADCAST_EDGE_STATE_CANCELLED = 6; +} + enum ReconcileRequestKind { RECONCILE_REQUEST_KIND_UNSPECIFIED = 0; RECONCILE_REQUEST_KIND_SNAPSHOT = 1; @@ -1059,6 +1094,102 @@ message QueryTransportWindowResponse { repeated TransportWindowRow rows = 2; } +message BroadcastTargetIdentity { + string worker_id = 1; + string daemon_id = 2; +} + +message BroadcastSessionInfo { + string session_id = 1; + string artifact_id = 2; + optional string requested_view_id = 3; + uint64 epoch = 4; + uint32 fanout = 5; + uint32 max_attempts = 6; + bool strict_parent = 7; + BroadcastSessionState state = 8; + string root_replica_id = 9; + google.protobuf.Timestamp created_at = 10; + google.protobuf.Timestamp updated_at = 11; + google.protobuf.Timestamp completed_at = 12; +} + +message BroadcastTargetInfo { + string session_id = 1; + string target_worker_id = 2; + string target_daemon_id = 3; + BroadcastTargetState state = 4; + uint32 level = 5; + uint32 attempt = 6; + string assigned_edge_id = 7; + string completed_replica_id = 8; + string failure_reason = 9; + google.protobuf.Timestamp created_at = 10; + google.protobuf.Timestamp updated_at = 11; + google.protobuf.Timestamp completed_at = 12; +} + +message BroadcastEdgeInfo { + string edge_id = 1; + string session_id = 2; + string parent_worker_id = 3; + string parent_replica_id = 4; + string child_worker_id = 5; + uint32 level = 6; + uint32 attempt = 7; + BroadcastEdgeState state = 8; + string transport_request_id = 9; + string failure_reason = 10; + google.protobuf.Timestamp created_at = 11; + google.protobuf.Timestamp updated_at = 12; + google.protobuf.Timestamp completed_at = 13; +} + +message CreateBroadcastSessionRequest { + string session_id = 1; + string artifact_id = 2; + optional string requested_view_id = 3; + uint64 epoch = 4; + uint32 fanout = 5; + bool strict_parent = 6; + uint32 max_attempts = 7; + string root_replica_id = 8; + repeated BroadcastTargetIdentity targets = 9; +} + +message CreateBroadcastSessionResponse { + Status status = 1; + BroadcastSessionInfo session = 2; + repeated BroadcastEdgeInfo edges = 3; +} + +message GetBroadcastSessionRequest { + string session_id = 1; +} + +message GetBroadcastSessionResponse { + Status status = 1; + BroadcastSessionInfo session = 2; +} + +message ListBroadcastEdgesRequest { + string session_id = 1; +} + +message ListBroadcastEdgesResponse { + Status status = 1; + repeated BroadcastEdgeInfo edges = 2; +} + +message CancelBroadcastSessionRequest { + string session_id = 1; +} + +message CancelBroadcastSessionResponse { + Status status = 1; + bool cancelled = 2; +} + // ========== Health Check ========== // Simple ping request to verify connectivity. diff --git a/tensorcast/global_store/grpc_service.py b/tensorcast/global_store/grpc_service.py index 8e18125e..e0c45def 100644 --- a/tensorcast/global_store/grpc_service.py +++ b/tensorcast/global_store/grpc_service.py @@ -56,6 +56,7 @@ AssemblyLayoutBindingRepository, AssemblyReadinessCutRepository, AssemblySlotOccupancyRepository, + BroadcastRepository, ChunkDirectoryRepository, ClusterInfoRepository, InstanceRepository, @@ -100,6 +101,7 @@ from tensorcast.global_store.rpc.assembly_slot_occupancy_rpc_handler import ( AssemblySlotOccupancyRpcHandler, ) +from tensorcast.global_store.rpc.broadcast_rpc_handler import BroadcastRpcHandler from tensorcast.global_store.rpc.chunk_rpc_handler import ChunkRpcHandler from tensorcast.global_store.rpc.disk_location_rpc_handler import DiskLocationRpcHandler from tensorcast.global_store.rpc.instance_rpc_handler import InstanceRpcHandler @@ -135,6 +137,7 @@ ) from tensorcast.global_store.services import ( ArtifactService, + BroadcastService, ChunkService, InstanceService, PlacementService, @@ -259,6 +262,7 @@ def _init_repositories(self) -> None: self.assembly_slot_occupancy_repository = AssemblySlotOccupancyRepository( self.connection ) + self.broadcast_repository = BroadcastRepository(self.connection) self.proof_repository = ProofRepository(self.connection) self.operation_repository = OperationRepository(self.connection) self.shard_home_lease_repository = ShardHomeLeaseRepository(self.connection) @@ -547,6 +551,15 @@ def _rebuild_runtime_services_and_handlers(self) -> None: datetime_to_timestamp=datetime_to_timestamp, logger=logger, ) + self.broadcast_service = BroadcastService( + broadcast_repository=self.broadcast_repository, + replica_repository=self.replica_repository, + worker_repository=self.worker_repository, + ) + self.broadcast_rpc_handler = BroadcastRpcHandler( + broadcast_service=self.broadcast_service, + logger=logger, + ) def set_runtime_info( self, diff --git a/tensorcast/global_store/rpc/broadcast_rpc_handler.py b/tensorcast/global_store/rpc/broadcast_rpc_handler.py new file mode 100644 index 00000000..5ac3887d --- /dev/null +++ b/tensorcast/global_store/rpc/broadcast_rpc_handler.py @@ -0,0 +1,211 @@ +# Copyright (c) 2026, TensorCast Team. + +"""Broadcast session RPC handler.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import grpc +from google.protobuf.timestamp_pb2 import Timestamp + +from tensorcast.global_store.models import ( + BroadcastEdge, + BroadcastEdgeState, + BroadcastSession, + BroadcastSessionState, +) +from tensorcast.global_store.services import BroadcastService +from tensorcast.proto.global_store.v1 import global_store_pb2 + + +_SESSION_STATE_TO_PROTO = { + BroadcastSessionState.PLANNING: global_store_pb2.BROADCAST_SESSION_STATE_PLANNING, + BroadcastSessionState.ACTIVE: global_store_pb2.BROADCAST_SESSION_STATE_ACTIVE, + BroadcastSessionState.COMPLETED: global_store_pb2.BROADCAST_SESSION_STATE_COMPLETED, + BroadcastSessionState.FAILED: global_store_pb2.BROADCAST_SESSION_STATE_FAILED, + BroadcastSessionState.CANCELLED: global_store_pb2.BROADCAST_SESSION_STATE_CANCELLED, +} + +_EDGE_STATE_TO_PROTO = { + BroadcastEdgeState.PLANNED: global_store_pb2.BROADCAST_EDGE_STATE_PLANNED, + BroadcastEdgeState.ASSIGNED: global_store_pb2.BROADCAST_EDGE_STATE_ASSIGNED, + BroadcastEdgeState.MATERIALIZING: global_store_pb2.BROADCAST_EDGE_STATE_MATERIALIZING, + BroadcastEdgeState.COMPLETED: global_store_pb2.BROADCAST_EDGE_STATE_COMPLETED, + BroadcastEdgeState.FAILED: global_store_pb2.BROADCAST_EDGE_STATE_FAILED, + BroadcastEdgeState.CANCELLED: global_store_pb2.BROADCAST_EDGE_STATE_CANCELLED, +} + + +class BroadcastRpcHandler: + """Owns broadcast planning RPC behavior and error mapping.""" + + def __init__(self, *, broadcast_service: BroadcastService, logger) -> None: + self._broadcast_service = broadcast_service + self._logger = logger + + def create_broadcast_session( + self, + request: global_store_pb2.CreateBroadcastSessionRequest, + context: grpc.ServicerContext, + ) -> global_store_pb2.CreateBroadcastSessionResponse: + try: + requested_view_id = ( + request.requested_view_id + if request.HasField("requested_view_id") + else None + ) + session = self._broadcast_service.create_session( + session_id=request.session_id, + artifact_id=request.artifact_id, + requested_view_id=requested_view_id, + epoch=int(request.epoch), + fanout=int(request.fanout), + target_worker_ids=[target.worker_id for target in request.targets], + target_daemon_ids=[target.daemon_id for target in request.targets], + root_replica_id=request.root_replica_id, + strict_parent=bool(request.strict_parent), + max_attempts=int(request.max_attempts), + ) + edges = self._broadcast_service.list_edges(session.session_id) + return global_store_pb2.CreateBroadcastSessionResponse( + status=global_store_pb2.STATUS_OK, + session=self._session_to_proto(session), + edges=[self._edge_to_proto(edge) for edge in edges], + ) + except ValueError as exc: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details(str(exc)) + return global_store_pb2.CreateBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR + ) + except Exception as exc: # noqa: BLE001 + self._logger.exception("CreateBroadcastSession failed") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(exc)) + return global_store_pb2.CreateBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR + ) + + def get_broadcast_session( + self, + request: global_store_pb2.GetBroadcastSessionRequest, + context: grpc.ServicerContext, + ) -> global_store_pb2.GetBroadcastSessionResponse: + try: + session = self._broadcast_service.get_session(request.session_id) + if session is None: + return global_store_pb2.GetBroadcastSessionResponse( + status=global_store_pb2.STATUS_NOT_FOUND + ) + return global_store_pb2.GetBroadcastSessionResponse( + status=global_store_pb2.STATUS_OK, + session=self._session_to_proto(session), + ) + except Exception as exc: # noqa: BLE001 + self._logger.exception("GetBroadcastSession failed") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(exc)) + return global_store_pb2.GetBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR + ) + + def list_broadcast_edges( + self, + request: global_store_pb2.ListBroadcastEdgesRequest, + context: grpc.ServicerContext, + ) -> global_store_pb2.ListBroadcastEdgesResponse: + try: + return global_store_pb2.ListBroadcastEdgesResponse( + status=global_store_pb2.STATUS_OK, + edges=[ + self._edge_to_proto(edge) + for edge in self._broadcast_service.list_edges(request.session_id) + ], + ) + except Exception as exc: # noqa: BLE001 + self._logger.exception("ListBroadcastEdges failed") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(exc)) + return global_store_pb2.ListBroadcastEdgesResponse( + status=global_store_pb2.STATUS_ERROR + ) + + def cancel_broadcast_session( + self, + request: global_store_pb2.CancelBroadcastSessionRequest, + context: grpc.ServicerContext, + ) -> global_store_pb2.CancelBroadcastSessionResponse: + try: + cancelled = self._broadcast_service.cancel_session(request.session_id) + return global_store_pb2.CancelBroadcastSessionResponse( + status=( + global_store_pb2.STATUS_OK + if cancelled + else global_store_pb2.STATUS_NOT_FOUND + ), + cancelled=cancelled, + ) + except Exception as exc: # noqa: BLE001 + self._logger.exception("CancelBroadcastSession failed") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(exc)) + return global_store_pb2.CancelBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR + ) + + def _session_to_proto( + self, + session: BroadcastSession, + ) -> global_store_pb2.BroadcastSessionInfo: + message = global_store_pb2.BroadcastSessionInfo( + session_id=session.session_id, + artifact_id=session.artifact_id, + epoch=int(session.epoch), + fanout=int(session.fanout), + max_attempts=int(session.max_attempts), + strict_parent=bool(session.strict_parent), + state=_SESSION_STATE_TO_PROTO.get( + session.state, + global_store_pb2.BROADCAST_SESSION_STATE_UNSPECIFIED, + ), + root_replica_id=str(session.root_replica_id or ""), + ) + if session.requested_view_id is not None: + message.requested_view_id = session.requested_view_id + self._copy_timestamp(message.created_at, session.created_at) + self._copy_timestamp(message.updated_at, session.updated_at) + self._copy_timestamp(message.completed_at, session.completed_at) + return message + + def _edge_to_proto( + self, + edge: BroadcastEdge, + ) -> global_store_pb2.BroadcastEdgeInfo: + message = global_store_pb2.BroadcastEdgeInfo( + edge_id=edge.edge_id, + session_id=edge.session_id, + parent_worker_id=edge.parent_worker_id, + parent_replica_id=str(edge.parent_replica_id), + child_worker_id=edge.child_worker_id, + level=int(edge.level), + attempt=int(edge.attempt), + state=_EDGE_STATE_TO_PROTO.get( + edge.state, + global_store_pb2.BROADCAST_EDGE_STATE_UNSPECIFIED, + ), + transport_request_id=edge.transport_request_id or "", + failure_reason=edge.failure_reason or "", + ) + self._copy_timestamp(message.created_at, edge.created_at) + self._copy_timestamp(message.updated_at, edge.updated_at) + self._copy_timestamp(message.completed_at, edge.completed_at) + return message + + @staticmethod + def _copy_timestamp(target: Timestamp, value: datetime | None) -> None: + if value is None: + return + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + target.FromDatetime(value) diff --git a/tensorcast/global_store/rpc_servicer_mixins.py b/tensorcast/global_store/rpc_servicer_mixins.py index 0207e6c3..e05e7d13 100644 --- a/tensorcast/global_store/rpc_servicer_mixins.py +++ b/tensorcast/global_store/rpc_servicer_mixins.py @@ -26,6 +26,7 @@ from tensorcast.global_store.rpc.assembly_slot_occupancy_rpc_handler import ( AssemblySlotOccupancyRpcHandler, ) +from tensorcast.global_store.rpc.broadcast_rpc_handler import BroadcastRpcHandler from tensorcast.global_store.rpc.chunk_rpc_handler import ChunkRpcHandler from tensorcast.global_store.rpc.disk_location_rpc_handler import DiskLocationRpcHandler from tensorcast.global_store.rpc.instance_rpc_handler import InstanceRpcHandler @@ -275,6 +276,7 @@ class ClusterRuntimeRpcServicerMixin: worker_state_sync_rpc_handler: WorkerStateSyncRpcHandler chunk_rpc_handler: ChunkRpcHandler shard_home_lease_rpc_handler: ShardHomeLeaseRpcHandler + broadcast_rpc_handler: BroadcastRpcHandler def RegisterReplica(self, request: Any, context: grpc.ServicerContext) -> Any: return self.replica_registration_rpc_handler.register_replica(request, context) @@ -323,6 +325,22 @@ def CompleteReplicaTransport( def QueryTransportWindow(self, request: Any, context: grpc.ServicerContext) -> Any: return self.transport_rpc_handler.query_transport_window(request, context) + def CreateBroadcastSession( + self, request: Any, context: grpc.ServicerContext + ) -> Any: + return self.broadcast_rpc_handler.create_broadcast_session(request, context) + + def GetBroadcastSession(self, request: Any, context: grpc.ServicerContext) -> Any: + return self.broadcast_rpc_handler.get_broadcast_session(request, context) + + def ListBroadcastEdges(self, request: Any, context: grpc.ServicerContext) -> Any: + return self.broadcast_rpc_handler.list_broadcast_edges(request, context) + + def CancelBroadcastSession( + self, request: Any, context: grpc.ServicerContext + ) -> Any: + return self.broadcast_rpc_handler.cancel_broadcast_session(request, context) + def RegisterWorker(self, request: Any, context: grpc.ServicerContext) -> Any: return self.worker_rpc_handler.register_worker(request, context) diff --git a/tensorcast/global_store/services/__init__.py b/tensorcast/global_store/services/__init__.py index 533f90ca..e9aa201c 100644 --- a/tensorcast/global_store/services/__init__.py +++ b/tensorcast/global_store/services/__init__.py @@ -3,6 +3,7 @@ """Global Store services.""" from .artifact_service import ArtifactService +from .broadcast_service import BroadcastService from .chunk_service import ChunkService from .instance_service import InstanceService from .memory_tier_service import MemoryTierService @@ -17,6 +18,7 @@ __all__ = [ "ChunkService", "ArtifactService", + "BroadcastService", "TransportService", "WorkerService", "InstanceService", diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py new file mode 100644 index 00000000..f432ea99 --- /dev/null +++ b/tensorcast/global_store/services/broadcast_service.py @@ -0,0 +1,272 @@ +# Copyright (c) 2026, TensorCast Team. + +"""Broadcast session planning service.""" + +from __future__ import annotations + +from uuid import UUID, uuid4 + +from tensorcast.global_store.models import ( + BroadcastEdge, + BroadcastEdgeState, + BroadcastSession, + BroadcastSessionState, + BroadcastTarget, + BroadcastTargetState, + Replica, + Worker, +) +from tensorcast.global_store.repositories import ( + BroadcastRepository, + ReplicaRepository, + WorkerRepository, +) + + +class BroadcastService: + """Coordinates broadcast session topology state.""" + + _ROOT_HEARTBEAT_TIMEOUT_SECONDS = 5.0 + + def __init__( + self, + *, + broadcast_repository: BroadcastRepository, + replica_repository: ReplicaRepository, + worker_repository: WorkerRepository, + ) -> None: + self._broadcast_repository = broadcast_repository + self._replica_repository = replica_repository + self._worker_repository = worker_repository + + def create_session( + self, + *, + session_id: str, + artifact_id: str, + requested_view_id: str | None, + epoch: int, + fanout: int, + target_daemon_ids: list[str] | tuple[str, ...], + root_replica_id: str | None, + strict_parent: bool, + max_attempts: int, + target_worker_ids: list[str] | tuple[str, ...] | None = None, + ) -> BroadcastSession: + """Create a broadcast session and reserve the first planned edges.""" + if fanout <= 0: + raise ValueError("fanout must be > 0") + if max_attempts <= 0: + raise ValueError("max_attempts must be > 0") + + targets = self._resolve_targets( + target_worker_ids=target_worker_ids or (), + target_daemon_ids=target_daemon_ids, + ) + root_replica, selected_root = self._resolve_root_replica( + artifact_id=artifact_id, + requested_view_id=requested_view_id, + root_replica_id=root_replica_id, + ) + + session = BroadcastSession( + session_id=session_id, + artifact_id=artifact_id, + requested_view_id=requested_view_id, + epoch=int(epoch), + fanout=int(fanout), + max_attempts=int(max_attempts), + strict_parent=bool(strict_parent), + state=BroadcastSessionState.ACTIVE, + root_replica_id=root_replica.replica_id, + ) + self._broadcast_repository.create_session(session) + + for worker in targets: + self._broadcast_repository.upsert_target( + BroadcastTarget( + session_id=session.session_id, + target_worker_id=worker.worker_id, + target_daemon_id=worker.daemon_id, + state=BroadcastTargetState.PENDING, + ) + ) + + self._plan_more_edges(session) + + if selected_root: + self._replica_repository.decrement_requests(root_replica.replica_id) + + return session + + def get_session(self, session_id: str) -> BroadcastSession | None: + """Return a broadcast session by ID.""" + return self._broadcast_repository.find_session(session_id) + + def list_edges(self, session_id: str) -> list[BroadcastEdge]: + """List broadcast edges for a session.""" + cursor = self._broadcast_repository.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._broadcast_repository._EDGE_PROJECTION} + FROM broadcast_edges + WHERE session_id = ? + ORDER BY level ASC, created_at ASC, edge_id ASC + """, + [session_id], + ) + rows = query.fetchall() + assert query.description is not None + columns = [desc[0] for desc in query.description] + return [ + self._broadcast_repository._row_to_edge(row, columns) for row in rows + ] + finally: + cursor.close() + + def cancel_session(self, session_id: str) -> bool: + """Mark a broadcast session cancelled.""" + return self._broadcast_repository.update_session_state( + session_id, + BroadcastSessionState.CANCELLED, + ) + + def _resolve_targets( + self, + *, + target_worker_ids: list[str] | tuple[str, ...], + target_daemon_ids: list[str] | tuple[str, ...], + ) -> list[Worker]: + targets: dict[str, Worker] = {} + for worker_id in target_worker_ids: + worker = self._worker_repository.find_by_id( + worker_id, + include_inactive=False, + ) + if worker is None: + raise ValueError(f"target worker not found: {worker_id}") + targets[worker.worker_id] = worker + + for daemon_id in target_daemon_ids: + worker = self._worker_repository.find_by_daemon_id( + daemon_id, + include_inactive=False, + ) + if worker is None: + raise ValueError(f"target daemon not found: {daemon_id}") + targets[worker.worker_id] = worker + + if not targets: + raise ValueError("at least one broadcast target is required") + return list(targets.values()) + + def _resolve_root_replica( + self, + *, + artifact_id: str, + requested_view_id: str | None, + root_replica_id: str | None, + ) -> tuple[Replica, bool]: + normalized_root_id = (root_replica_id or "").strip() + if normalized_root_id: + replica = self._replica_repository.find_by_id( + UUID(normalized_root_id), + artifact_id, + ) + if replica is None: + raise ValueError(f"root replica not found: {normalized_root_id}") + return replica, False + + result = self._replica_repository.find_available_for_transport( + artifact_id=artifact_id, + heartbeat_timeout_seconds=self._ROOT_HEARTBEAT_TIMEOUT_SECONDS, + view_id=requested_view_id, + ) + if result.replica is None: + raise ValueError(f"no available root replica for artifact: {artifact_id}") + return result.replica, True + + def _plan_more_edges(self, session: BroadcastSession) -> list[BroadcastEdge]: + pending_targets = self._broadcast_repository.list_targets_by_state( + session.session_id, + BroadcastTargetState.PENDING, + limit=max(0, int(session.fanout)), + ) + if not pending_targets or session.root_replica_id is None: + return [] + + active_edges_count = self._count_active_edges(session.session_id) + capacity = max(0, int(session.fanout) - active_edges_count) + if capacity <= 0: + return [] + + parent_pool = self._parent_pool(session) + if not parent_pool: + return [] + + planned: list[BroadcastEdge] = [] + for target, parent in zip(pending_targets[:capacity], parent_pool * capacity): + parent_replica, parent_level = parent + edge = BroadcastEdge( + edge_id=str(uuid4()), + session_id=session.session_id, + parent_worker_id=parent_replica.worker_id or "", + parent_replica_id=parent_replica.replica_id, + child_worker_id=target.target_worker_id, + level=parent_level + 1, + attempt=target.attempt + 1, + state=BroadcastEdgeState.PLANNED, + ) + self._broadcast_repository.create_edge(edge) + target.state = BroadcastTargetState.ASSIGNED + target.level = edge.level + target.attempt = edge.attempt + target.assigned_edge_id = edge.edge_id + self._broadcast_repository.upsert_target(target) + planned.append(edge) + return planned + + def _count_active_edges(self, session_id: str) -> int: + cursor = self._broadcast_repository.get_cursor() + try: + row = cursor.execute( + """ + SELECT COUNT(*) + FROM broadcast_edges + WHERE session_id = ? + AND state IN ('planned', 'assigned', 'materializing') + """, + [session_id], + ).fetchone() + return int(row[0]) if row is not None else 0 + finally: + cursor.close() + + def _parent_pool( + self, + session: BroadcastSession, + ) -> list[tuple[Replica, int]]: + parents: list[tuple[Replica, int]] = [] + if session.root_replica_id is not None: + root = self._replica_repository.find_by_id( + session.root_replica_id, + session.artifact_id, + ) + if root is not None: + parents.append((root, 0)) + + completed_targets = self._broadcast_repository.list_targets_by_state( + session.session_id, + BroadcastTargetState.COMPLETED, + limit=10_000, + ) + for target in completed_targets: + if target.completed_replica_id is None: + continue + replica = self._replica_repository.find_by_replica_id( + target.completed_replica_id + ) + if replica is not None: + parents.append((replica, int(target.level or 0))) + return parents diff --git a/tests/python/global_store/test_broadcast_rpc.py b/tests/python/global_store/test_broadcast_rpc.py new file mode 100644 index 00000000..ef06dd4e --- /dev/null +++ b/tests/python/global_store/test_broadcast_rpc.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from tensorcast.proto.global_store.v1 import global_store_pb2 + + +def test_create_broadcast_session_rpc_returns_edges(servicer, test_context, memory_info): + root_worker = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-root", + node_id="node-root", + node_address="10.10.0.1", + grpc_port=50101, + p2p_port=50102, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + child_worker = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-child", + node_id="node-child", + node_address="10.10.0.2", + grpc_port=50201, + p2p_port=50202, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + memory_info.node_id = "node-root" + memory_info.node_address = "10.10.0.1" + memory_info.node_port = 50102 + register_resp = servicer.RegisterReplica( + global_store_pb2.RegisterReplicaRequest( + artifact_id="mi2:model-rpc", + worker_id=root_worker, + mem_info=memory_info, + max_concurrency=4, + ), + test_context, + ) + assert register_resp.status == global_store_pb2.STATUS_OK + + response = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-rpc", + artifact_id="mi2:model-rpc", + epoch=7, + fanout=1, + strict_parent=True, + max_attempts=3, + root_replica_id=register_resp.replica_id, + targets=[ + global_store_pb2.BroadcastTargetIdentity( + worker_id=child_worker, + daemon_id="daemon-child", + ) + ], + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_OK + assert response.session.session_id == "session-rpc" + assert response.session.state == global_store_pb2.BROADCAST_SESSION_STATE_ACTIVE + edge_resp = servicer.ListBroadcastEdges( + global_store_pb2.ListBroadcastEdgesRequest(session_id="session-rpc"), + test_context, + ) + assert edge_resp.status == global_store_pb2.STATUS_OK + assert len(edge_resp.edges) == 1 + assert edge_resp.edges[0].child_worker_id == child_worker diff --git a/tests/python/global_store/test_broadcast_service.py b/tests/python/global_store/test_broadcast_service.py new file mode 100644 index 00000000..e14ae7f8 --- /dev/null +++ b/tests/python/global_store/test_broadcast_service.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from tensorcast.global_store.models import ( + BroadcastEdgeState, + BroadcastSessionState, + BroadcastTargetState, + ExportState, + MemoryType, + Replica, + Worker, +) +from tensorcast.global_store.services import BroadcastService + + +def _worker(worker_id: str, daemon_id: str, node_id: str) -> Worker: + return Worker( + worker_id=worker_id, + daemon_id=daemon_id, + node_id=node_id, + node_address=f"10.0.0.{node_id[-1]}", + grpc_port=5000 + int(node_id[-1]), + p2p_port=6000 + int(node_id[-1]), + mem_pool_total_size=4096, + mem_pool_available_size=4096, + accepting_new_requests=True, + ) + + +def _exportable_replica(artifact_id: str, worker: Worker) -> Replica: + return Replica( + artifact_id=artifact_id, + node_id=worker.node_id, + node_address=worker.node_address, + node_port=worker.p2p_port, + memory_size=1024, + memory_type=MemoryType.GPU, + device_id=0, + max_concurrency=4, + current_requests=0, + is_available=True, + remote_memory_keys=[f"rk-{worker.worker_id}"], + buffer_sizes=[1024], + export_state=ExportState.EXPORTABLE, + worker_id=worker.worker_id, + ) + + +def test_create_session_plans_first_layer_by_fanout(repositories): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root", "daemon-root", "node1") + child1 = _worker("worker-child-1", "daemon-child-1", "node2") + child2 = _worker("worker-child-2", "daemon-child-2", "node3") + child3 = _worker("worker-child-3", "daemon-child-3", "node4") + for worker in (root, child1, child2, child3): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-a", root)) + + session = service.create_session( + session_id="session-a", + artifact_id="mi2:model-a", + requested_view_id=None, + epoch=42, + fanout=2, + target_daemon_ids=["daemon-child-1", "daemon-child-2", "daemon-child-3"], + root_replica_id=str(root_replica.replica_id), + strict_parent=True, + max_attempts=3, + ) + + assert session.state is BroadcastSessionState.ACTIVE + targets = broadcast_repo.list_targets("session-a") + assert len(targets) == 3 + assigned = [t for t in targets if t.state is BroadcastTargetState.ASSIGNED] + pending = [t for t in targets if t.state is BroadcastTargetState.PENDING] + assert len(assigned) == 2 + assert len(pending) == 1 + edges = [ + broadcast_repo.find_active_edge_for_child("session-a", t.target_worker_id) + for t in assigned + ] + assert all(edge is not None for edge in edges) + assert all(edge.state is BroadcastEdgeState.PLANNED for edge in edges if edge) + assert all(edge.parent_replica_id == root_replica.replica_id for edge in edges if edge) From 462308cdaa23421267bd1781f6cc1b2b6877f53e Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:23:45 +0800 Subject: [PATCH 24/49] fix(global-store): align broadcast rpc contract --- .../global_store/v1/global_store.proto | 9 +- .../global_store/rpc/broadcast_rpc_handler.py | 81 +++++++++-- .../services/broadcast_service.py | 18 +++ .../python/global_store/test_broadcast_rpc.py | 130 ++++++++++++++++++ .../global_store/test_broadcast_service.py | 46 +++++++ 5 files changed, 273 insertions(+), 11 deletions(-) diff --git a/proto/tensorcast/global_store/v1/global_store.proto b/proto/tensorcast/global_store/v1/global_store.proto index d11ec3c2..f0c884b2 100644 --- a/proto/tensorcast/global_store/v1/global_store.proto +++ b/proto/tensorcast/global_store/v1/global_store.proto @@ -1102,7 +1102,7 @@ message BroadcastTargetIdentity { message BroadcastSessionInfo { string session_id = 1; string artifact_id = 2; - optional string requested_view_id = 3; + tensorcast.common.v1.ByteSpaceRef requested_byte_space = 3; uint64 epoch = 4; uint32 fanout = 5; uint32 max_attempts = 6; @@ -1148,7 +1148,7 @@ message BroadcastEdgeInfo { message CreateBroadcastSessionRequest { string session_id = 1; string artifact_id = 2; - optional string requested_view_id = 3; + tensorcast.common.v1.ByteSpaceRef requested_byte_space = 3; uint64 epoch = 4; uint32 fanout = 5; bool strict_parent = 6; @@ -1160,7 +1160,8 @@ message CreateBroadcastSessionRequest { message CreateBroadcastSessionResponse { Status status = 1; BroadcastSessionInfo session = 2; - repeated BroadcastEdgeInfo edges = 3; + repeated BroadcastTargetInfo targets = 3; + repeated BroadcastEdgeInfo edges = 4; } message GetBroadcastSessionRequest { @@ -1170,6 +1171,7 @@ message GetBroadcastSessionRequest { message GetBroadcastSessionResponse { Status status = 1; BroadcastSessionInfo session = 2; + repeated BroadcastTargetInfo targets = 3; } message ListBroadcastEdgesRequest { @@ -1183,6 +1185,7 @@ message ListBroadcastEdgesResponse { message CancelBroadcastSessionRequest { string session_id = 1; + string reason = 2; } message CancelBroadcastSessionResponse { diff --git a/tensorcast/global_store/rpc/broadcast_rpc_handler.py b/tensorcast/global_store/rpc/broadcast_rpc_handler.py index 5ac3887d..7523e039 100644 --- a/tensorcast/global_store/rpc/broadcast_rpc_handler.py +++ b/tensorcast/global_store/rpc/broadcast_rpc_handler.py @@ -14,8 +14,11 @@ BroadcastEdgeState, BroadcastSession, BroadcastSessionState, + BroadcastTarget, + BroadcastTargetState, ) from tensorcast.global_store.services import BroadcastService +from tensorcast.proto.common.v1 import common_pb2 from tensorcast.proto.global_store.v1 import global_store_pb2 @@ -36,6 +39,15 @@ BroadcastEdgeState.CANCELLED: global_store_pb2.BROADCAST_EDGE_STATE_CANCELLED, } +_TARGET_STATE_TO_PROTO = { + BroadcastTargetState.PENDING: global_store_pb2.BROADCAST_TARGET_STATE_PENDING, + BroadcastTargetState.ASSIGNED: global_store_pb2.BROADCAST_TARGET_STATE_ASSIGNED, + BroadcastTargetState.MATERIALIZING: global_store_pb2.BROADCAST_TARGET_STATE_MATERIALIZING, + BroadcastTargetState.COMPLETED: global_store_pb2.BROADCAST_TARGET_STATE_COMPLETED, + BroadcastTargetState.FAILED: global_store_pb2.BROADCAST_TARGET_STATE_FAILED, + BroadcastTargetState.CANCELLED: global_store_pb2.BROADCAST_TARGET_STATE_CANCELLED, +} + class BroadcastRpcHandler: """Owns broadcast planning RPC behavior and error mapping.""" @@ -50,27 +62,34 @@ def create_broadcast_session( context: grpc.ServicerContext, ) -> global_store_pb2.CreateBroadcastSessionResponse: try: - requested_view_id = ( - request.requested_view_id - if request.HasField("requested_view_id") - else None - ) + requested_view_id = self._requested_view_id_from_byte_space(request) + target_worker_ids: list[str] = [] + target_daemon_ids: list[str] = [] + for target in request.targets: + worker_id = target.worker_id.strip() + daemon_id = target.daemon_id.strip() + if worker_id: + target_worker_ids.append(worker_id) + if daemon_id: + target_daemon_ids.append(daemon_id) session = self._broadcast_service.create_session( session_id=request.session_id, artifact_id=request.artifact_id, requested_view_id=requested_view_id, epoch=int(request.epoch), fanout=int(request.fanout), - target_worker_ids=[target.worker_id for target in request.targets], - target_daemon_ids=[target.daemon_id for target in request.targets], + target_worker_ids=target_worker_ids, + target_daemon_ids=target_daemon_ids, root_replica_id=request.root_replica_id, strict_parent=bool(request.strict_parent), max_attempts=int(request.max_attempts), ) edges = self._broadcast_service.list_edges(session.session_id) + targets = self._broadcast_service.list_targets(session.session_id) return global_store_pb2.CreateBroadcastSessionResponse( status=global_store_pb2.STATUS_OK, session=self._session_to_proto(session), + targets=[self._target_to_proto(target) for target in targets], edges=[self._edge_to_proto(edge) for edge in edges], ) except ValueError as exc: @@ -101,6 +120,12 @@ def get_broadcast_session( return global_store_pb2.GetBroadcastSessionResponse( status=global_store_pb2.STATUS_OK, session=self._session_to_proto(session), + targets=[ + self._target_to_proto(target) + for target in self._broadcast_service.list_targets( + session.session_id + ) + ], ) except Exception as exc: # noqa: BLE001 self._logger.exception("GetBroadcastSession failed") @@ -172,12 +197,38 @@ def _session_to_proto( root_replica_id=str(session.root_replica_id or ""), ) if session.requested_view_id is not None: - message.requested_view_id = session.requested_view_id + message.requested_byte_space.kind = common_pb2.BYTE_SPACE_KIND_VIEW + message.requested_byte_space.id = session.requested_view_id + else: + message.requested_byte_space.kind = common_pb2.BYTE_SPACE_KIND_CANONICAL self._copy_timestamp(message.created_at, session.created_at) self._copy_timestamp(message.updated_at, session.updated_at) self._copy_timestamp(message.completed_at, session.completed_at) return message + def _target_to_proto( + self, + target: BroadcastTarget, + ) -> global_store_pb2.BroadcastTargetInfo: + message = global_store_pb2.BroadcastTargetInfo( + session_id=target.session_id, + target_worker_id=target.target_worker_id, + target_daemon_id=target.target_daemon_id or "", + state=_TARGET_STATE_TO_PROTO.get( + target.state, + global_store_pb2.BROADCAST_TARGET_STATE_UNSPECIFIED, + ), + level=int(target.level or 0), + attempt=int(target.attempt), + assigned_edge_id=target.assigned_edge_id or "", + completed_replica_id=str(target.completed_replica_id or ""), + failure_reason=target.failure_reason or "", + ) + self._copy_timestamp(message.created_at, target.created_at) + self._copy_timestamp(message.updated_at, target.updated_at) + self._copy_timestamp(message.completed_at, target.completed_at) + return message + def _edge_to_proto( self, edge: BroadcastEdge, @@ -209,3 +260,17 @@ def _copy_timestamp(target: Timestamp, value: datetime | None) -> None: if value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) target.FromDatetime(value) + + @staticmethod + def _requested_view_id_from_byte_space( + request: global_store_pb2.CreateBroadcastSessionRequest, + ) -> str | None: + if not request.HasField("requested_byte_space"): + return None + byte_space = request.requested_byte_space + if byte_space.kind == common_pb2.BYTE_SPACE_KIND_VIEW: + view_id = byte_space.id.strip() + if not view_id: + raise ValueError("requested_byte_space VIEW requires id") + return view_id + return None diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index f432ea99..5dfa6379 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -54,6 +54,14 @@ def create_session( target_worker_ids: list[str] | tuple[str, ...] | None = None, ) -> BroadcastSession: """Create a broadcast session and reserve the first planned edges.""" + session_id = str(session_id).strip() + artifact_id = str(artifact_id).strip() + if not session_id: + raise ValueError("session_id is required") + if not artifact_id: + raise ValueError("artifact_id is required") + if epoch < 0: + raise ValueError("epoch must be >= 0") if fanout <= 0: raise ValueError("fanout must be > 0") if max_attempts <= 0: @@ -125,6 +133,10 @@ def list_edges(self, session_id: str) -> list[BroadcastEdge]: finally: cursor.close() + def list_targets(self, session_id: str) -> list[BroadcastTarget]: + """List broadcast targets for a session.""" + return self._broadcast_repository.list_targets(session_id) + def cancel_session(self, session_id: str) -> bool: """Mark a broadcast session cancelled.""" return self._broadcast_repository.update_session_state( @@ -140,6 +152,9 @@ def _resolve_targets( ) -> list[Worker]: targets: dict[str, Worker] = {} for worker_id in target_worker_ids: + worker_id = str(worker_id).strip() + if not worker_id: + continue worker = self._worker_repository.find_by_id( worker_id, include_inactive=False, @@ -149,6 +164,9 @@ def _resolve_targets( targets[worker.worker_id] = worker for daemon_id in target_daemon_ids: + daemon_id = str(daemon_id).strip() + if not daemon_id: + continue worker = self._worker_repository.find_by_daemon_id( daemon_id, include_inactive=False, diff --git a/tests/python/global_store/test_broadcast_rpc.py b/tests/python/global_store/test_broadcast_rpc.py index ef06dd4e..6107bb58 100644 --- a/tests/python/global_store/test_broadcast_rpc.py +++ b/tests/python/global_store/test_broadcast_rpc.py @@ -1,8 +1,40 @@ from __future__ import annotations +from tensorcast.proto.common.v1 import common_pb2 from tensorcast.proto.global_store.v1 import global_store_pb2 +def test_broadcast_proto_contract_uses_byte_space_and_targets(): + session_fields = global_store_pb2.BroadcastSessionInfo.DESCRIPTOR.fields_by_name + create_fields = ( + global_store_pb2.CreateBroadcastSessionRequest.DESCRIPTOR.fields_by_name + ) + create_response_fields = ( + global_store_pb2.CreateBroadcastSessionResponse.DESCRIPTOR.fields_by_name + ) + get_response_fields = ( + global_store_pb2.GetBroadcastSessionResponse.DESCRIPTOR.fields_by_name + ) + cancel_fields = global_store_pb2.CancelBroadcastSessionRequest.DESCRIPTOR.fields_by_name + + assert "requested_view_id" not in session_fields + assert session_fields["requested_byte_space"].number == 3 + assert ( + session_fields["requested_byte_space"].message_type.full_name + == "tensorcast.common.v1.ByteSpaceRef" + ) + assert "requested_view_id" not in create_fields + assert create_fields["requested_byte_space"].number == 3 + assert ( + create_fields["requested_byte_space"].message_type.full_name + == "tensorcast.common.v1.ByteSpaceRef" + ) + assert create_response_fields["targets"].number == 3 + assert create_response_fields["edges"].number == 4 + assert get_response_fields["targets"].number == 3 + assert cancel_fields["reason"].number == 2 + + def test_create_broadcast_session_rpc_returns_edges(servicer, test_context, memory_info): root_worker = servicer.RegisterWorker( global_store_pb2.RegisterWorkerRequest( @@ -46,6 +78,10 @@ def test_create_broadcast_session_rpc_returns_edges(servicer, test_context, memo global_store_pb2.CreateBroadcastSessionRequest( session_id="session-rpc", artifact_id="mi2:model-rpc", + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_VIEW, + id="view-rpc", + ), epoch=7, fanout=1, strict_parent=True, @@ -63,7 +99,18 @@ def test_create_broadcast_session_rpc_returns_edges(servicer, test_context, memo assert response.status == global_store_pb2.STATUS_OK assert response.session.session_id == "session-rpc" + assert response.session.requested_byte_space.kind == common_pb2.BYTE_SPACE_KIND_VIEW + assert response.session.requested_byte_space.id == "view-rpc" assert response.session.state == global_store_pb2.BROADCAST_SESSION_STATE_ACTIVE + assert len(response.targets) == 1 + assert response.targets[0].target_worker_id == child_worker + get_resp = servicer.GetBroadcastSession( + global_store_pb2.GetBroadcastSessionRequest(session_id="session-rpc"), + test_context, + ) + assert get_resp.status == global_store_pb2.STATUS_OK + assert len(get_resp.targets) == 1 + assert get_resp.targets[0].target_worker_id == child_worker edge_resp = servicer.ListBroadcastEdges( global_store_pb2.ListBroadcastEdgesRequest(session_id="session-rpc"), test_context, @@ -71,3 +118,86 @@ def test_create_broadcast_session_rpc_returns_edges(servicer, test_context, memo assert edge_resp.status == global_store_pb2.STATUS_OK assert len(edge_resp.edges) == 1 assert edge_resp.edges[0].child_worker_id == child_worker + + +def test_create_broadcast_session_accepts_worker_only_and_daemon_only_targets( + servicer, + test_context, + memory_info, +): + root_worker = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-root-mixed", + node_id="node-root-mixed", + node_address="10.20.0.1", + grpc_port=51101, + p2p_port=51102, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + worker_only_target = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-worker-only", + node_id="node-worker-only", + node_address="10.20.0.2", + grpc_port=51201, + p2p_port=51202, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + daemon_only_target = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-daemon-only", + node_id="node-daemon-only", + node_address="10.20.0.3", + grpc_port=51301, + p2p_port=51302, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + memory_info.node_id = "node-root-mixed" + memory_info.node_address = "10.20.0.1" + memory_info.node_port = 51102 + register_resp = servicer.RegisterReplica( + global_store_pb2.RegisterReplicaRequest( + artifact_id="mi2:model-mixed-targets", + worker_id=root_worker, + mem_info=memory_info, + max_concurrency=4, + ), + test_context, + ) + assert register_resp.status == global_store_pb2.STATUS_OK + + response = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-mixed-targets", + artifact_id="mi2:model-mixed-targets", + epoch=9, + fanout=2, + strict_parent=True, + max_attempts=3, + root_replica_id=register_resp.replica_id, + targets=[ + global_store_pb2.BroadcastTargetIdentity( + worker_id=worker_only_target, + ), + global_store_pb2.BroadcastTargetIdentity( + daemon_id="daemon-daemon-only", + ), + ], + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_OK + assert {target.target_worker_id for target in response.targets} == { + worker_only_target, + daemon_only_target, + } diff --git a/tests/python/global_store/test_broadcast_service.py b/tests/python/global_store/test_broadcast_service.py index e14ae7f8..3d14f4fb 100644 --- a/tests/python/global_store/test_broadcast_service.py +++ b/tests/python/global_store/test_broadcast_service.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from tensorcast.global_store.models import ( BroadcastEdgeState, BroadcastSessionState, @@ -89,3 +91,47 @@ def test_create_session_plans_first_layer_by_fanout(repositories): assert all(edge is not None for edge in edges) assert all(edge.state is BroadcastEdgeState.PLANNED for edge in edges if edge) assert all(edge.parent_replica_id == root_replica.replica_id for edge in edges if edge) + + +@pytest.mark.parametrize( + ("overrides", "message"), + [ + ({"session_id": ""}, "session_id is required"), + ({"artifact_id": ""}, "artifact_id is required"), + ({"epoch": -1}, "epoch must be >= 0"), + ], +) +def test_create_session_validates_required_inputs(repositories, overrides, message): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root-validation", "daemon-root-validation", "node1") + child = _worker("worker-child-validation", "daemon-child-validation", "node2") + for worker in (root, child): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create( + _exportable_replica("mi2:model-validation", root) + ) + + kwargs = { + "session_id": "session-validation", + "artifact_id": "mi2:model-validation", + "requested_view_id": None, + "epoch": 1, + "fanout": 1, + "target_daemon_ids": ["daemon-child-validation"], + "root_replica_id": str(root_replica.replica_id), + "strict_parent": True, + "max_attempts": 3, + } + kwargs.update(overrides) + + with pytest.raises(ValueError, match=message): + service.create_session(**kwargs) + assert broadcast_repo.find_session("session-validation") is None From 9d78a4cf3ce2db93d86e023f226d8e49dc602746 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:28:52 +0800 Subject: [PATCH 25/49] fix(global-store): stabilize broadcast create proto fields --- proto/tensorcast/global_store/v1/global_store.proto | 8 ++++---- tests/python/global_store/test_broadcast_rpc.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/proto/tensorcast/global_store/v1/global_store.proto b/proto/tensorcast/global_store/v1/global_store.proto index f0c884b2..28f15d8a 100644 --- a/proto/tensorcast/global_store/v1/global_store.proto +++ b/proto/tensorcast/global_store/v1/global_store.proto @@ -1151,10 +1151,10 @@ message CreateBroadcastSessionRequest { tensorcast.common.v1.ByteSpaceRef requested_byte_space = 3; uint64 epoch = 4; uint32 fanout = 5; - bool strict_parent = 6; - uint32 max_attempts = 7; - string root_replica_id = 8; - repeated BroadcastTargetIdentity targets = 9; + repeated BroadcastTargetIdentity targets = 6; + string root_replica_id = 7; + bool strict_parent = 8; + uint32 max_attempts = 9; } message CreateBroadcastSessionResponse { diff --git a/tests/python/global_store/test_broadcast_rpc.py b/tests/python/global_store/test_broadcast_rpc.py index 6107bb58..4001a3c6 100644 --- a/tests/python/global_store/test_broadcast_rpc.py +++ b/tests/python/global_store/test_broadcast_rpc.py @@ -29,6 +29,10 @@ def test_broadcast_proto_contract_uses_byte_space_and_targets(): create_fields["requested_byte_space"].message_type.full_name == "tensorcast.common.v1.ByteSpaceRef" ) + assert create_fields["targets"].number == 6 + assert create_fields["root_replica_id"].number == 7 + assert create_fields["strict_parent"].number == 8 + assert create_fields["max_attempts"].number == 9 assert create_response_fields["targets"].number == 3 assert create_response_fields["edges"].number == 4 assert get_response_fields["targets"].number == 3 From 323075c4ab6a3c01a50e8e08666e76860f2ba3c7 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:39:41 +0800 Subject: [PATCH 26/49] fix(global-store): harden broadcast session creation --- .../repositories/broadcast_repository.py | 28 +++ .../global_store/rpc/broadcast_rpc_handler.py | 67 +++++-- .../services/broadcast_service.py | 169 ++++++++++++------ .../global_store/test_broadcast_repository.py | 2 + .../python/global_store/test_broadcast_rpc.py | 111 ++++++++++++ .../global_store/test_broadcast_service.py | 141 +++++++++++++++ 6 files changed, 449 insertions(+), 69 deletions(-) diff --git a/tensorcast/global_store/repositories/broadcast_repository.py b/tensorcast/global_store/repositories/broadcast_repository.py index 28b2f929..80094b3d 100644 --- a/tensorcast/global_store/repositories/broadcast_repository.py +++ b/tensorcast/global_store/repositories/broadcast_repository.py @@ -382,6 +382,34 @@ def find_edge( if owns_cursor: cursor.close() + def list_edges( + self, + session_id: str, + cursor: DuckDBPyConnection | None = None, + ) -> list[BroadcastEdge]: + """List all broadcast edges for a session.""" + normalized_session_id = self._normalize_required_text(session_id) + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = cursor.execute( + f""" + SELECT {self._EDGE_PROJECTION} + FROM broadcast_edges + WHERE session_id = ? + ORDER BY level ASC, created_at ASC, edge_id ASC + """, + [normalized_session_id], + ) + rows = query.fetchall() + assert query.description is not None + columns = [desc[0] for desc in query.description] + return [self._row_to_edge(row, columns) for row in rows] + finally: + if owns_cursor: + cursor.close() + def find_active_edge_for_child( self, session_id: str, diff --git a/tensorcast/global_store/rpc/broadcast_rpc_handler.py b/tensorcast/global_store/rpc/broadcast_rpc_handler.py index 7523e039..f2f6c015 100644 --- a/tensorcast/global_store/rpc/broadcast_rpc_handler.py +++ b/tensorcast/global_store/rpc/broadcast_rpc_handler.py @@ -63,23 +63,21 @@ def create_broadcast_session( ) -> global_store_pb2.CreateBroadcastSessionResponse: try: requested_view_id = self._requested_view_id_from_byte_space(request) - target_worker_ids: list[str] = [] - target_daemon_ids: list[str] = [] + target_identities: list[tuple[str, str]] = [] for target in request.targets: worker_id = target.worker_id.strip() daemon_id = target.daemon_id.strip() - if worker_id: - target_worker_ids.append(worker_id) - if daemon_id: - target_daemon_ids.append(daemon_id) + if worker_id or daemon_id: + target_identities.append((worker_id, daemon_id)) session = self._broadcast_service.create_session( session_id=request.session_id, artifact_id=request.artifact_id, requested_view_id=requested_view_id, epoch=int(request.epoch), fanout=int(request.fanout), - target_worker_ids=target_worker_ids, - target_daemon_ids=target_daemon_ids, + target_worker_ids=(), + target_daemon_ids=(), + target_identities=target_identities, root_replica_id=request.root_replica_id, strict_parent=bool(request.strict_parent), max_attempts=int(request.max_attempts), @@ -111,8 +109,13 @@ def get_broadcast_session( request: global_store_pb2.GetBroadcastSessionRequest, context: grpc.ServicerContext, ) -> global_store_pb2.GetBroadcastSessionResponse: + session_id = self._required_session_id(request.session_id, context) + if session_id is None: + return global_store_pb2.GetBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR + ) try: - session = self._broadcast_service.get_session(request.session_id) + session = self._broadcast_service.get_session(session_id) if session is None: return global_store_pb2.GetBroadcastSessionResponse( status=global_store_pb2.STATUS_NOT_FOUND @@ -127,6 +130,12 @@ def get_broadcast_session( ) ], ) + except ValueError as exc: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details(str(exc)) + return global_store_pb2.GetBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR + ) except Exception as exc: # noqa: BLE001 self._logger.exception("GetBroadcastSession failed") context.set_code(grpc.StatusCode.INTERNAL) @@ -140,14 +149,25 @@ def list_broadcast_edges( request: global_store_pb2.ListBroadcastEdgesRequest, context: grpc.ServicerContext, ) -> global_store_pb2.ListBroadcastEdgesResponse: + session_id = self._required_session_id(request.session_id, context) + if session_id is None: + return global_store_pb2.ListBroadcastEdgesResponse( + status=global_store_pb2.STATUS_ERROR + ) try: return global_store_pb2.ListBroadcastEdgesResponse( status=global_store_pb2.STATUS_OK, edges=[ self._edge_to_proto(edge) - for edge in self._broadcast_service.list_edges(request.session_id) + for edge in self._broadcast_service.list_edges(session_id) ], ) + except ValueError as exc: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details(str(exc)) + return global_store_pb2.ListBroadcastEdgesResponse( + status=global_store_pb2.STATUS_ERROR + ) except Exception as exc: # noqa: BLE001 self._logger.exception("ListBroadcastEdges failed") context.set_code(grpc.StatusCode.INTERNAL) @@ -161,8 +181,14 @@ def cancel_broadcast_session( request: global_store_pb2.CancelBroadcastSessionRequest, context: grpc.ServicerContext, ) -> global_store_pb2.CancelBroadcastSessionResponse: + session_id = self._required_session_id(request.session_id, context) + if session_id is None: + return global_store_pb2.CancelBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR, + cancelled=False, + ) try: - cancelled = self._broadcast_service.cancel_session(request.session_id) + cancelled = self._broadcast_service.cancel_session(session_id) return global_store_pb2.CancelBroadcastSessionResponse( status=( global_store_pb2.STATUS_OK @@ -171,6 +197,13 @@ def cancel_broadcast_session( ), cancelled=cancelled, ) + except ValueError as exc: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details(str(exc)) + return global_store_pb2.CancelBroadcastSessionResponse( + status=global_store_pb2.STATUS_ERROR, + cancelled=False, + ) except Exception as exc: # noqa: BLE001 self._logger.exception("CancelBroadcastSession failed") context.set_code(grpc.StatusCode.INTERNAL) @@ -274,3 +307,15 @@ def _requested_view_id_from_byte_space( raise ValueError("requested_byte_space VIEW requires id") return view_id return None + + @staticmethod + def _required_session_id( + raw_session_id: str, + context: grpc.ServicerContext, + ) -> str | None: + session_id = raw_session_id.strip() + if session_id: + return session_id + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("session_id is required") + return None diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index 5dfa6379..f562c45c 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -4,8 +4,10 @@ from __future__ import annotations +from collections.abc import Sequence from uuid import UUID, uuid4 +from tensorcast.global_store.exceptions import DatabaseError from tensorcast.global_store.models import ( BroadcastEdge, BroadcastEdgeState, @@ -52,6 +54,7 @@ def create_session( strict_parent: bool, max_attempts: int, target_worker_ids: list[str] | tuple[str, ...] | None = None, + target_identities: Sequence[tuple[str, str]] | None = None, ) -> BroadcastSession: """Create a broadcast session and reserve the first planned edges.""" session_id = str(session_id).strip() @@ -67,45 +70,65 @@ def create_session( if max_attempts <= 0: raise ValueError("max_attempts must be > 0") + existing = self._broadcast_repository.find_session(session_id) + if existing is not None: + return existing + targets = self._resolve_targets( target_worker_ids=target_worker_ids or (), target_daemon_ids=target_daemon_ids, + target_identities=target_identities or (), ) - root_replica, selected_root = self._resolve_root_replica( - artifact_id=artifact_id, - requested_view_id=requested_view_id, - root_replica_id=root_replica_id, - ) - - session = BroadcastSession( - session_id=session_id, - artifact_id=artifact_id, - requested_view_id=requested_view_id, - epoch=int(epoch), - fanout=int(fanout), - max_attempts=int(max_attempts), - strict_parent=bool(strict_parent), - state=BroadcastSessionState.ACTIVE, - root_replica_id=root_replica.replica_id, - ) - self._broadcast_repository.create_session(session) - - for worker in targets: - self._broadcast_repository.upsert_target( - BroadcastTarget( - session_id=session.session_id, - target_worker_id=worker.worker_id, - target_daemon_id=worker.daemon_id, - state=BroadcastTargetState.PENDING, + root_replica: Replica | None = None + selected_root = False + try: + try: + root_replica, selected_root = self._resolve_root_replica( + artifact_id=artifact_id, + requested_view_id=requested_view_id, + root_replica_id=root_replica_id, ) - ) - self._plan_more_edges(session) + session = BroadcastSession( + session_id=session_id, + artifact_id=artifact_id, + requested_view_id=requested_view_id, + epoch=int(epoch), + fanout=int(fanout), + max_attempts=int(max_attempts), + strict_parent=bool(strict_parent), + state=BroadcastSessionState.ACTIVE, + root_replica_id=root_replica.replica_id, + ) + with self._broadcast_repository.transaction() as tx: + existing = self._broadcast_repository.find_session( + session_id, + cursor=tx, + ) + if existing is not None: + return existing + self._broadcast_repository.create_session(session, cursor=tx) + for worker in targets: + self._broadcast_repository.upsert_target( + BroadcastTarget( + session_id=session.session_id, + target_worker_id=worker.worker_id, + target_daemon_id=worker.daemon_id, + state=BroadcastTargetState.PENDING, + ), + cursor=tx, + ) - if selected_root: - self._replica_repository.decrement_requests(root_replica.replica_id) + self._plan_more_edges(session, cursor=tx) - return session + return session + except DatabaseError as exc: + if exc.__cause__ is not None: + raise exc.__cause__ from exc + raise + finally: + if selected_root and root_replica is not None: + self._replica_repository.decrement_requests(root_replica.replica_id) def get_session(self, session_id: str) -> BroadcastSession | None: """Return a broadcast session by ID.""" @@ -113,25 +136,7 @@ def get_session(self, session_id: str) -> BroadcastSession | None: def list_edges(self, session_id: str) -> list[BroadcastEdge]: """List broadcast edges for a session.""" - cursor = self._broadcast_repository.get_cursor() - try: - query = cursor.execute( - f""" - SELECT {self._broadcast_repository._EDGE_PROJECTION} - FROM broadcast_edges - WHERE session_id = ? - ORDER BY level ASC, created_at ASC, edge_id ASC - """, - [session_id], - ) - rows = query.fetchall() - assert query.description is not None - columns = [desc[0] for desc in query.description] - return [ - self._broadcast_repository._row_to_edge(row, columns) for row in rows - ] - finally: - cursor.close() + return self._broadcast_repository.list_edges(session_id) def list_targets(self, session_id: str) -> list[BroadcastTarget]: """List broadcast targets for a session.""" @@ -149,8 +154,43 @@ def _resolve_targets( *, target_worker_ids: list[str] | tuple[str, ...], target_daemon_ids: list[str] | tuple[str, ...], + target_identities: Sequence[tuple[str, str]], ) -> list[Worker]: targets: dict[str, Worker] = {} + for raw_worker_id, raw_daemon_id in target_identities: + worker_id = str(raw_worker_id).strip() + daemon_id = str(raw_daemon_id).strip() + if not worker_id and not daemon_id: + continue + + worker_by_id: Worker | None = None + worker_by_daemon: Worker | None = None + if worker_id: + worker_by_id = self._worker_repository.find_by_id( + worker_id, + include_inactive=False, + ) + if worker_by_id is None: + raise ValueError(f"target worker not found: {worker_id}") + if daemon_id: + worker_by_daemon = self._worker_repository.find_by_daemon_id( + daemon_id, + include_inactive=False, + ) + if worker_by_daemon is None: + raise ValueError(f"target daemon not found: {daemon_id}") + if ( + worker_by_id is not None + and worker_by_daemon is not None + and worker_by_id.worker_id != worker_by_daemon.worker_id + ): + raise ValueError( + "target worker_id and daemon_id resolve to different workers" + ) + worker = worker_by_id or worker_by_daemon + if worker is not None: + targets[worker.worker_id] = worker + for worker_id in target_worker_ids: worker_id = str(worker_id).strip() if not worker_id: @@ -205,21 +245,27 @@ def _resolve_root_replica( raise ValueError(f"no available root replica for artifact: {artifact_id}") return result.replica, True - def _plan_more_edges(self, session: BroadcastSession) -> list[BroadcastEdge]: + def _plan_more_edges( + self, + session: BroadcastSession, + *, + cursor=None, + ) -> list[BroadcastEdge]: pending_targets = self._broadcast_repository.list_targets_by_state( session.session_id, BroadcastTargetState.PENDING, limit=max(0, int(session.fanout)), + cursor=cursor, ) if not pending_targets or session.root_replica_id is None: return [] - active_edges_count = self._count_active_edges(session.session_id) + active_edges_count = self._count_active_edges(session.session_id, cursor=cursor) capacity = max(0, int(session.fanout) - active_edges_count) if capacity <= 0: return [] - parent_pool = self._parent_pool(session) + parent_pool = self._parent_pool(session, cursor=cursor) if not parent_pool: return [] @@ -236,17 +282,19 @@ def _plan_more_edges(self, session: BroadcastSession) -> list[BroadcastEdge]: attempt=target.attempt + 1, state=BroadcastEdgeState.PLANNED, ) - self._broadcast_repository.create_edge(edge) + self._broadcast_repository.create_edge(edge, cursor=cursor) target.state = BroadcastTargetState.ASSIGNED target.level = edge.level target.attempt = edge.attempt target.assigned_edge_id = edge.edge_id - self._broadcast_repository.upsert_target(target) + self._broadcast_repository.upsert_target(target, cursor=cursor) planned.append(edge) return planned - def _count_active_edges(self, session_id: str) -> int: - cursor = self._broadcast_repository.get_cursor() + def _count_active_edges(self, session_id: str, *, cursor=None) -> int: + owns_cursor = cursor is None + if owns_cursor: + cursor = self._broadcast_repository.get_cursor() try: row = cursor.execute( """ @@ -259,17 +307,21 @@ def _count_active_edges(self, session_id: str) -> int: ).fetchone() return int(row[0]) if row is not None else 0 finally: - cursor.close() + if owns_cursor: + cursor.close() def _parent_pool( self, session: BroadcastSession, + *, + cursor=None, ) -> list[tuple[Replica, int]]: parents: list[tuple[Replica, int]] = [] if session.root_replica_id is not None: root = self._replica_repository.find_by_id( session.root_replica_id, session.artifact_id, + cursor=cursor, ) if root is not None: parents.append((root, 0)) @@ -278,6 +330,7 @@ def _parent_pool( session.session_id, BroadcastTargetState.COMPLETED, limit=10_000, + cursor=cursor, ) for target in completed_targets: if target.completed_replica_id is None: diff --git a/tests/python/global_store/test_broadcast_repository.py b/tests/python/global_store/test_broadcast_repository.py index 1c4b0e42..069029ed 100644 --- a/tests/python/global_store/test_broadcast_repository.py +++ b/tests/python/global_store/test_broadcast_repository.py @@ -95,6 +95,8 @@ def test_broadcast_repository_creates_session_targets_and_edges(db_connection): assert edge is not None assert edge.parent_worker_id == "worker-root" assert edge.state is BroadcastEdgeState.PLANNED + edges = repo.list_edges("session-a") + assert [edge.edge_id for edge in edges] == ["edge-1"] def test_broadcast_repository_prevents_two_active_edges_for_child(db_connection): diff --git a/tests/python/global_store/test_broadcast_rpc.py b/tests/python/global_store/test_broadcast_rpc.py index 4001a3c6..40dc5d3c 100644 --- a/tests/python/global_store/test_broadcast_rpc.py +++ b/tests/python/global_store/test_broadcast_rpc.py @@ -1,5 +1,7 @@ from __future__ import annotations +import grpc + from tensorcast.proto.common.v1 import common_pb2 from tensorcast.proto.global_store.v1 import global_store_pb2 @@ -205,3 +207,112 @@ def test_create_broadcast_session_accepts_worker_only_and_daemon_only_targets( worker_only_target, daemon_only_target, } + assert ( + response.session.requested_byte_space.kind + == common_pb2.BYTE_SPACE_KIND_CANONICAL + ) + + +def test_blank_broadcast_session_ids_are_invalid(servicer, test_context): + get_resp = servicer.GetBroadcastSession( + global_store_pb2.GetBroadcastSessionRequest(session_id=" "), + test_context, + ) + assert get_resp.status == global_store_pb2.STATUS_ERROR + assert test_context.code == grpc.StatusCode.INVALID_ARGUMENT + + test_context.code = None + test_context.details = None + list_resp = servicer.ListBroadcastEdges( + global_store_pb2.ListBroadcastEdgesRequest(session_id=" "), + test_context, + ) + assert list_resp.status == global_store_pb2.STATUS_ERROR + assert test_context.code == grpc.StatusCode.INVALID_ARGUMENT + + test_context.code = None + test_context.details = None + cancel_resp = servicer.CancelBroadcastSession( + global_store_pb2.CancelBroadcastSessionRequest(session_id=" "), + test_context, + ) + assert cancel_resp.status == global_store_pb2.STATUS_ERROR + assert test_context.code == grpc.StatusCode.INVALID_ARGUMENT + + +def test_create_broadcast_session_rejects_mismatched_target_identity( + servicer, + test_context, + memory_info, +): + root_worker = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-root-mismatch", + node_id="node-root-mismatch", + node_address="10.30.0.1", + grpc_port=52101, + p2p_port=52102, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + worker_target = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-worker-mismatch", + node_id="node-worker-mismatch", + node_address="10.30.0.2", + grpc_port=52201, + p2p_port=52202, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ).worker_id + servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id="daemon-other-mismatch", + node_id="node-other-mismatch", + node_address="10.30.0.3", + grpc_port=52301, + p2p_port=52302, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + test_context, + ) + memory_info.node_id = "node-root-mismatch" + memory_info.node_address = "10.30.0.1" + memory_info.node_port = 52102 + register_resp = servicer.RegisterReplica( + global_store_pb2.RegisterReplicaRequest( + artifact_id="mi2:model-mismatch", + worker_id=root_worker, + mem_info=memory_info, + max_concurrency=4, + ), + test_context, + ) + assert register_resp.status == global_store_pb2.STATUS_OK + + response = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-mismatch", + artifact_id="mi2:model-mismatch", + epoch=1, + fanout=1, + strict_parent=True, + max_attempts=3, + root_replica_id=register_resp.replica_id, + targets=[ + global_store_pb2.BroadcastTargetIdentity( + worker_id=worker_target, + daemon_id="daemon-other-mismatch", + ) + ], + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_ERROR + assert test_context.code == grpc.StatusCode.INVALID_ARGUMENT diff --git a/tests/python/global_store/test_broadcast_service.py b/tests/python/global_store/test_broadcast_service.py index 3d14f4fb..5793d454 100644 --- a/tests/python/global_store/test_broadcast_service.py +++ b/tests/python/global_store/test_broadcast_service.py @@ -91,6 +91,147 @@ def test_create_session_plans_first_layer_by_fanout(repositories): assert all(edge is not None for edge in edges) assert all(edge.state is BroadcastEdgeState.PLANNED for edge in edges if edge) assert all(edge.parent_replica_id == root_replica.replica_id for edge in edges if edge) + assert len(service.list_edges("session-a")) == 2 + + +def test_create_session_duplicate_explicit_root_returns_existing_without_counter_change( + repositories, +): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root-explicit", "daemon-root-explicit", "node1") + child = _worker("worker-child-explicit", "daemon-child-explicit", "node2") + for worker in (root, child): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-explicit", root)) + + first = service.create_session( + session_id="session-explicit", + artifact_id="mi2:model-explicit", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=["daemon-child-explicit"], + root_replica_id=str(root_replica.replica_id), + strict_parent=True, + max_attempts=3, + ) + assert replica_repo.get_current_requests(root_replica.replica_id) == 0 + + second = service.create_session( + session_id="session-explicit", + artifact_id="mi2:model-explicit", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=["daemon-child-explicit"], + root_replica_id=str(root_replica.replica_id), + strict_parent=True, + max_attempts=3, + ) + + assert second.session_id == first.session_id + assert replica_repo.get_current_requests(root_replica.replica_id) == 0 + assert len(broadcast_repo.list_targets("session-explicit")) == 1 + assert len(service.list_edges("session-explicit")) == 1 + + +def test_create_session_duplicate_auto_root_returns_existing_without_counter_change( + repositories, +): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root-auto", "daemon-root-auto", "node1") + child = _worker("worker-child-auto", "daemon-child-auto", "node2") + for worker in (root, child): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-auto", root)) + + first = service.create_session( + session_id="session-auto", + artifact_id="mi2:model-auto", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=["daemon-child-auto"], + root_replica_id="", + strict_parent=True, + max_attempts=3, + ) + assert replica_repo.get_current_requests(root_replica.replica_id) == 0 + + second = service.create_session( + session_id="session-auto", + artifact_id="mi2:model-auto", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=["daemon-child-auto"], + root_replica_id="", + strict_parent=True, + max_attempts=3, + ) + + assert second.session_id == first.session_id + assert replica_repo.get_current_requests(root_replica.replica_id) == 0 + assert len(broadcast_repo.list_targets("session-auto")) == 1 + assert len(service.list_edges("session-auto")) == 1 + + +def test_create_session_auto_root_failure_releases_counter_and_rolls_back( + repositories, + monkeypatch, +): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root-failure", "daemon-root-failure", "node1") + child = _worker("worker-child-failure", "daemon-child-failure", "node2") + for worker in (root, child): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-failure", root)) + + def fail_planning(*args, **kwargs): + raise RuntimeError("forced planning failure") + + monkeypatch.setattr(service, "_plan_more_edges", fail_planning) + + with pytest.raises(RuntimeError, match="forced planning failure"): + service.create_session( + session_id="session-failure", + artifact_id="mi2:model-failure", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=["daemon-child-failure"], + root_replica_id=None, + strict_parent=True, + max_attempts=3, + ) + + assert replica_repo.get_current_requests(root_replica.replica_id) == 0 + assert broadcast_repo.find_session("session-failure") is None + assert broadcast_repo.list_targets("session-failure") == [] @pytest.mark.parametrize( From dbea72e2a5ea09aeddda7c4f2a2a2dd229fa1383 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:56:50 +0800 Subject: [PATCH 27/49] feat(global-store): route broadcast transports through tree edges --- .../global_store/v1/global_store.proto | 6 + tensorcast/global_store/grpc_service.py | 11 +- tensorcast/global_store/models/__init__.py | 2 + tensorcast/global_store/models/transport.py | 10 + .../repositories/replica_repository.py | 152 ++++++++++++ .../repositories/transport_repository.py | 43 +++- .../global_store/rpc/transport_rpc_handler.py | 16 ++ .../services/broadcast_service.py | 163 ++++++++++++- .../services/transport_service.py | 136 +++++++++++ .../global_store/test_broadcast_transport.py | 221 ++++++++++++++++++ 10 files changed, 744 insertions(+), 16 deletions(-) create mode 100644 tests/python/global_store/test_broadcast_transport.py diff --git a/proto/tensorcast/global_store/v1/global_store.proto b/proto/tensorcast/global_store/v1/global_store.proto index 28f15d8a..9ecd9c87 100644 --- a/proto/tensorcast/global_store/v1/global_store.proto +++ b/proto/tensorcast/global_store/v1/global_store.proto @@ -1031,6 +1031,7 @@ message RequestReplicaTransportRequest { TransportSchedulingGroup scheduling_group = 8; string requester_worker_id = 9; string request_id = 10; + BroadcastTransportHint broadcast = 11; } message RequestReplicaTransportResponse { @@ -1048,6 +1049,11 @@ message TransportSchedulingGroup { uint64 epoch = 6; } +message BroadcastTransportHint { + string session_id = 1; + bool strict_parent = 2; +} + enum TransportCompletionOutcome { TRANSPORT_COMPLETION_OUTCOME_UNSPECIFIED = 0; TRANSPORT_COMPLETION_OUTCOME_SUCCESS = 1; diff --git a/tensorcast/global_store/grpc_service.py b/tensorcast/global_store/grpc_service.py index e0c45def..c49cdc28 100644 --- a/tensorcast/global_store/grpc_service.py +++ b/tensorcast/global_store/grpc_service.py @@ -466,10 +466,16 @@ def _rebuild_runtime_services_and_handlers(self) -> None: control_reducer=self.worker_control_reducer, logger=logger, ) + self.broadcast_service = BroadcastService( + broadcast_repository=self.broadcast_repository, + replica_repository=self.replica_repository, + worker_repository=self.worker_repository, + ) self.transport_service = TransportService( self.replica_repository, self.transport_repository, self.pending_transport_request_repository, + broadcast_service=self.broadcast_service, ) self.transport_rpc_handler = TransportRpcHandler( transport_service=self.transport_service, @@ -551,11 +557,6 @@ def _rebuild_runtime_services_and_handlers(self) -> None: datetime_to_timestamp=datetime_to_timestamp, logger=logger, ) - self.broadcast_service = BroadcastService( - broadcast_repository=self.broadcast_repository, - replica_repository=self.replica_repository, - worker_repository=self.worker_repository, - ) self.broadcast_rpc_handler = BroadcastRpcHandler( broadcast_service=self.broadcast_service, logger=logger, diff --git a/tensorcast/global_store/models/__init__.py b/tensorcast/global_store/models/__init__.py index 62933c95..ef47fbbf 100644 --- a/tensorcast/global_store/models/__init__.py +++ b/tensorcast/global_store/models/__init__.py @@ -25,6 +25,7 @@ from .replica import ByteSpaceKind, ByteSpaceRef, ExportState, MemoryType, Replica from .shard_home_lease import ShardHomeLease from .transport import ( + BroadcastTransportHint, Transport, TransportCompletionOutcome, TransportSchedulingGroup, @@ -40,6 +41,7 @@ "BroadcastTargetState", "Instance", "Replica", + "BroadcastTransportHint", "Transport", "TransportCompletionOutcome", "TransportSchedulingGroup", diff --git a/tensorcast/global_store/models/transport.py b/tensorcast/global_store/models/transport.py index 4e2459a9..87e0f7d3 100644 --- a/tensorcast/global_store/models/transport.py +++ b/tensorcast/global_store/models/transport.py @@ -18,6 +18,14 @@ class TransportCompletionOutcome(str, Enum): CANCELLED = "cancelled" +@dataclass(frozen=True) +class BroadcastTransportHint: + """Optional broadcast-tree transport routing hint.""" + + session_id: str + strict_parent: bool = True + + @dataclass(frozen=True) class TransportSchedulingGroup: """Optional scheduling-group metadata attached to transport requests.""" @@ -50,6 +58,8 @@ class Transport: requester_worker_id: str | None = None request_id: str | None = None request_fingerprint: str | None = None + broadcast_session_id: str | None = None + broadcast_edge_id: str | None = None # Optional scheduling-group metadata group_id: str | None = None diff --git a/tensorcast/global_store/repositories/replica_repository.py b/tensorcast/global_store/repositories/replica_repository.py index ba9817be..6de4ab77 100644 --- a/tensorcast/global_store/repositories/replica_repository.py +++ b/tensorcast/global_store/repositories/replica_repository.py @@ -584,6 +584,158 @@ def find_available_for_transport( if owns_cursor: cursor.close() + def claim_replica_for_transport( + self, + *, + replica_id: UUID, + artifact_id: str, + view_id: str | None, + heartbeat_timeout_seconds: float, + cursor=None, + ) -> TransportSelectionResult: + """Claim one exact replica if it is currently transport eligible.""" + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = ( + "SELECT " + + self._REPLICA_PROJECTION + + ", COALESCE(w.worker_id, '') AS gs_worker_id, " + + "wl.accepting_new_requests AS worker_accepting, " + + "wl.last_heartbeat AS worker_last_heartbeat, " + + "w.inactive_at AS worker_inactive_at " + + "FROM artifact_replicas mr " + + "LEFT JOIN replica_counters rc ON rc.replica_id = mr.replica_id " + + "LEFT JOIN workers w ON mr.worker_id = w.worker_id " + + "LEFT JOIN worker_liveness wl ON wl.worker_id = w.worker_id " + + "WHERE mr.replica_id = ? " + + "AND mr.artifact_id = ? " + + "AND COALESCE(mr.view_id, '') = COALESCE(?, '')" + ) + result = cursor.execute(query, [str(replica_id), artifact_id, view_id or ""]) + row = result.fetchone() + if row is None: + return TransportSelectionResult(replica=None, exportable_replicas=0) + + assert result.description is not None + columns = [desc[0] for desc in result.description] + candidate = self._build_transport_candidate( + row, + columns, + now_ts=time.time(), + heartbeat_timeout_seconds=heartbeat_timeout_seconds, + ) + transport_ok, _ = self._evaluate_transport_metadata(candidate.replica) + exportable_replicas = 1 if transport_ok else 0 + eligible, reason = self._evaluate_transport_candidate(candidate) + if not eligible: + inc_transport_filter(artifact_id, reason) + return TransportSelectionResult( + replica=None, + exportable_replicas=exportable_replicas, + ) + + claim = cursor.execute( + """ + UPDATE replica_counters + SET current_requests = current_requests + 1, + last_assigned_at = now() + WHERE replica_id = ? + AND current_requests < ( + SELECT max_concurrency FROM artifact_replicas WHERE replica_id = ? + ) + RETURNING current_requests + """, + [str(replica_id), str(replica_id)], + ).fetchone() + if not claim: + return TransportSelectionResult( + replica=None, + exportable_replicas=exportable_replicas, + ) + + full_result = cursor.execute( + self._replica_select_sql("JOIN") + " WHERE mr.replica_id = ?", + [str(replica_id)], + ) + full_row = full_result.fetchone() + if full_row is None: + return TransportSelectionResult( + replica=None, + exportable_replicas=exportable_replicas, + ) + assert full_result.description is not None + full_columns = [desc[0] for desc in full_result.description] + return TransportSelectionResult( + replica=self._row_to_model(full_row, full_columns), + exportable_replicas=exportable_replicas, + ) + finally: + if owns_cursor: + cursor.close() + + def find_exportable_replica_for_worker( + self, + *, + artifact_id: str, + view_id: str | None, + worker_id: str, + heartbeat_timeout_seconds: float, + cursor=None, + ) -> Replica | None: + """Return the best registered child replica after materialization completes.""" + owns_cursor = cursor is None + if owns_cursor: + cursor = self.get_cursor() + try: + query = ( + "SELECT " + + self._REPLICA_PROJECTION + + ", COALESCE(w.worker_id, '') AS gs_worker_id, " + + "wl.accepting_new_requests AS worker_accepting, " + + "wl.last_heartbeat AS worker_last_heartbeat, " + + "w.inactive_at AS worker_inactive_at " + + "FROM artifact_replicas mr " + + "LEFT JOIN replica_counters rc ON rc.replica_id = mr.replica_id " + + "LEFT JOIN workers w ON mr.worker_id = w.worker_id " + + "LEFT JOIN worker_liveness wl ON wl.worker_id = w.worker_id " + + "WHERE mr.artifact_id = ? " + + "AND COALESCE(mr.view_id, '') = COALESCE(?, '') " + + "AND mr.worker_id = ? " + + "ORDER BY " + + "CASE " + + "WHEN mr.memory_type = 'GPU' THEN 0 " + + "WHEN mr.memory_type = 'RAM' THEN 1 " + + "WHEN mr.memory_type = 'DISK' THEN 2 " + + "ELSE 3 " + + "END, " + + "COALESCE(rc.current_requests, 0) ASC, " + + "mr.updated_at DESC" + ) + result = cursor.execute(query, [artifact_id, view_id or "", worker_id]) + rows = result.fetchall() + if not rows: + return None + + assert result.description is not None + columns = [desc[0] for desc in result.description] + now_ts = time.time() + for row in rows: + candidate = self._build_transport_candidate( + row, + columns, + now_ts=now_ts, + heartbeat_timeout_seconds=heartbeat_timeout_seconds, + ) + eligible, _ = self._evaluate_transport_candidate(candidate) + if eligible: + return candidate.replica + return None + finally: + if owns_cursor: + cursor.close() + @staticmethod def _memory_priority(memory_type: MemoryType) -> int: if memory_type is MemoryType.GPU: diff --git a/tensorcast/global_store/repositories/transport_repository.py b/tensorcast/global_store/repositories/transport_repository.py index bc1a1e04..26bebb2b 100644 --- a/tensorcast/global_store/repositories/transport_repository.py +++ b/tensorcast/global_store/repositories/transport_repository.py @@ -37,6 +37,8 @@ class TransportWindowRow: completion_outcome: str request_id: str requester_worker_id: str + broadcast_session_id: str + broadcast_edge_id: str group_id: str group_kind: str group_part_id: str @@ -55,7 +57,8 @@ class TransportRepository(BaseRepository): "request_id, request_fingerprint, " "requester_worker_id, group_id, group_kind, group_total_parts, " "group_part_id, group_priority, group_epoch, completion_outcome, " - "completion_detail, created_at, completed_at, status" + "completion_detail, created_at, completed_at, status, " + "broadcast_session_id, broadcast_edge_id" ) def find_by_id(self, transport_id: UUID, cursor=None) -> Transport | None: @@ -134,9 +137,10 @@ def create_with_cursor(self, transport: Transport, cursor) -> Transport: request_id, request_fingerprint, requester_worker_id, group_id, group_kind, group_total_parts, group_part_id, group_priority, group_epoch, + broadcast_session_id, broadcast_edge_id, status ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, [ str(transport.transport_id), @@ -157,6 +161,8 @@ def create_with_cursor(self, transport: Transport, cursor) -> Transport: self._normalize_optional_text(transport.group_part_id), self._normalize_optional_int(transport.group_priority), self._normalize_optional_int(transport.group_epoch), + self._normalize_optional_text(transport.broadcast_session_id), + self._normalize_optional_text(transport.broadcast_edge_id), "in_progress", ], ) @@ -205,9 +211,10 @@ def create_if_absent_with_cursor( request_id, request_fingerprint, requester_worker_id, group_id, group_kind, group_total_parts, group_part_id, group_priority, group_epoch, + broadcast_session_id, broadcast_edge_id, status ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, [ str(transport.transport_id), @@ -228,6 +235,8 @@ def create_if_absent_with_cursor( self._normalize_optional_text(transport.group_part_id), self._normalize_optional_int(transport.group_priority), self._normalize_optional_int(transport.group_epoch), + self._normalize_optional_text(transport.broadcast_session_id), + self._normalize_optional_text(transport.broadcast_edge_id), "in_progress", ], ) @@ -460,6 +469,8 @@ def list_rows_in_created_window( COALESCE(t.completion_outcome, '') AS completion_outcome, COALESCE(t.request_id, '') AS request_id, COALESCE(t.requester_worker_id, '') AS requester_worker_id, + COALESCE(t.broadcast_session_id, '') AS broadcast_session_id, + COALESCE(t.broadcast_edge_id, '') AS broadcast_edge_id, COALESCE(t.group_id, '') AS group_id, COALESCE(t.group_kind, '') AS group_kind, COALESCE(t.group_part_id, '') AS group_part_id, @@ -477,8 +488,8 @@ def list_rows_in_created_window( ).fetchall() result: list[TransportWindowRow] = [] for row in rows: - created_at = self._coerce_datetime(row[11]) - completed_at = self._coerce_datetime_optional(row[12]) + created_at = self._coerce_datetime(row[13]) + completed_at = self._coerce_datetime_optional(row[14]) result.append( TransportWindowRow( transport_id=str(row[0] or ""), @@ -493,15 +504,23 @@ def list_rows_in_created_window( str(row[6] or "") ) or "", - group_id=self._normalize_optional_text(str(row[7] or "")) or "", - group_kind=self._normalize_optional_text(str(row[8] or "")) + broadcast_session_id=self._normalize_optional_text( + str(row[7] or "") + ) + or "", + broadcast_edge_id=self._normalize_optional_text( + str(row[8] or "") + ) or "", - group_part_id=self._normalize_optional_text(str(row[9] or "")) + group_id=self._normalize_optional_text(str(row[9] or "")) or "", + group_kind=self._normalize_optional_text(str(row[10] or "")) or "", - group_total_parts=int(row[10] or 0), + group_part_id=self._normalize_optional_text(str(row[11] or "")) + or "", + group_total_parts=int(row[12] or 0), created_at=created_at, completed_at=completed_at, - replica_memory_size_bytes=int(row[13] or 0), + replica_memory_size_bytes=int(row[15] or 0), ) ) return result @@ -736,6 +755,10 @@ def get(column: str, default=None): group_part_id=self._normalize_optional_text(get("group_part_id")), group_priority=self._normalize_optional_int(get("group_priority")), group_epoch=self._normalize_optional_int(get("group_epoch")), + broadcast_session_id=self._normalize_optional_text( + get("broadcast_session_id") + ), + broadcast_edge_id=self._normalize_optional_text(get("broadcast_edge_id")), completion_outcome=outcome, completion_detail=self._normalize_optional_text(get("completion_detail")), created_at=get("created_at"), diff --git a/tensorcast/global_store/rpc/transport_rpc_handler.py b/tensorcast/global_store/rpc/transport_rpc_handler.py index 9e5852bc..0fb3522e 100644 --- a/tensorcast/global_store/rpc/transport_rpc_handler.py +++ b/tensorcast/global_store/rpc/transport_rpc_handler.py @@ -17,6 +17,7 @@ ValidationError, ) from tensorcast.global_store.models import ( + BroadcastTransportHint, Replica, TransportCompletionOutcome, TransportSchedulingGroup, @@ -125,6 +126,20 @@ def request_replica_transport( status=global_store_pb2.Status.STATUS_ERROR ) + broadcast_hint: BroadcastTransportHint | None = None + if request.HasField("broadcast"): + session_id = request.broadcast.session_id.strip() + if not session_id: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("broadcast.session_id is required") + return global_store_pb2.RequestReplicaTransportResponse( + status=global_store_pb2.Status.STATUS_ERROR + ) + broadcast_hint = BroadcastTransportHint( + session_id=session_id, + strict_parent=bool(request.broadcast.strict_parent), + ) + replica, transport_id = self._transport_service.request_transport( artifact_id=request.artifact_id, view_id=requested_view_id, @@ -135,6 +150,7 @@ def request_replica_transport( scheduling_group=scheduling_group, requester_worker_id=requester_worker_id, request_id=request_id, + broadcast_hint=broadcast_hint, ) remote_info = self._replica_to_memory_info(replica) diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index f562c45c..d7c6ec58 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -7,7 +7,13 @@ from collections.abc import Sequence from uuid import UUID, uuid4 -from tensorcast.global_store.exceptions import DatabaseError +from duckdb import DuckDBPyConnection + +from tensorcast.global_store.exceptions import ( + DatabaseError, + NotFoundError, + ValidationError, +) from tensorcast.global_store.models import ( BroadcastEdge, BroadcastEdgeState, @@ -16,6 +22,7 @@ BroadcastTarget, BroadcastTargetState, Replica, + TransportCompletionOutcome, Worker, ) from tensorcast.global_store.repositories import ( @@ -149,6 +156,160 @@ def cancel_session(self, session_id: str) -> bool: BroadcastSessionState.CANCELLED, ) + def claim_transport_edge( + self, + *, + session_id: str, + artifact_id: str, + requested_view_id: str | None, + requester_worker_id: str, + request_id: str, + heartbeat_timeout_seconds: float, + cursor: DuckDBPyConnection, + ) -> tuple[Replica, BroadcastEdge]: + """Claim the planned broadcast parent for one target worker transport.""" + session = self._broadcast_repository.find_session(session_id, cursor=cursor) + if session is None: + raise NotFoundError(f"broadcast session not found: {session_id}") + if session.state is not BroadcastSessionState.ACTIVE: + raise ValidationError(f"broadcast session is not active: {session_id}") + if session.artifact_id != artifact_id: + raise ValidationError("broadcast session artifact does not match request") + if (session.requested_view_id or "") != (requested_view_id or ""): + raise ValidationError("broadcast session byte space does not match request") + + target = self._broadcast_repository.find_target( + session.session_id, + requester_worker_id, + cursor=cursor, + ) + if target is None: + raise NotFoundError( + f"broadcast target not found for worker: {requester_worker_id}" + ) + if target.state is BroadcastTargetState.COMPLETED: + raise ValidationError("broadcast target is already completed") + + edge = self._broadcast_repository.find_active_edge_for_child( + session.session_id, + requester_worker_id, + cursor=cursor, + ) + if edge is None: + self._plan_more_edges(session, cursor=cursor) + edge = self._broadcast_repository.find_active_edge_for_child( + session.session_id, + requester_worker_id, + cursor=cursor, + ) + if edge is None: + raise NotFoundError( + f"no broadcast edge available for worker: {requester_worker_id}" + ) + + selection = self._replica_repository.claim_replica_for_transport( + replica_id=edge.parent_replica_id, + artifact_id=artifact_id, + view_id=requested_view_id, + heartbeat_timeout_seconds=heartbeat_timeout_seconds, + cursor=cursor, + ) + if selection.replica is None: + raise NotFoundError("broadcast parent replica is not transport eligible") + + materialized = self._broadcast_repository.mark_edge_materializing( + edge.edge_id, + request_id, + cursor=cursor, + ) + if not materialized: + self._replica_repository.decrement_requests_with_cursor( + selection.replica.replica_id, + cursor, + ) + raise ValidationError("broadcast edge is no longer claimable") + + claimed_edge = self._broadcast_repository.find_edge(edge.edge_id, cursor=cursor) + if claimed_edge is None: + raise NotFoundError(f"broadcast edge not found: {edge.edge_id}") + return selection.replica, claimed_edge + + def complete_transport_edge( + self, + *, + session_id: str, + edge_id: str, + transport_outcome: TransportCompletionOutcome, + outcome_detail: str | None, + cursor: DuckDBPyConnection, + ) -> None: + """Advance broadcast edge state from a completed transport outcome.""" + session = self._broadcast_repository.find_session(session_id, cursor=cursor) + if session is None: + raise NotFoundError(f"broadcast session not found: {session_id}") + edge = self._broadcast_repository.find_edge(edge_id, cursor=cursor) + if edge is None: + raise NotFoundError(f"broadcast edge not found: {edge_id}") + + if transport_outcome is TransportCompletionOutcome.SUCCESS: + child_replica = self._replica_repository.find_exportable_replica_for_worker( + artifact_id=session.artifact_id, + view_id=session.requested_view_id, + worker_id=edge.child_worker_id, + heartbeat_timeout_seconds=self._ROOT_HEARTBEAT_TIMEOUT_SECONDS, + cursor=cursor, + ) + if child_replica is None: + self._broadcast_repository.mark_edge_failed( + edge.edge_id, + "child_replica_not_exportable_after_success", + cursor=cursor, + ) + return + self._broadcast_repository.mark_edge_completed( + edge.edge_id, + child_replica.replica_id, + cursor=cursor, + ) + self._plan_more_edges(session, cursor=cursor) + self._mark_session_complete_if_done(session.session_id, cursor=cursor) + return + + reason = ( + outcome_detail or transport_outcome.value or "transport_failed" + ).strip() + if not reason: + reason = "transport_failed" + self._broadcast_repository.mark_edge_failed( + edge.edge_id, + reason, + cursor=cursor, + ) + if int(edge.attempt) < int(session.max_attempts): + target = self._broadcast_repository.find_target( + edge.session_id, + edge.child_worker_id, + cursor=cursor, + ) + if target is not None: + target.state = BroadcastTargetState.PENDING + target.assigned_edge_id = None + target.completed_replica_id = None + target.completed_at = None + self._broadcast_repository.upsert_target(target, cursor=cursor) + self._plan_more_edges(session, cursor=cursor) + + def _mark_session_complete_if_done(self, session_id: str, *, cursor=None) -> None: + if self._broadcast_repository.count_incomplete_targets( + session_id, + cursor=cursor, + ) == 0: + self._broadcast_repository.update_session_state( + session_id, + BroadcastSessionState.COMPLETED, + cursor=cursor, + ) + def _resolve_targets( self, *, diff --git a/tensorcast/global_store/services/transport_service.py b/tensorcast/global_store/services/transport_service.py index 643b48ea..257ab74a 100644 --- a/tensorcast/global_store/services/transport_service.py +++ b/tensorcast/global_store/services/transport_service.py @@ -29,6 +29,7 @@ record_transport_source_assignment, ) from tensorcast.global_store.models import ( + BroadcastTransportHint, PendingTransportRequest, PendingTransportState, Replica, @@ -49,6 +50,7 @@ from tensorcast.global_store.repositories.transport_repository import ( TransportWindowRow, ) +from tensorcast.global_store.services.broadcast_service import BroadcastService from tensorcast.logger import init_logger logger = init_logger(__name__) @@ -63,11 +65,13 @@ def __init__( replica_repository: ReplicaRepository, transport_repository: TransportRepository, pending_transport_request_repository: PendingTransportRequestRepository, + broadcast_service: BroadcastService | None = None, ): """Initialize service with repositories.""" self.replica_repository = replica_repository self.transport_repository = transport_repository self.pending_transport_request_repository = pending_transport_request_repository + self.broadcast_service = broadcast_service self.config = get_config() # Serialize queue-wide dispatch to avoid multi-thread transaction storms. self._dispatch_loop_lock = threading.Lock() @@ -136,6 +140,7 @@ def _build_request_fingerprint( source_port: int, requester_worker_id: str | None, scheduling_group: TransportSchedulingGroup | None, + broadcast_hint: BroadcastTransportHint | None = None, ) -> str: group_kind = ( str(scheduling_group.group_kind).strip().lower() @@ -152,6 +157,14 @@ def _build_request_fingerprint( "source_address": source_address, "source_port": int(source_port), "requester_worker_id": (requester_worker_id or "").strip(), + "broadcast": ( + { + "session_id": broadcast_hint.session_id, + "strict_parent": bool(broadcast_hint.strict_parent), + } + if broadcast_hint is not None + else None + ), "scheduling_group": ( { "group_id": scheduling_group.group_id, @@ -181,6 +194,8 @@ def _build_transport( request_fingerprint: str | None, requester_worker_id: str | None, scheduling_group: TransportSchedulingGroup | None, + broadcast_session_id: str | None = None, + broadcast_edge_id: str | None = None, ) -> Transport: transport = Transport( replica_id=replica.replica_id, @@ -193,6 +208,8 @@ def _build_transport( request_id=request_id, request_fingerprint=request_fingerprint, requester_worker_id=requester_worker_id, + broadcast_session_id=broadcast_session_id, + broadcast_edge_id=broadcast_edge_id, ) transport.set_scheduling_group(scheduling_group) return transport @@ -310,6 +327,7 @@ def request_transport( view_id: str | None = None, scheduling_group: TransportSchedulingGroup | None = None, requester_worker_id: str | None = None, + broadcast_hint: BroadcastTransportHint | None = None, ) -> tuple[Replica, UUID]: """ Request an artifact transport via unified pending-queue dispatch. @@ -343,6 +361,7 @@ def request_transport( source_port=source_port, requester_worker_id=normalized_requester_worker_id, scheduling_group=scheduling_group, + broadcast_hint=broadcast_hint, ) existing = self._resolve_existing_request( normalized_request_id, request_fingerprint @@ -350,6 +369,19 @@ def request_transport( if existing is not None: return existing + if broadcast_hint is not None: + return self._request_transport_broadcast( + artifact_id=artifact_id, + view_id=view_id, + source_node_id=source_node_id, + source_address=source_address, + source_port=source_port, + requester_worker_id=normalized_requester_worker_id, + request_fingerprint=request_fingerprint, + request_id=normalized_request_id, + broadcast_hint=broadcast_hint, + ) + return self._request_transport_group_dispatch( artifact_id=artifact_id, view_id=view_id, @@ -363,6 +395,98 @@ def request_transport( request_id=normalized_request_id, ) + def _request_transport_broadcast( + self, + *, + artifact_id: str, + view_id: str | None, + source_node_id: str, + source_address: str, + source_port: int, + requester_worker_id: str | None, + request_fingerprint: str, + request_id: str, + broadcast_hint: BroadcastTransportHint, + ) -> tuple[Replica, UUID]: + start_time = time.time() + if self.broadcast_service is None: + raise ValidationError( + "broadcast transport requested but service is unavailable" + ) + if requester_worker_id is None: + raise ValidationError("broadcast transport requires requester_worker_id") + + try: + with self._dispatch_loop_lock, self.replica_repository.transaction() as tx: + existing = self.transport_repository.find_by_request_id( + request_id, + cursor=tx, + ) + if existing is not None: + if ( + existing.request_fingerprint is not None + and existing.request_fingerprint != request_fingerprint + ): + raise ValidationError( + f"request_id={request_id} already used with different payload" + ) + replica = self.replica_repository.find_by_id( + existing.replica_id, + existing.artifact_id, + cursor=tx, + ) + if replica is not None: + return replica, existing.transport_id + + replica, edge = self.broadcast_service.claim_transport_edge( + session_id=broadcast_hint.session_id, + artifact_id=artifact_id, + requested_view_id=view_id, + requester_worker_id=requester_worker_id, + request_id=request_id, + heartbeat_timeout_seconds=self.config.heartbeat_timeout_ms / 1000, + cursor=tx, + ) + transport = self._build_transport( + replica=replica, + artifact_id=artifact_id, + requested_view_id=view_id, + source_node_id=source_node_id, + source_address=source_address, + source_port=source_port, + request_id=request_id, + request_fingerprint=request_fingerprint, + requester_worker_id=requester_worker_id, + scheduling_group=None, + broadcast_session_id=broadcast_hint.session_id, + broadcast_edge_id=edge.edge_id, + ) + resolved_transport, created = ( + self.transport_repository.create_if_absent_with_cursor( + transport, + tx, + ) + ) + if not created: + self.replica_repository.decrement_requests_with_cursor( + replica.replica_id, + tx, + ) + else: + inc_active_transports() + record_transport_source_assignment( + artifact_id=artifact_id, + replica_id=str(replica.replica_id), + source_created_at=replica.created_at, + ) + inc_transport_request(artifact_id, "success") + observe_transport_wait(artifact_id, time.time() - start_time) + return replica, resolved_transport.transport_id + except DatabaseError as exc: + if isinstance(exc.__cause__, (NotFoundError, ValidationError)): + raise exc.__cause__ from exc + raise + def _request_transport_group_dispatch( self, *, @@ -1099,6 +1223,18 @@ def _complete_transport_in_single_tx( self.replica_repository.decrement_requests_with_cursor( transport.replica_id, tx ) + if ( + self.broadcast_service is not None + and transport.broadcast_session_id + and transport.broadcast_edge_id + ): + self.broadcast_service.complete_transport_edge( + session_id=transport.broadcast_session_id, + edge_id=transport.broadcast_edge_id, + transport_outcome=outcome, + outcome_detail=outcome_detail, + cursor=tx, + ) replica_after = self.replica_repository.find_by_id( transport.replica_id, transport.artifact_id, cursor=tx ) diff --git a/tests/python/global_store/test_broadcast_transport.py b/tests/python/global_store/test_broadcast_transport.py new file mode 100644 index 00000000..e092dc32 --- /dev/null +++ b/tests/python/global_store/test_broadcast_transport.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +from tensorcast.global_store.models import BroadcastEdgeState +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.global_store.v1 import global_store_pb2 + + +def _register_worker(servicer, context, *, worker_id: str, node_id: str, port: int) -> str: + response = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id=f"daemon-{worker_id}", + node_id=node_id, + node_address=f"10.30.0.{port % 100}", + grpc_port=port, + p2p_port=port + 1, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + context, + ) + assert response.status == global_store_pb2.STATUS_OK + return response.worker_id + + +def _register_exportable_replica( + servicer, + context, + *, + artifact_id: str, + worker_id: str, + node_id: str, + node_address: str, + node_port: int, + remote_key: str, +) -> str: + mem_info = common_pb2.MemoryInfo( + node_id=node_id, + node_address=node_address, + node_port=node_port, + memory_size=1024, + memory_type=common_pb2.MEMORY_TYPE_GPU, + device_id=0, + byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + ) + transport = mem_info.transport + transport.export_state = common_pb2.ReplicaTransportMetadata.EXPORT_STATE_EXPORTABLE + transport.export_generation = 1 + transport.remote_memory_keys.append(remote_key) + transport.buffer_sizes.append(1024) + + response = servicer.RegisterReplica( + global_store_pb2.RegisterReplicaRequest( + artifact_id=artifact_id, + worker_id=worker_id, + mem_info=mem_info, + max_concurrency=4, + ), + context, + ) + assert response.status == global_store_pb2.STATUS_OK + return response.replica_id + + +def _create_broadcast_session( + servicer, + context, + *, + session_id: str, + artifact_id: str, + root_replica_id: str, + child_worker_id: str, +) -> None: + response = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id=session_id, + artifact_id=artifact_id, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + epoch=1, + fanout=1, + strict_parent=True, + max_attempts=3, + root_replica_id=root_replica_id, + targets=[ + global_store_pb2.BroadcastTargetIdentity( + worker_id=child_worker_id, + ) + ], + ), + context, + ) + assert response.status == global_store_pb2.STATUS_OK + assert len(response.edges) == 1 + + +def test_broadcast_transport_uses_edge_parent(servicer, test_context): + artifact_id = "mi2:broadcast-transport-parent" + root_worker = _register_worker( + servicer, test_context, worker_id="root", node_id="node-1", port=53100 + ) + alternate_worker = _register_worker( + servicer, test_context, worker_id="alt", node_id="node-2", port=53200 + ) + child_worker = _register_worker( + servicer, test_context, worker_id="child", node_id="node-3", port=53300 + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.30.0.1", + node_port=53101, + remote_key="rk-root", + ) + _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=alternate_worker, + node_id="node-2", + node_address="10.30.0.2", + node_port=53201, + remote_key="rk-alt", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-transport-parent", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + + response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.30.9.9", + source_port=59000, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-broadcast-parent", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-transport-parent", + strict_parent=True, + ), + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_OK + assert response.remote_memory_info.node_id == "node-1" + assert list(response.remote_memory_info.transport.remote_memory_keys) == ["rk-root"] + + +def test_broadcast_failed_transport_requeues_target(servicer, test_context): + artifact_id = "mi2:broadcast-transport-failed" + root_worker = _register_worker( + servicer, test_context, worker_id="root-fail", node_id="node-1", port=54100 + ) + child_worker = _register_worker( + servicer, test_context, worker_id="child-fail", node_id="node-2", port=54200 + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.40.0.1", + node_port=54101, + remote_key="rk-root-fail", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-transport-failed", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + + transport_response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.40.9.9", + source_port=59001, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-broadcast-failed", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-transport-failed", + strict_parent=True, + ), + ), + test_context, + ) + assert transport_response.status == global_store_pb2.STATUS_OK + + complete_response = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=transport_response.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_FAILED, + outcome_detail="forced test failure", + ), + test_context, + ) + assert complete_response.status == global_store_pb2.STATUS_OK + + edges = servicer.broadcast_service.list_edges("session-transport-failed") + assert any(edge.state is BroadcastEdgeState.FAILED for edge in edges) From c7a3fb82fdd17afe6913de07295e7c160172d565 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:08:24 +0800 Subject: [PATCH 28/49] fix(global-store): harden broadcast transport lifecycle --- .../repositories/broadcast_repository.py | 6 + .../services/broadcast_service.py | 42 ++- .../services/transport_service.py | 20 + .../global_store/test_broadcast_repository.py | 60 +++ .../global_store/test_broadcast_transport.py | 342 +++++++++++++++++- 5 files changed, 462 insertions(+), 8 deletions(-) diff --git a/tensorcast/global_store/repositories/broadcast_repository.py b/tensorcast/global_store/repositories/broadcast_repository.py index 80094b3d..35f1c302 100644 --- a/tensorcast/global_store/repositories/broadcast_repository.py +++ b/tensorcast/global_store/repositories/broadcast_repository.py @@ -563,6 +563,8 @@ def mark_edge_failed( edge = self.find_edge(normalized_edge_id, cursor=cursor) if edge is None: return False + if edge.state is not BroadcastEdgeState.MATERIALIZING: + return False target = self.find_target(edge.session_id, edge.child_worker_id, cursor=cursor) if target is None or target.assigned_edge_id != edge.edge_id: return False @@ -575,6 +577,7 @@ def mark_edge_failed( updated_at = CURRENT_TIMESTAMP, completed_at = CURRENT_TIMESTAMP WHERE edge_id = ? + AND state = 'materializing' RETURNING edge_id """, [normalized_reason, normalized_edge_id], @@ -624,6 +627,8 @@ def mark_edge_completed( edge = self.find_edge(normalized_edge_id, cursor=cursor) if edge is None: return False + if edge.state is not BroadcastEdgeState.MATERIALIZING: + return False target = self.find_target(edge.session_id, edge.child_worker_id, cursor=cursor) if target is None or target.assigned_edge_id != edge.edge_id: return False @@ -635,6 +640,7 @@ def mark_edge_completed( updated_at = CURRENT_TIMESTAMP, completed_at = CURRENT_TIMESTAMP WHERE edge_id = ? + AND state = 'materializing' RETURNING edge_id """, [normalized_edge_id], diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index d7c6ec58..e6f3fb64 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -260,9 +260,10 @@ def complete_transport_edge( cursor=cursor, ) if child_replica is None: - self._broadcast_repository.mark_edge_failed( - edge.edge_id, - "child_replica_not_exportable_after_success", + self._fail_edge_and_maybe_retry( + session=session, + edge=edge, + reason="child_replica_not_exportable_after_success", cursor=cursor, ) return @@ -272,7 +273,7 @@ def complete_transport_edge( cursor=cursor, ) self._plan_more_edges(session, cursor=cursor) - self._mark_session_complete_if_done(session.session_id, cursor=cursor) + self._mark_session_terminal_if_done(session.session_id, cursor=cursor) return reason = ( @@ -280,6 +281,21 @@ def complete_transport_edge( ).strip() if not reason: reason = "transport_failed" + self._fail_edge_and_maybe_retry( + session=session, + edge=edge, + reason=reason, + cursor=cursor, + ) + + def _fail_edge_and_maybe_retry( + self, + *, + session: BroadcastSession, + edge: BroadcastEdge, + reason: str, + cursor: DuckDBPyConnection, + ) -> None: self._broadcast_repository.mark_edge_failed( edge.edge_id, reason, @@ -298,12 +314,26 @@ def complete_transport_edge( target.completed_at = None self._broadcast_repository.upsert_target(target, cursor=cursor) self._plan_more_edges(session, cursor=cursor) + return + self._mark_session_terminal_if_done(session.session_id, cursor=cursor) - def _mark_session_complete_if_done(self, session_id: str, *, cursor=None) -> None: + def _mark_session_terminal_if_done(self, session_id: str, *, cursor=None) -> None: if self._broadcast_repository.count_incomplete_targets( session_id, cursor=cursor, - ) == 0: + ) != 0: + return + targets = self._broadcast_repository.list_targets(session_id, cursor=cursor) + if any(target.state is BroadcastTargetState.FAILED for target in targets): + self._broadcast_repository.update_session_state( + session_id, + BroadcastSessionState.FAILED, + cursor=cursor, + ) + return + if targets and all( + target.state is BroadcastTargetState.COMPLETED for target in targets + ): self._broadcast_repository.update_session_state( session_id, BroadcastSessionState.COMPLETED, diff --git a/tensorcast/global_store/services/transport_service.py b/tensorcast/global_store/services/transport_service.py index 257ab74a..1dba60ce 100644 --- a/tensorcast/global_store/services/transport_service.py +++ b/tensorcast/global_store/services/transport_service.py @@ -437,6 +437,9 @@ def _request_transport_broadcast( ) if replica is not None: return replica, existing.transport_id + raise NotFoundError( + "existing broadcast transport source replica is missing" + ) replica, edge = self.broadcast_service.claim_transport_edge( session_id=broadcast_hint.session_id, @@ -472,6 +475,23 @@ def _request_transport_broadcast( replica.replica_id, tx, ) + self.broadcast_service.complete_transport_edge( + session_id=broadcast_hint.session_id, + edge_id=edge.edge_id, + transport_outcome=TransportCompletionOutcome.FAILED, + outcome_detail="duplicate_request_after_broadcast_claim", + cursor=tx, + ) + existing_replica = self.replica_repository.find_by_id( + resolved_transport.replica_id, + resolved_transport.artifact_id, + cursor=tx, + ) + if existing_replica is None: + raise NotFoundError( + "existing broadcast transport source replica is missing" + ) + return existing_replica, resolved_transport.transport_id else: inc_active_transports() record_transport_source_assignment( diff --git a/tests/python/global_store/test_broadcast_repository.py b/tests/python/global_store/test_broadcast_repository.py index 069029ed..0bab63ad 100644 --- a/tests/python/global_store/test_broadcast_repository.py +++ b/tests/python/global_store/test_broadcast_repository.py @@ -406,3 +406,63 @@ def test_broadcast_repository_marks_edge_completed_and_target_completed(db_conne assert edge.state is BroadcastEdgeState.COMPLETED assert target.state is BroadcastTargetState.COMPLETED assert target.completed_replica_id == completed_replica_id + + +@pytest.mark.parametrize("transition", ["failed", "completed"]) +def test_broadcast_repository_rejects_non_materializing_terminal_transition( + db_connection, + transition, +): + repo = BroadcastRepository(db_connection) + repo.create_session( + BroadcastSession( + session_id="session-non-materializing", + artifact_id="mi2:test", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + ) + ) + repo.upsert_target( + BroadcastTarget( + session_id="session-non-materializing", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + state=BroadcastTargetState.ASSIGNED, + assigned_edge_id="edge-planned", + ) + ) + repo.create_edge( + BroadcastEdge( + edge_id="edge-planned", + session_id="session-non-materializing", + parent_worker_id="worker-root", + parent_replica_id=UUID("00000000-0000-0000-0000-000000000001"), + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.PLANNED, + ) + ) + + if transition == "failed": + changed = repo.mark_edge_failed(edge_id="edge-planned", reason="not ready") + else: + changed = repo.mark_edge_completed( + edge_id="edge-planned", + completed_replica_id=UUID("00000000-0000-0000-0000-000000000002"), + ) + + edge = repo.find_edge("edge-planned") + target = repo.find_target("session-non-materializing", "worker-child") + assert not changed + assert edge is not None + assert target is not None + assert edge.state is BroadcastEdgeState.PLANNED + assert edge.completed_at is None + assert target.state is BroadcastTargetState.ASSIGNED + assert target.completed_at is None diff --git a/tests/python/global_store/test_broadcast_transport.py b/tests/python/global_store/test_broadcast_transport.py index e092dc32..4d6fdc7a 100644 --- a/tests/python/global_store/test_broadcast_transport.py +++ b/tests/python/global_store/test_broadcast_transport.py @@ -1,6 +1,13 @@ from __future__ import annotations -from tensorcast.global_store.models import BroadcastEdgeState +from uuid import UUID + +from tensorcast.global_store.models import ( + BroadcastEdgeState, + BroadcastSessionState, + BroadcastTargetState, + Transport, +) from tensorcast.proto.common.v1 import common_pb2 from tensorcast.proto.global_store.v1 import global_store_pb2 @@ -71,6 +78,7 @@ def _create_broadcast_session( artifact_id: str, root_replica_id: str, child_worker_id: str, + max_attempts: int = 3, ) -> None: response = servicer.CreateBroadcastSession( global_store_pb2.CreateBroadcastSessionRequest( @@ -82,7 +90,7 @@ def _create_broadcast_session( epoch=1, fanout=1, strict_parent=True, - max_attempts=3, + max_attempts=max_attempts, root_replica_id=root_replica_id, targets=[ global_store_pb2.BroadcastTargetIdentity( @@ -219,3 +227,333 @@ def test_broadcast_failed_transport_requeues_target(servicer, test_context): edges = servicer.broadcast_service.list_edges("session-transport-failed") assert any(edge.state is BroadcastEdgeState.FAILED for edge in edges) + + +def test_duplicate_broadcast_request_reuses_existing_edge(servicer, test_context): + artifact_id = "mi2:broadcast-transport-duplicate" + root_worker = _register_worker( + servicer, test_context, worker_id="root-dup", node_id="node-1", port=55100 + ) + child_worker = _register_worker( + servicer, test_context, worker_id="child-dup", node_id="node-2", port=55200 + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.50.0.1", + node_port=55101, + remote_key="rk-root-dup", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-transport-duplicate", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.50.9.9", + source_port=59002, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-broadcast-duplicate", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-transport-duplicate", + strict_parent=True, + ), + ) + + first = servicer.RequestReplicaTransport(request, test_context) + second = servicer.RequestReplicaTransport(request, test_context) + + assert first.status == global_store_pb2.STATUS_OK + assert second.status == global_store_pb2.STATUS_OK + assert second.transport_id == first.transport_id + edges = servicer.broadcast_service.list_edges("session-transport-duplicate") + assert len(edges) == 1 + assert edges[0].state is BroadcastEdgeState.MATERIALIZING + targets = servicer.broadcast_service.list_targets("session-transport-duplicate") + assert len(targets) == 1 + assert targets[0].state is BroadcastTargetState.MATERIALIZING + assert targets[0].assigned_edge_id == edges[0].edge_id + + +def test_broadcast_request_existing_transport_missing_replica_does_not_claim_edge( + servicer, + test_context, +): + artifact_id = "mi2:broadcast-transport-missing-replica" + root_worker = _register_worker( + servicer, + test_context, + worker_id="root-missing-replica", + node_id="node-1", + port=55500, + ) + child_worker = _register_worker( + servicer, + test_context, + worker_id="child-missing-replica", + node_id="node-2", + port=55600, + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.55.0.1", + node_port=55501, + remote_key="rk-root-missing-replica", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-missing-replica", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + servicer.transport_repository.create( + Transport( + replica_id=UUID("00000000-0000-0000-0000-00000000dead"), + artifact_id=artifact_id, + source_node_id="stale-node", + source_address="10.55.9.9", + source_port=59006, + requester_worker_id=child_worker, + request_id="request-missing-replica", + ) + ) + + response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.55.9.10", + source_port=59007, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-missing-replica", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-missing-replica", + strict_parent=True, + ), + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_NOT_FOUND + edges = servicer.broadcast_service.list_edges("session-missing-replica") + assert len(edges) == 1 + assert edges[0].state is BroadcastEdgeState.PLANNED + target = servicer.broadcast_service.list_targets("session-missing-replica")[0] + assert target.state is BroadcastTargetState.ASSIGNED + assert target.assigned_edge_id == edges[0].edge_id + + +def test_broadcast_success_without_child_replica_requeues(servicer, test_context): + artifact_id = "mi2:broadcast-transport-success-no-child" + root_worker = _register_worker( + servicer, + test_context, + worker_id="root-no-child", + node_id="node-1", + port=56100, + ) + child_worker = _register_worker( + servicer, + test_context, + worker_id="child-no-child", + node_id="node-2", + port=56200, + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.60.0.1", + node_port=56101, + remote_key="rk-root-no-child", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-success-no-child", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + transport_response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.60.9.9", + source_port=59003, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-success-no-child", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-success-no-child", + strict_parent=True, + ), + ), + test_context, + ) + assert transport_response.status == global_store_pb2.STATUS_OK + + complete_response = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=transport_response.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_SUCCESS, + ), + test_context, + ) + + assert complete_response.status == global_store_pb2.STATUS_OK + edges = servicer.broadcast_service.list_edges("session-success-no-child") + assert [edge.state for edge in edges].count(BroadcastEdgeState.FAILED) == 1 + assert [edge.state for edge in edges].count(BroadcastEdgeState.PLANNED) == 1 + target = servicer.broadcast_service.list_targets("session-success-no-child")[0] + assert target.state is BroadcastTargetState.ASSIGNED + assert target.attempt == 2 + + +def test_broadcast_max_attempt_exhaustion_marks_session_failed(servicer, test_context): + artifact_id = "mi2:broadcast-transport-max-attempts" + root_worker = _register_worker( + servicer, test_context, worker_id="root-max", node_id="node-1", port=57100 + ) + child_worker = _register_worker( + servicer, test_context, worker_id="child-max", node_id="node-2", port=57200 + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.70.0.1", + node_port=57101, + remote_key="rk-root-max", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-max-attempts", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + max_attempts=1, + ) + transport_response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.70.9.9", + source_port=59004, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-max-attempts", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-max-attempts", + strict_parent=True, + ), + ), + test_context, + ) + assert transport_response.status == global_store_pb2.STATUS_OK + + complete_response = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=transport_response.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_FAILED, + ), + test_context, + ) + + assert complete_response.status == global_store_pb2.STATUS_OK + session = servicer.broadcast_service.get_session("session-max-attempts") + target = servicer.broadcast_service.list_targets("session-max-attempts")[0] + assert session is not None + assert session.state is BroadcastSessionState.FAILED + assert target.state is BroadcastTargetState.FAILED + + +def test_duplicate_broadcast_completion_is_noop(servicer, test_context): + artifact_id = "mi2:broadcast-transport-complete-twice" + root_worker = _register_worker( + servicer, test_context, worker_id="root-twice", node_id="node-1", port=58100 + ) + child_worker = _register_worker( + servicer, test_context, worker_id="child-twice", node_id="node-2", port=58200 + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.80.0.1", + node_port=58101, + remote_key="rk-root-twice", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-complete-twice", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + transport_response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.80.9.9", + source_port=59005, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-complete-twice", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-complete-twice", + strict_parent=True, + ), + ), + test_context, + ) + assert transport_response.status == global_store_pb2.STATUS_OK + complete_request = global_store_pb2.CompleteReplicaTransportRequest( + transport_id=transport_response.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_FAILED, + ) + + first = servicer.CompleteReplicaTransport(complete_request, test_context) + edges_after_first = servicer.broadcast_service.list_edges("session-complete-twice") + second = servicer.CompleteReplicaTransport(complete_request, test_context) + edges_after_second = servicer.broadcast_service.list_edges("session-complete-twice") + + assert first.status == global_store_pb2.STATUS_OK + assert second.status == global_store_pb2.STATUS_OK + assert [(e.edge_id, e.state) for e in edges_after_second] == [ + (e.edge_id, e.state) for e in edges_after_first + ] From 5cd916c3776b415ef179f5285cd889c59b073ee2 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:14:43 +0800 Subject: [PATCH 29/49] fix(global-store): ignore stale broadcast edge failures --- .../services/broadcast_service.py | 19 +-- .../global_store/test_broadcast_service.py | 116 ++++++++++++++++++ 2 files changed, 127 insertions(+), 8 deletions(-) diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index e6f3fb64..5493e3ef 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -296,24 +296,27 @@ def _fail_edge_and_maybe_retry( reason: str, cursor: DuckDBPyConnection, ) -> None: - self._broadcast_repository.mark_edge_failed( + failed = self._broadcast_repository.mark_edge_failed( edge.edge_id, reason, cursor=cursor, ) + if not failed: + return if int(edge.attempt) < int(session.max_attempts): target = self._broadcast_repository.find_target( edge.session_id, edge.child_worker_id, cursor=cursor, ) - if target is not None: - target.state = BroadcastTargetState.PENDING - target.assigned_edge_id = None - target.completed_replica_id = None - target.completed_at = None - self._broadcast_repository.upsert_target(target, cursor=cursor) - self._plan_more_edges(session, cursor=cursor) + if target is None or target.assigned_edge_id != edge.edge_id: + return + target.state = BroadcastTargetState.PENDING + target.assigned_edge_id = None + target.completed_replica_id = None + target.completed_at = None + self._broadcast_repository.upsert_target(target, cursor=cursor) + self._plan_more_edges(session, cursor=cursor) return self._mark_session_terminal_if_done(session.session_id, cursor=cursor) diff --git a/tests/python/global_store/test_broadcast_service.py b/tests/python/global_store/test_broadcast_service.py index 5793d454..2db6bf48 100644 --- a/tests/python/global_store/test_broadcast_service.py +++ b/tests/python/global_store/test_broadcast_service.py @@ -1,14 +1,20 @@ from __future__ import annotations +from uuid import UUID + import pytest from tensorcast.global_store.models import ( + BroadcastEdge, BroadcastEdgeState, + BroadcastSession, BroadcastSessionState, + BroadcastTarget, BroadcastTargetState, ExportState, MemoryType, Replica, + TransportCompletionOutcome, Worker, ) from tensorcast.global_store.services import BroadcastService @@ -276,3 +282,113 @@ def test_create_session_validates_required_inputs(repositories, overrides, messa with pytest.raises(ValueError, match=message): service.create_session(**kwargs) assert broadcast_repo.find_session("session-validation") is None + + +@pytest.mark.parametrize( + ("target_state", "assigned_edge_id"), + [ + (BroadcastTargetState.COMPLETED, "edge-stale"), + (BroadcastTargetState.ASSIGNED, "edge-new"), + ], +) +def test_complete_transport_edge_stale_failure_does_not_requeue_target( + repositories, + target_state, + assigned_edge_id, +): + replica_repo = repositories["replica"] + worker_repo = repositories["worker"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root_replica_id = UUID("00000000-0000-0000-0000-000000000001") + completed_replica_id = UUID("00000000-0000-0000-0000-000000000002") + broadcast_repo.create_session( + BroadcastSession( + session_id="session-stale-failure", + artifact_id="mi2:stale-failure", + requested_view_id=None, + epoch=1, + fanout=1, + max_attempts=3, + strict_parent=True, + state=BroadcastSessionState.ACTIVE, + root_replica_id=root_replica_id, + ) + ) + broadcast_repo.upsert_target( + BroadcastTarget( + session_id="session-stale-failure", + target_worker_id="worker-child", + target_daemon_id="daemon-child", + state=target_state, + level=1, + attempt=2, + assigned_edge_id=assigned_edge_id, + completed_replica_id=( + completed_replica_id + if target_state is BroadcastTargetState.COMPLETED + else None + ), + ) + ) + broadcast_repo.create_edge( + BroadcastEdge( + edge_id="edge-stale", + session_id="session-stale-failure", + parent_worker_id="worker-root", + parent_replica_id=root_replica_id, + child_worker_id="worker-child", + level=1, + attempt=1, + state=BroadcastEdgeState.PLANNED, + ) + ) + if assigned_edge_id == "edge-new": + with broadcast_repo.transaction() as tx: + tx.execute( + """ + INSERT INTO broadcast_edges ( + edge_id, session_id, parent_worker_id, parent_replica_id, + child_worker_id, level, attempt, state + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + "edge-new", + "session-stale-failure", + "worker-root", + str(root_replica_id), + "worker-child", + 1, + 2, + BroadcastEdgeState.PLANNED.value, + ], + ) + + with broadcast_repo.transaction() as tx: + service.complete_transport_edge( + session_id="session-stale-failure", + edge_id="edge-stale", + transport_outcome=TransportCompletionOutcome.FAILED, + outcome_detail="stale completion", + cursor=tx, + ) + + target = broadcast_repo.find_target("session-stale-failure", "worker-child") + edges = broadcast_repo.list_edges("session-stale-failure") + session = broadcast_repo.find_session("session-stale-failure") + assert target is not None + assert session is not None + assert target.state is target_state + assert target.assigned_edge_id == assigned_edge_id + assert target.completed_replica_id == ( + completed_replica_id if target_state is BroadcastTargetState.COMPLETED else None + ) + assert {edge.edge_id for edge in edges} == ( + {"edge-stale", "edge-new"} if assigned_edge_id == "edge-new" else {"edge-stale"} + ) + assert session.state is BroadcastSessionState.ACTIVE From 7ab9aeb8d5b8e495b0a2c124bc497ce096c9280b Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:21:31 +0800 Subject: [PATCH 30/49] fix(global-store): retry ineligible broadcast parents --- .../services/broadcast_service.py | 17 ++ .../services/transport_service.py | 121 +++++++------- .../global_store/test_broadcast_transport.py | 149 ++++++++++++++++++ 3 files changed, 232 insertions(+), 55 deletions(-) diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index 5493e3ef..d570240c 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -215,6 +215,23 @@ def claim_transport_edge( cursor=cursor, ) if selection.replica is None: + materialized = self._broadcast_repository.mark_edge_materializing( + edge.edge_id, + request_id, + cursor=cursor, + ) + if materialized: + claimed_edge = self._broadcast_repository.find_edge( + edge.edge_id, + cursor=cursor, + ) + if claimed_edge is not None: + self._fail_edge_and_maybe_retry( + session=session, + edge=claimed_edge, + reason="parent_replica_not_transport_eligible", + cursor=cursor, + ) raise NotFoundError("broadcast parent replica is not transport eligible") materialized = self._broadcast_repository.mark_edge_materializing( diff --git a/tensorcast/global_store/services/transport_service.py b/tensorcast/global_store/services/transport_service.py index 1dba60ce..55804c28 100644 --- a/tensorcast/global_store/services/transport_service.py +++ b/tensorcast/global_store/services/transport_service.py @@ -416,6 +416,7 @@ def _request_transport_broadcast( if requester_worker_id is None: raise ValidationError("broadcast transport requires requester_worker_id") + claim_error: NotFoundError | ValidationError | None = None try: with self._dispatch_loop_lock, self.replica_repository.transaction() as tx: existing = self.transport_repository.find_by_request_id( @@ -441,67 +442,77 @@ def _request_transport_broadcast( "existing broadcast transport source replica is missing" ) - replica, edge = self.broadcast_service.claim_transport_edge( - session_id=broadcast_hint.session_id, - artifact_id=artifact_id, - requested_view_id=view_id, - requester_worker_id=requester_worker_id, - request_id=request_id, - heartbeat_timeout_seconds=self.config.heartbeat_timeout_ms / 1000, - cursor=tx, - ) - transport = self._build_transport( - replica=replica, - artifact_id=artifact_id, - requested_view_id=view_id, - source_node_id=source_node_id, - source_address=source_address, - source_port=source_port, - request_id=request_id, - request_fingerprint=request_fingerprint, - requester_worker_id=requester_worker_id, - scheduling_group=None, - broadcast_session_id=broadcast_hint.session_id, - broadcast_edge_id=edge.edge_id, - ) - resolved_transport, created = ( - self.transport_repository.create_if_absent_with_cursor( - transport, - tx, - ) - ) - if not created: - self.replica_repository.decrement_requests_with_cursor( - replica.replica_id, - tx, - ) - self.broadcast_service.complete_transport_edge( + try: + replica, edge = self.broadcast_service.claim_transport_edge( session_id=broadcast_hint.session_id, - edge_id=edge.edge_id, - transport_outcome=TransportCompletionOutcome.FAILED, - outcome_detail="duplicate_request_after_broadcast_claim", + artifact_id=artifact_id, + requested_view_id=view_id, + requester_worker_id=requester_worker_id, + request_id=request_id, + heartbeat_timeout_seconds=self.config.heartbeat_timeout_ms + / 1000, cursor=tx, ) - existing_replica = self.replica_repository.find_by_id( - resolved_transport.replica_id, - resolved_transport.artifact_id, - cursor=tx, + except (NotFoundError, ValidationError) as exc: + claim_error = exc + replica = None + edge = None + + if claim_error is None: + transport = self._build_transport( + replica=replica, + artifact_id=artifact_id, + requested_view_id=view_id, + source_node_id=source_node_id, + source_address=source_address, + source_port=source_port, + request_id=request_id, + request_fingerprint=request_fingerprint, + requester_worker_id=requester_worker_id, + scheduling_group=None, + broadcast_session_id=broadcast_hint.session_id, + broadcast_edge_id=edge.edge_id, ) - if existing_replica is None: - raise NotFoundError( - "existing broadcast transport source replica is missing" + resolved_transport, created = ( + self.transport_repository.create_if_absent_with_cursor( + transport, + tx, ) - return existing_replica, resolved_transport.transport_id - else: - inc_active_transports() - record_transport_source_assignment( - artifact_id=artifact_id, - replica_id=str(replica.replica_id), - source_created_at=replica.created_at, ) - inc_transport_request(artifact_id, "success") - observe_transport_wait(artifact_id, time.time() - start_time) - return replica, resolved_transport.transport_id + if not created: + self.replica_repository.decrement_requests_with_cursor( + replica.replica_id, + tx, + ) + self.broadcast_service.complete_transport_edge( + session_id=broadcast_hint.session_id, + edge_id=edge.edge_id, + transport_outcome=TransportCompletionOutcome.FAILED, + outcome_detail="duplicate_request_after_broadcast_claim", + cursor=tx, + ) + existing_replica = self.replica_repository.find_by_id( + resolved_transport.replica_id, + resolved_transport.artifact_id, + cursor=tx, + ) + if existing_replica is None: + raise NotFoundError( + "existing broadcast transport source replica is missing" + ) + return existing_replica, resolved_transport.transport_id + else: + inc_active_transports() + record_transport_source_assignment( + artifact_id=artifact_id, + replica_id=str(replica.replica_id), + source_created_at=replica.created_at, + ) + inc_transport_request(artifact_id, "success") + observe_transport_wait(artifact_id, time.time() - start_time) + return replica, resolved_transport.transport_id + if claim_error is not None: + raise claim_error except DatabaseError as exc: if isinstance(exc.__cause__, (NotFoundError, ValidationError)): raise exc.__cause__ from exc diff --git a/tests/python/global_store/test_broadcast_transport.py b/tests/python/global_store/test_broadcast_transport.py index 4d6fdc7a..c2cb667f 100644 --- a/tests/python/global_store/test_broadcast_transport.py +++ b/tests/python/global_store/test_broadcast_transport.py @@ -363,6 +363,155 @@ def test_broadcast_request_existing_transport_missing_replica_does_not_claim_edg assert target.assigned_edge_id == edges[0].edge_id +def test_broadcast_parent_ineligible_exhausts_attempt_and_fails_session( + servicer, + test_context, +): + artifact_id = "mi2:broadcast-parent-ineligible-max" + root_worker = _register_worker( + servicer, + test_context, + worker_id="root-ineligible-max", + node_id="node-1", + port=55700, + ) + child_worker = _register_worker( + servicer, + test_context, + worker_id="child-ineligible-max", + node_id="node-2", + port=55800, + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.57.0.1", + node_port=55701, + remote_key="rk-root-ineligible-max", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-parent-ineligible-max", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + max_attempts=1, + ) + servicer.replica_repository.mark_unavailable(UUID(root_replica_id)) + + response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.57.9.9", + source_port=59008, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-parent-ineligible-max", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-parent-ineligible-max", + strict_parent=True, + ), + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_NOT_FOUND + edges = servicer.broadcast_service.list_edges("session-parent-ineligible-max") + target = servicer.broadcast_service.list_targets("session-parent-ineligible-max")[0] + session = servicer.broadcast_service.get_session("session-parent-ineligible-max") + assert session is not None + assert len(edges) == 1 + assert edges[0].state is BroadcastEdgeState.FAILED + assert edges[0].failure_reason == "parent_replica_not_transport_eligible" + assert target.state is BroadcastTargetState.FAILED + assert session.state is BroadcastSessionState.FAILED + + +def test_broadcast_parent_ineligible_retries_without_stuck_active_edge( + servicer, + test_context, +): + artifact_id = "mi2:broadcast-parent-ineligible-retry" + root_worker = _register_worker( + servicer, + test_context, + worker_id="root-ineligible-retry", + node_id="node-1", + port=55900, + ) + child_worker = _register_worker( + servicer, + test_context, + worker_id="child-ineligible-retry", + node_id="node-2", + port=56000, + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.59.0.1", + node_port=55901, + remote_key="rk-root-ineligible-retry", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-parent-ineligible-retry", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + max_attempts=3, + ) + original_edge = servicer.broadcast_service.list_edges( + "session-parent-ineligible-retry" + )[0] + servicer.replica_repository.mark_unavailable(UUID(root_replica_id)) + + response = servicer.RequestReplicaTransport( + global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.59.9.9", + source_port=59009, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-parent-ineligible-retry", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-parent-ineligible-retry", + strict_parent=True, + ), + ), + test_context, + ) + + assert response.status == global_store_pb2.STATUS_NOT_FOUND + edges = servicer.broadcast_service.list_edges("session-parent-ineligible-retry") + target = servicer.broadcast_service.list_targets( + "session-parent-ineligible-retry" + )[0] + original = [edge for edge in edges if edge.edge_id == original_edge.edge_id][0] + retry_edges = [edge for edge in edges if edge.edge_id != original_edge.edge_id] + assert original.state is BroadcastEdgeState.FAILED + assert original.failure_reason == "parent_replica_not_transport_eligible" + assert len(retry_edges) == 1 + assert retry_edges[0].state is BroadcastEdgeState.PLANNED + assert retry_edges[0].attempt == 2 + assert target.state is BroadcastTargetState.ASSIGNED + assert target.assigned_edge_id == retry_edges[0].edge_id + + def test_broadcast_success_without_child_replica_requeues(servicer, test_context): artifact_id = "mi2:broadcast-transport-success-no-child" root_worker = _register_worker( From e7fcdef697a42dba952e2a8ca28109cd8fbd75a1 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:29:35 +0800 Subject: [PATCH 31/49] fix(global-store): recycle stale broadcast request ids --- .../repositories/transport_repository.py | 24 +++ .../services/broadcast_service.py | 14 ++ .../services/transport_service.py | 84 ++++++++-- .../global_store/test_broadcast_transport.py | 148 ++++++++++++++++++ 4 files changed, 256 insertions(+), 14 deletions(-) diff --git a/tensorcast/global_store/repositories/transport_repository.py b/tensorcast/global_store/repositories/transport_repository.py index 26bebb2b..2a99a74d 100644 --- a/tensorcast/global_store/repositories/transport_repository.py +++ b/tensorcast/global_store/repositories/transport_repository.py @@ -283,6 +283,30 @@ def delete_with_cursor(self, transport_id: UUID, cursor) -> bool: ).fetchone() return row is not None + def clear_request_identity_with_cursor( + self, + *, + transport_id: UUID, + request_id: str, + cursor, + ) -> bool: + """Clear request idempotency fields for one known transport row.""" + normalized_request_id = self._normalize_request_id(request_id) + if normalized_request_id is None: + return False + row = cursor.execute( + """ + UPDATE artifact_transports + SET request_id = NULL, + request_fingerprint = NULL + WHERE transport_id = ? + AND request_id = ? + RETURNING transport_id + """, + [str(transport_id), normalized_request_id], + ).fetchone() + return row is not None + def update_status(self, transport_id: UUID, status: str, completed_at=None) -> bool: """Update transport status and optionally set completed_at.""" cursor = self.get_cursor() diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index d570240c..931e69e9 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -149,6 +149,20 @@ def list_targets(self, session_id: str) -> list[BroadcastTarget]: """List broadcast targets for a session.""" return self._broadcast_repository.list_targets(session_id) + def find_active_edge_for_target( + self, + *, + session_id: str, + target_worker_id: str, + cursor: DuckDBPyConnection, + ) -> BroadcastEdge | None: + """Return the current active edge for one target worker.""" + return self._broadcast_repository.find_active_edge_for_child( + session_id, + target_worker_id, + cursor=cursor, + ) + def cancel_session(self, session_id: str) -> bool: """Mark a broadcast session cancelled.""" return self._broadcast_repository.update_session_state( diff --git a/tensorcast/global_store/services/transport_service.py b/tensorcast/global_store/services/transport_service.py index 55804c28..2338f9e0 100644 --- a/tensorcast/global_store/services/transport_service.py +++ b/tensorcast/global_store/services/transport_service.py @@ -363,12 +363,6 @@ def request_transport( scheduling_group=scheduling_group, broadcast_hint=broadcast_hint, ) - existing = self._resolve_existing_request( - normalized_request_id, request_fingerprint - ) - if existing is not None: - return existing - if broadcast_hint is not None: return self._request_transport_broadcast( artifact_id=artifact_id, @@ -382,6 +376,12 @@ def request_transport( broadcast_hint=broadcast_hint, ) + existing = self._resolve_existing_request( + normalized_request_id, request_fingerprint + ) + if existing is not None: + return existing + return self._request_transport_group_dispatch( artifact_id=artifact_id, view_id=view_id, @@ -431,16 +431,15 @@ def _request_transport_broadcast( raise ValidationError( f"request_id={request_id} already used with different payload" ) - replica = self.replica_repository.find_by_id( - existing.replica_id, - existing.artifact_id, + replay = self._resolve_existing_broadcast_transport_replay( + existing=existing, + request_id=request_id, + session_id=broadcast_hint.session_id, + requester_worker_id=requester_worker_id, cursor=tx, ) - if replica is not None: - return replica, existing.transport_id - raise NotFoundError( - "existing broadcast transport source replica is missing" - ) + if replay is not None: + return replay try: replica, edge = self.broadcast_service.claim_transport_edge( @@ -518,6 +517,63 @@ def _request_transport_broadcast( raise exc.__cause__ from exc raise + def _resolve_existing_broadcast_transport_replay( + self, + *, + existing: Transport, + request_id: str, + session_id: str, + requester_worker_id: str, + cursor, + ) -> tuple[Replica, UUID] | None: + active_edge = None + if self.broadcast_service is not None: + active_edge = self.broadcast_service.find_active_edge_for_target( + session_id=session_id, + target_worker_id=requester_worker_id, + cursor=cursor, + ) + existing_edge_id = existing.broadcast_edge_id or "" + existing_session_id = existing.broadcast_session_id or "" + if ( + active_edge is not None + and existing_session_id == session_id + and existing_edge_id == active_edge.edge_id + ): + replica = self.replica_repository.find_by_id( + existing.replica_id, + existing.artifact_id, + cursor=cursor, + ) + if replica is None: + raise NotFoundError( + "existing broadcast transport source replica is missing" + ) + return replica, existing.transport_id + + if ( + active_edge is not None + and existing_session_id == session_id + and existing_edge_id + and existing_edge_id != active_edge.edge_id + ): + cleared = self.transport_repository.clear_request_identity_with_cursor( + transport_id=existing.transport_id, + request_id=request_id, + cursor=cursor, + ) + if cleared: + return None + + replica = self.replica_repository.find_by_id( + existing.replica_id, + existing.artifact_id, + cursor=cursor, + ) + if replica is not None: + return replica, existing.transport_id + raise NotFoundError("existing broadcast transport source replica is missing") + def _request_transport_group_dispatch( self, *, diff --git a/tests/python/global_store/test_broadcast_transport.py b/tests/python/global_store/test_broadcast_transport.py index c2cb667f..6f0aefdf 100644 --- a/tests/python/global_store/test_broadcast_transport.py +++ b/tests/python/global_store/test_broadcast_transport.py @@ -286,6 +286,154 @@ def test_duplicate_broadcast_request_reuses_existing_edge(servicer, test_context assert targets[0].assigned_edge_id == edges[0].edge_id +def test_failed_broadcast_transport_replay_claims_retry_edge(servicer, test_context): + artifact_id = "mi2:broadcast-replay-after-failure" + root_worker = _register_worker( + servicer, + test_context, + worker_id="root-replay-failure", + node_id="node-1", + port=55300, + ) + child_worker = _register_worker( + servicer, + test_context, + worker_id="child-replay-failure", + node_id="node-2", + port=55400, + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.53.0.1", + node_port=55301, + remote_key="rk-root-replay-failure", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-replay-failure", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.53.9.9", + source_port=59010, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-replay-failure", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-replay-failure", + strict_parent=True, + ), + ) + first = servicer.RequestReplicaTransport(request, test_context) + assert first.status == global_store_pb2.STATUS_OK + complete = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=first.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_FAILED, + ), + test_context, + ) + assert complete.status == global_store_pb2.STATUS_OK + retry_edge = [ + edge + for edge in servicer.broadcast_service.list_edges("session-replay-failure") + if edge.state is BroadcastEdgeState.PLANNED + ][0] + + replay = servicer.RequestReplicaTransport(request, test_context) + + assert replay.status == global_store_pb2.STATUS_OK + assert replay.transport_id != first.transport_id + edges = servicer.broadcast_service.list_edges("session-replay-failure") + retry = [edge for edge in edges if edge.edge_id == retry_edge.edge_id][0] + assert retry.state is BroadcastEdgeState.MATERIALIZING + + +def test_success_without_child_replay_claims_retry_edge(servicer, test_context): + artifact_id = "mi2:broadcast-replay-after-no-child" + root_worker = _register_worker( + servicer, + test_context, + worker_id="root-replay-no-child", + node_id="node-1", + port=56300, + ) + child_worker = _register_worker( + servicer, + test_context, + worker_id="child-replay-no-child", + node_id="node-2", + port=56400, + ) + root_replica_id = _register_exportable_replica( + servicer, + test_context, + artifact_id=artifact_id, + worker_id=root_worker, + node_id="node-1", + node_address="10.63.0.1", + node_port=56301, + remote_key="rk-root-replay-no-child", + ) + _create_broadcast_session( + servicer, + test_context, + session_id="session-replay-no-child", + artifact_id=artifact_id, + root_replica_id=root_replica_id, + child_worker_id=child_worker, + ) + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.63.9.9", + source_port=59011, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-replay-no-child", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-replay-no-child", + strict_parent=True, + ), + ) + first = servicer.RequestReplicaTransport(request, test_context) + assert first.status == global_store_pb2.STATUS_OK + complete = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=first.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_SUCCESS, + ), + test_context, + ) + assert complete.status == global_store_pb2.STATUS_OK + retry_edge = [ + edge + for edge in servicer.broadcast_service.list_edges("session-replay-no-child") + if edge.state is BroadcastEdgeState.PLANNED + ][0] + + replay = servicer.RequestReplicaTransport(request, test_context) + + assert replay.status == global_store_pb2.STATUS_OK + assert replay.transport_id != first.transport_id + edges = servicer.broadcast_service.list_edges("session-replay-no-child") + retry = [edge for edge in edges if edge.edge_id == retry_edge.edge_id][0] + assert retry.state is BroadcastEdgeState.MATERIALIZING + + def test_broadcast_request_existing_transport_missing_replica_does_not_claim_edge( servicer, test_context, From a785ff9adb94272eb7c5db0745624e1de3368f19 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:36:41 +0800 Subject: [PATCH 32/49] fix(global-store): reject exhausted broadcast replays --- .../services/transport_service.py | 5 +++ .../global_store/test_broadcast_transport.py | 40 +++++++++++-------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/tensorcast/global_store/services/transport_service.py b/tensorcast/global_store/services/transport_service.py index 2338f9e0..fd748c17 100644 --- a/tensorcast/global_store/services/transport_service.py +++ b/tensorcast/global_store/services/transport_service.py @@ -565,6 +565,11 @@ def _resolve_existing_broadcast_transport_replay( if cleared: return None + if self._is_terminal_failed_transport(existing): + raise NotFoundError( + "broadcast transport request is terminal and no retry edge is active" + ) + replica = self.replica_repository.find_by_id( existing.replica_id, existing.artifact_id, diff --git a/tests/python/global_store/test_broadcast_transport.py b/tests/python/global_store/test_broadcast_transport.py index 6f0aefdf..f45ce37f 100644 --- a/tests/python/global_store/test_broadcast_transport.py +++ b/tests/python/global_store/test_broadcast_transport.py @@ -758,24 +758,22 @@ def test_broadcast_max_attempt_exhaustion_marks_session_failed(servicer, test_co child_worker_id=child_worker, max_attempts=1, ) - transport_response = servicer.RequestReplicaTransport( - global_store_pb2.RequestReplicaTransportRequest( - artifact_id=artifact_id, - source_node_id="requester-node", - source_address="10.70.9.9", - source_port=59004, - requested_byte_space=common_pb2.ByteSpaceRef( - kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, - ), - requester_worker_id=child_worker, - request_id="request-max-attempts", - broadcast=global_store_pb2.BroadcastTransportHint( - session_id="session-max-attempts", - strict_parent=True, - ), + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id=artifact_id, + source_node_id="requester-node", + source_address="10.70.9.9", + source_port=59004, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + requester_worker_id=child_worker, + request_id="request-max-attempts", + broadcast=global_store_pb2.BroadcastTransportHint( + session_id="session-max-attempts", + strict_parent=True, ), - test_context, ) + transport_response = servicer.RequestReplicaTransport(request, test_context) assert transport_response.status == global_store_pb2.STATUS_OK complete_response = servicer.CompleteReplicaTransport( @@ -793,6 +791,16 @@ def test_broadcast_max_attempt_exhaustion_marks_session_failed(servicer, test_co assert session.state is BroadcastSessionState.FAILED assert target.state is BroadcastTargetState.FAILED + replay_response = servicer.RequestReplicaTransport(request, test_context) + session_after_replay = servicer.broadcast_service.get_session( + "session-max-attempts" + ) + + assert replay_response.status != global_store_pb2.STATUS_OK + assert replay_response.transport_id != transport_response.transport_id + assert session_after_replay is not None + assert session_after_replay.state is BroadcastSessionState.FAILED + def test_duplicate_broadcast_completion_is_noop(servicer, test_context): artifact_id = "mi2:broadcast-transport-complete-twice" From 72fc42c7f6ecadf77424ac9611343d7bcd6b4396 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:44:00 +0800 Subject: [PATCH 33/49] fix(global-store): preserve legacy transport fingerprints --- .../services/transport_service.py | 13 ++- tests/python/global_store/test_services.py | 81 +++++++++++++++++++ 2 files changed, 86 insertions(+), 8 deletions(-) diff --git a/tensorcast/global_store/services/transport_service.py b/tensorcast/global_store/services/transport_service.py index fd748c17..c4c78754 100644 --- a/tensorcast/global_store/services/transport_service.py +++ b/tensorcast/global_store/services/transport_service.py @@ -157,14 +157,6 @@ def _build_request_fingerprint( "source_address": source_address, "source_port": int(source_port), "requester_worker_id": (requester_worker_id or "").strip(), - "broadcast": ( - { - "session_id": broadcast_hint.session_id, - "strict_parent": bool(broadcast_hint.strict_parent), - } - if broadcast_hint is not None - else None - ), "scheduling_group": ( { "group_id": scheduling_group.group_id, @@ -178,6 +170,11 @@ def _build_request_fingerprint( else None ), } + if broadcast_hint is not None: + payload["broadcast"] = { + "session_id": broadcast_hint.session_id, + "strict_parent": bool(broadcast_hint.strict_parent), + } serialized = json.dumps(payload, sort_keys=True, separators=(",", ":")) return hashlib.sha256(serialized.encode("utf-8")).hexdigest() diff --git a/tests/python/global_store/test_services.py b/tests/python/global_store/test_services.py index f3ca20b2..b46d4581 100644 --- a/tests/python/global_store/test_services.py +++ b/tests/python/global_store/test_services.py @@ -4,6 +4,7 @@ import base64 import hashlib +import json from datetime import datetime, timedelta, timezone from uuid import UUID @@ -19,6 +20,7 @@ ValidationError, ) from tensorcast.global_store.models import ( + BroadcastTransportHint, ByteSpaceRef, ExportState, MemoryType, @@ -50,6 +52,85 @@ class TestServices: """Test service layer.""" + def test_transport_fingerprint_preserves_legacy_non_broadcast_shape(self): + expected_payload = { + "artifact_id": "mi2:legacy-fingerprint", + "view_id": "", + "source_node_id": "source-node", + "source_address": "10.1.2.3", + "source_port": 9090, + "requester_worker_id": "", + "scheduling_group": None, + } + expected = hashlib.sha256( + json.dumps( + expected_payload, + sort_keys=True, + separators=(",", ":"), + ).encode("utf-8") + ).hexdigest() + + actual = TransportService._build_request_fingerprint( + artifact_id="mi2:legacy-fingerprint", + view_id=None, + source_node_id="source-node", + source_address="10.1.2.3", + source_port=9090, + requester_worker_id=None, + scheduling_group=None, + broadcast_hint=None, + ) + + assert actual == expected + + def test_transport_fingerprint_includes_broadcast_hint(self): + without_broadcast = TransportService._build_request_fingerprint( + artifact_id="mi2:broadcast-fingerprint", + view_id=None, + source_node_id="source-node", + source_address="10.1.2.3", + source_port=9090, + requester_worker_id="worker-child", + scheduling_group=None, + broadcast_hint=None, + ) + with_broadcast = TransportService._build_request_fingerprint( + artifact_id="mi2:broadcast-fingerprint", + view_id=None, + source_node_id="source-node", + source_address="10.1.2.3", + source_port=9090, + requester_worker_id="worker-child", + scheduling_group=None, + broadcast_hint=BroadcastTransportHint( + session_id="session-a", + strict_parent=True, + ), + ) + expected_payload = { + "artifact_id": "mi2:broadcast-fingerprint", + "view_id": "", + "source_node_id": "source-node", + "source_address": "10.1.2.3", + "source_port": 9090, + "requester_worker_id": "worker-child", + "broadcast": { + "session_id": "session-a", + "strict_parent": True, + }, + "scheduling_group": None, + } + expected = hashlib.sha256( + json.dumps( + expected_payload, + sort_keys=True, + separators=(",", ":"), + ).encode("utf-8") + ).hexdigest() + + assert with_broadcast == expected + assert with_broadcast != without_broadcast + def test_worker_service_registration(self, services): """Test worker registration logic.""" worker_service = services["worker"] From bf580139525a0ad078a97af4cac6d42d43cefa0b Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:50:52 +0800 Subject: [PATCH 34/49] fix(global-store): stop cancelled broadcast advancement --- .../services/broadcast_service.py | 17 ++++ .../global_store/test_broadcast_service.py | 89 +++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index 931e69e9..23e8e516 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -278,6 +278,8 @@ def complete_transport_edge( session = self._broadcast_repository.find_session(session_id, cursor=cursor) if session is None: raise NotFoundError(f"broadcast session not found: {session_id}") + if session.state is not BroadcastSessionState.ACTIVE: + return edge = self._broadcast_repository.find_edge(edge_id, cursor=cursor) if edge is None: raise NotFoundError(f"broadcast edge not found: {edge_id}") @@ -352,6 +354,9 @@ def _fail_edge_and_maybe_retry( self._mark_session_terminal_if_done(session.session_id, cursor=cursor) def _mark_session_terminal_if_done(self, session_id: str, *, cursor=None) -> None: + session = self._broadcast_repository.find_session(session_id, cursor=cursor) + if session is None or session.state is not BroadcastSessionState.ACTIVE: + return if self._broadcast_repository.count_incomplete_targets( session_id, cursor=cursor, @@ -476,6 +481,18 @@ def _plan_more_edges( *, cursor=None, ) -> list[BroadcastEdge]: + if session.state is not BroadcastSessionState.ACTIVE: + return [] + current_session = self._broadcast_repository.find_session( + session.session_id, + cursor=cursor, + ) + if ( + current_session is None + or current_session.state is not BroadcastSessionState.ACTIVE + ): + return [] + session = current_session pending_targets = self._broadcast_repository.list_targets_by_state( session.session_id, BroadcastTargetState.PENDING, diff --git a/tests/python/global_store/test_broadcast_service.py b/tests/python/global_store/test_broadcast_service.py index 2db6bf48..901660f1 100644 --- a/tests/python/global_store/test_broadcast_service.py +++ b/tests/python/global_store/test_broadcast_service.py @@ -392,3 +392,92 @@ def test_complete_transport_edge_stale_failure_does_not_requeue_target( {"edge-stale", "edge-new"} if assigned_edge_id == "edge-new" else {"edge-stale"} ) assert session.state is BroadcastSessionState.ACTIVE + + +def test_complete_transport_edge_success_noops_after_session_cancelled(repositories): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root-cancel", "daemon-root-cancel", "node1") + child1 = _worker("worker-child-cancel-1", "daemon-child-cancel-1", "node2") + child2 = _worker("worker-child-cancel-2", "daemon-child-cancel-2", "node3") + for worker in (root, child1, child2): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-cancel", root)) + + service.create_session( + session_id="session-cancel-inflight", + artifact_id="mi2:model-cancel", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=[ + "daemon-child-cancel-1", + "daemon-child-cancel-2", + ], + root_replica_id=str(root_replica.replica_id), + strict_parent=True, + max_attempts=3, + ) + assigned_before_claim = [ + target + for target in broadcast_repo.list_targets("session-cancel-inflight") + if target.state is BroadcastTargetState.ASSIGNED + ] + assert len(assigned_before_claim) == 1 + claimed_worker_id = assigned_before_claim[0].target_worker_id + + with broadcast_repo.transaction() as tx: + _, edge = service.claim_transport_edge( + session_id="session-cancel-inflight", + artifact_id="mi2:model-cancel", + requested_view_id=None, + requester_worker_id=claimed_worker_id, + request_id="request-cancel-inflight", + heartbeat_timeout_seconds=30.0, + cursor=tx, + ) + + assert service.cancel_session("session-cancel-inflight") + completed_child = child1 if child1.worker_id == claimed_worker_id else child2 + replica_repo.create(_exportable_replica("mi2:model-cancel", completed_child)) + + with broadcast_repo.transaction() as tx: + service.complete_transport_edge( + session_id="session-cancel-inflight", + edge_id=edge.edge_id, + transport_outcome=TransportCompletionOutcome.SUCCESS, + outcome_detail=None, + cursor=tx, + ) + + session = broadcast_repo.find_session("session-cancel-inflight") + edge_after = broadcast_repo.find_edge(edge.edge_id) + targets = broadcast_repo.list_targets("session-cancel-inflight") + edges = broadcast_repo.list_edges("session-cancel-inflight") + + assert session is not None + assert edge_after is not None + assert session.state is BroadcastSessionState.CANCELLED + assert edge_after.state is BroadcastEdgeState.MATERIALIZING + assert edge_after.completed_at is None + assert len(edges) == 1 + materializing_targets = [ + target + for target in targets + if target.state is BroadcastTargetState.MATERIALIZING + ] + pending_targets = [ + target for target in targets if target.state is BroadcastTargetState.PENDING + ] + assert len(materializing_targets) == 1 + assert materializing_targets[0].target_worker_id == claimed_worker_id + assert materializing_targets[0].assigned_edge_id == edge.edge_id + assert len(pending_targets) == 1 + assert pending_targets[0].assigned_edge_id is None From c2aae7b521d2ae1778f520b1403879498ce0a0f3 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:09:31 +0800 Subject: [PATCH 35/49] feat(materialize): propagate broadcast session hints --- .../materialization/contracts/loading_spec.h | 7 +++ .../materialization_policy_utils.cc | 16 +++++ .../materialization_policy_utils.h | 3 + .../replica_materialization_service.cc | 7 +++ .../materialization_policy_utils_test.cc | 13 ++++ proto/tensorcast/daemon/v2/store_daemon.proto | 6 ++ tensorcast/__init__.py | 1 + tensorcast/api/__init__.py | 2 + tensorcast/api/_materialize.py | 4 ++ tensorcast/api/context.py | 19 ++++++ tensorcast/api/store/artifact.py | 8 +++ tensorcast/api/store/materialization.py | 22 +++++++ tensorcast/daemon_ctl.py | 11 ++++ .../api/test_daemon_ctl_broadcast_hint.py | 59 +++++++++++++++++++ tests/python/api/test_prefetch_operation.py | 17 ++++++ 15 files changed, 195 insertions(+) create mode 100644 tests/python/api/test_daemon_ctl_broadcast_hint.py diff --git a/core/store/materialization/contracts/loading_spec.h b/core/store/materialization/contracts/loading_spec.h index af7b535d..2499ca7a 100644 --- a/core/store/materialization/contracts/loading_spec.h +++ b/core/store/materialization/contracts/loading_spec.h @@ -171,6 +171,11 @@ struct CollectiveLoadGroupHint { uint32_t rank{0}; }; +struct BroadcastHint { + std::string session_id; + bool strict_parent{true}; +}; + struct RetrievalPolicy { SourcePreference preference{SourcePreference::kAuto}; bool allow_p2p{true}; @@ -213,6 +218,8 @@ struct MaterializeHints { std::optional transport_scheduling_group; // Optional same-host multi-rank hint for shared-window disk loading. std::optional collective_load_group; + // Optional broadcast session hint for coordinated global-store transport. + std::optional broadcast; // Optional topology-locality hint. This remains distinct from retrieval // policy so the strategy plane can reason about source sharing without // rewriting semantic or transport policy. diff --git a/daemon/service/controllers/materialization_policy_utils.cc b/daemon/service/controllers/materialization_policy_utils.cc index b247ba7a..c7ffe204 100644 --- a/daemon/service/controllers/materialization_policy_utils.cc +++ b/daemon/service/controllers/materialization_policy_utils.cc @@ -9,6 +9,7 @@ #include #include "absl/status/status.h" +#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "core/store/materialization/dataplane/view/view_identity.h" @@ -122,6 +123,21 @@ std::optional resolve_transport_sc }; } +std::optional resolve_broadcast_materialization_hint( + const v2::BroadcastMaterializationHint* hint) { + if (hint == nullptr) { + return std::nullopt; + } + std::string session_id = std::string(absl::StripAsciiWhitespace(hint->session_id())); + if (session_id.empty()) { + return std::nullopt; + } + return store::loading::BroadcastHint{ + .session_id = std::move(session_id), + .strict_parent = hint->strict_parent(), + }; +} + absl::StatusOr resolve_source_execution_topology( const v2::SourceExecutionTopology* topology) { ExecutionTopologyContext execution_topology; diff --git a/daemon/service/controllers/materialization_policy_utils.h b/daemon/service/controllers/materialization_policy_utils.h index 1f51d2f9..a9bd5ecf 100644 --- a/daemon/service/controllers/materialization_policy_utils.h +++ b/daemon/service/controllers/materialization_policy_utils.h @@ -57,6 +57,9 @@ std::optional resolve_collective_group_ std::optional resolve_transport_scheduling_group_hint( const v2::TransportSchedulingGroupHint* group); +std::optional resolve_broadcast_materialization_hint( + const v2::BroadcastMaterializationHint* hint); + absl::StatusOr resolve_source_execution_topology(const v2::SourceExecutionTopology* topology); absl::StatusOr resolve_collective_policy( diff --git a/daemon/service/controllers/replica_materialization_service.cc b/daemon/service/controllers/replica_materialization_service.cc index 697e8e22..1c878973 100644 --- a/daemon/service/controllers/replica_materialization_service.cc +++ b/daemon/service/controllers/replica_materialization_service.cc @@ -55,6 +55,7 @@ using materialization_policy::build_view_spec_proto; using materialization_policy::compute_view_id_from_spec; using materialization_policy::convert_view_spec; using materialization_policy::NormalizedMaterializationRequestContext; +using materialization_policy::resolve_broadcast_materialization_hint; using materialization_policy::resolve_collective_group_hint; using materialization_policy::resolve_materialization_request_context; using materialization_policy::resolve_transport_scheduling_group_hint; @@ -695,6 +696,12 @@ grpc::Status ReplicaMaterializationService::materialize_replica( hints.transport_scheduling_group = std::move(*group_hint); } } + if (req.has_broadcast()) { + auto broadcast_hint = resolve_broadcast_materialization_hint(&req.broadcast()); + if (broadcast_hint.has_value()) { + hints.broadcast = std::move(*broadcast_hint); + } + } if (prefer_direct_disk_for_local_import) { hints.set_retrieval_policy( store::loading::RetrievalPolicy{ diff --git a/daemon/service/materialization_policy_utils_test.cc b/daemon/service/materialization_policy_utils_test.cc index 2e23d824..4c981839 100644 --- a/daemon/service/materialization_policy_utils_test.cc +++ b/daemon/service/materialization_policy_utils_test.cc @@ -7,6 +7,7 @@ namespace { using tensorcast::daemon::materialization_policy::default_collective_policy_for_mapped_target; +using tensorcast::daemon::materialization_policy::resolve_broadcast_materialization_hint; using tensorcast::daemon::materialization_policy::resolve_transport_scheduling_group_hint; using tensorcast::store::loading::CollectiveLoadGroupHint; using tensorcast::store::loading::ExecutionTopologyContext; @@ -52,4 +53,16 @@ TEST_CASE("Transport scheduling group hint maps daemon proto", "[daemon][materia CHECK(hint->epoch == 42); } +TEST_CASE("Broadcast materialization hint maps daemon proto", "[daemon][materialization][policy]") { + v2::BroadcastMaterializationHint proto; + proto.set_session_id("session-a"); + proto.set_strict_parent(true); + + auto hint = resolve_broadcast_materialization_hint(&proto); + + REQUIRE(hint.has_value()); + CHECK(hint->session_id == "session-a"); + CHECK(hint->strict_parent); +} + } // namespace diff --git a/proto/tensorcast/daemon/v2/store_daemon.proto b/proto/tensorcast/daemon/v2/store_daemon.proto index 9503edb7..a392b1bd 100644 --- a/proto/tensorcast/daemon/v2/store_daemon.proto +++ b/proto/tensorcast/daemon/v2/store_daemon.proto @@ -524,6 +524,11 @@ message TransportSchedulingGroupHint { uint64 epoch = 6; } +message BroadcastMaterializationHint { + string session_id = 1; + bool strict_parent = 2; +} + message MaterializeReplicaRequest { tensorcast.common.v1.ArtifactSelection selection = 1; reserved 2, 9; @@ -557,6 +562,7 @@ message MaterializeReplicaRequest { ServingArtifactRuntimePolicy serving_artifact_policy = 20; string transport_request_id = 21; TransportSchedulingGroupHint transport_scheduling_group = 22; + BroadcastMaterializationHint broadcast = 23; } message CollectiveLoadGroup { diff --git a/tensorcast/__init__.py b/tensorcast/__init__.py index 60b01a0b..fb85fcf0 100644 --- a/tensorcast/__init__.py +++ b/tensorcast/__init__.py @@ -141,6 +141,7 @@ def _install_c_extension_bootstrap() -> None: "tensorcast.api", "calculate_tensor_device_offsets", ), + "BroadcastContext": ("tensorcast.api", "BroadcastContext"), "CallContext": ("tensorcast.api", "CallContext"), "CollectiveLoadGroup": ("tensorcast.api", "CollectiveLoadGroup"), "TransportSchedulingGroup": ("tensorcast.api", "TransportSchedulingGroup"), diff --git a/tensorcast/api/__init__.py b/tensorcast/api/__init__.py index 9cccd544..80bbdccc 100644 --- a/tensorcast/api/__init__.py +++ b/tensorcast/api/__init__.py @@ -26,6 +26,7 @@ ) from tensorcast.api._register import RegisteredLease, RegistrationResult from tensorcast.api.context import ( + BroadcastContext, CallContext, CollectiveLoadGroup, GovernanceContext, @@ -194,6 +195,7 @@ "BindingRealizationEntry", "BindingRealizationPlan", "Artifact", + "BroadcastContext", "CallContext", "CollectiveLoadGroup", "DirectorySnapshot", diff --git a/tensorcast/api/_materialize.py b/tensorcast/api/_materialize.py index 9fbac768..ed078c7a 100644 --- a/tensorcast/api/_materialize.py +++ b/tensorcast/api/_materialize.py @@ -288,6 +288,8 @@ def materialize_artifact_v2( lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, transport_request_id: str | None = None, transport_scheduling_group: TransportSchedulingGroup | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, ) -> MaterializationPayload: if artifact_id is not None and key is not None: raise ValueError("Exactly one of artifact_id or key must be provided") @@ -417,6 +419,8 @@ def materialize_artifact_v2( collective_load_group=collective_load_group, transport_request_id=transport_request_id, transport_scheduling_group=transport_group_proto, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, timeout_s=effective_timeout_s, timing_out=materialize_timing, ) diff --git a/tensorcast/api/context.py b/tensorcast/api/context.py index 34828e35..aad93589 100644 --- a/tensorcast/api/context.py +++ b/tensorcast/api/context.py @@ -59,6 +59,21 @@ def __post_init__(self) -> None: object.__setattr__(self, "request_id", request_id or None) +@dataclass(frozen=True, slots=True) +class BroadcastContext: + """Broadcast session hint for a materialization call.""" + + session_id: str + strict_parent: bool = True + + def __post_init__(self) -> None: + session_id = str(self.session_id).strip() + if not session_id: + raise ValueError("BroadcastContext.session_id must be non-empty") + object.__setattr__(self, "session_id", session_id) + object.__setattr__(self, "strict_parent", bool(self.strict_parent)) + + @dataclass(frozen=True, slots=True) class GovernanceContext: """Typed low-cardinality governance hints propagated with a plan.""" @@ -79,6 +94,7 @@ class CallContext: tags: Mapping[str, SpanAttributeValue] | None = None collective: CollectiveLoadGroup | None = None transport_group: TransportSchedulingGroup | None = None + broadcast: BroadcastContext | None = None governance: GovernanceContext | None = None @@ -91,6 +107,7 @@ def context( tags: Mapping[str, SpanAttributeValue] | None = None, collective: CollectiveLoadGroup | None = None, transport_group: TransportSchedulingGroup | None = None, + broadcast: BroadcastContext | None = None, governance: GovernanceContext | None = None, ) -> CallContext: return CallContext( @@ -101,12 +118,14 @@ def context( tags=tags, collective=collective, transport_group=transport_group, + broadcast=broadcast, governance=governance, ) __all__ = [ "CallContext", + "BroadcastContext", "CollectiveLoadGroup", "GovernanceContext", "QosClass", diff --git a/tensorcast/api/store/artifact.py b/tensorcast/api/store/artifact.py index a51158c7..f5448d6d 100644 --- a/tensorcast/api/store/artifact.py +++ b/tensorcast/api/store/artifact.py @@ -2274,6 +2274,14 @@ def _device_uuid_value() -> str: lease_mode=store_daemon_pb2.LeaseMode.LEASE_MODE_NO_LEASE, transport_request_id=transport_request_id, transport_scheduling_group=transport_scheduling_group, + broadcast_session_id=( + ctx.broadcast.session_id if ctx is not None and ctx.broadcast else None + ), + broadcast_strict_parent=( + ctx.broadcast.strict_parent + if ctx is not None and ctx.broadcast + else True + ), ) self._update_metadata_from_payload(payload, runtime) operation_id = payload.ticket_replica_uuid or payload.replica_uuid or "" diff --git a/tensorcast/api/store/materialization.py b/tensorcast/api/store/materialization.py index 8bb772b8..fb7203ef 100644 --- a/tensorcast/api/store/materialization.py +++ b/tensorcast/api/store/materialization.py @@ -382,6 +382,8 @@ def materialize_subset( lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, transport_request_id: str | None = None, transport_scheduling_group: TransportSchedulingGroup | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, ) -> tuple[MaterializationPayload, int]: return self._perform_get_with_retry( method="get", @@ -402,6 +404,8 @@ def materialize_subset( lease_mode=lease_mode, transport_request_id=transport_request_id, transport_scheduling_group=transport_scheduling_group, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, ) def get_view( @@ -1822,6 +1826,8 @@ def _materialize( replica_uuid: str | None = None, transport_request_id: str | None = None, transport_scheduling_group: TransportSchedulingGroup | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, ) -> MaterializationPayload: return self._materialize_payload( artifact_id=artifact_id, @@ -1843,6 +1849,8 @@ def _materialize( replica_uuid=replica_uuid, transport_request_id=transport_request_id, transport_scheduling_group=transport_scheduling_group, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, ) def _materialize_payload( @@ -1867,6 +1875,8 @@ def _materialize_payload( replica_uuid: str | None = None, transport_request_id: str | None = None, transport_scheduling_group: TransportSchedulingGroup | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, ) -> MaterializationPayload: client = self._runtime.ensure_client() resolved_artifact_id = artifact_id @@ -1920,6 +1930,8 @@ def _materialize_payload( lease_mode=lease_mode, transport_request_id=transport_request_id, transport_scheduling_group=transport_scheduling_group, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, ) disallowed_sources: set[store_daemon_pb2.MaterializationSource] = set() if not allow_p2p: @@ -2088,6 +2100,8 @@ def _perform_get_with_retry( lease_mode: store_daemon_pb2.LeaseMode = store_daemon_pb2.LeaseMode.LEASE_MODE_UNSPECIFIED, transport_request_id: str | None = None, transport_scheduling_group: TransportSchedulingGroup | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, ) -> tuple[MaterializationPayload, int]: options_snapshot = self._build_get_options(options_override) retrieval_policy = options_snapshot.source or RetrievalPolicy() @@ -2245,6 +2259,8 @@ def record_outcome(status: str) -> None: timeout_s=rpc_timeout_s, transport_request_id=transport_request_id, transport_scheduling_group=transport_scheduling_group, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, ) summary = self._summarize_materialized(materialized, tensor_names) selection_label = summary["selection"] @@ -2347,6 +2363,8 @@ def record_outcome(status: str) -> None: timeout_s=wait_timeout_s, transport_request_id=transport_request_id, transport_scheduling_group=transport_scheduling_group, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, ) except Exception as exc: # noqa: BLE001 error = map_materialization_error(exc) @@ -2480,6 +2498,8 @@ def _attempt_get( allow_cpu: bool = False, transport_request_id: str | None = None, transport_scheduling_group: TransportSchedulingGroup | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, ) -> tuple[MaterializationPayload, int]: artifact_id, key = self._resolve_identifiers(artifact_id, key) options = self._build_get_options(options_override) @@ -2505,6 +2525,8 @@ def _attempt_get( lease_mode=lease_mode, transport_request_id=transport_request_id, transport_scheduling_group=transport_scheduling_group, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, ) except Exception as exc: # noqa: BLE001 if "selection.logical_layout_hash does not match resolved selection" in str( diff --git a/tensorcast/daemon_ctl.py b/tensorcast/daemon_ctl.py index 4596057e..402cbe37 100644 --- a/tensorcast/daemon_ctl.py +++ b/tensorcast/daemon_ctl.py @@ -1991,6 +1991,8 @@ def materialize_by_artifact_id_v2( transport_request_id: str | None = None, transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> store_daemon_pb2.MaterializeReplicaResponse: ... @@ -2017,6 +2019,8 @@ def materialize_by_artifact_id_v2( transport_request_id: str | None = None, transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> tuple[bytes, store_daemon_pb2.MaterializeReplicaStatus]: ... @@ -2042,6 +2046,8 @@ def materialize_by_artifact_id_v2( transport_request_id: str | None = None, transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> bytes: ... @@ -2066,6 +2072,8 @@ def materialize_by_artifact_id_v2( transport_request_id: str | None = None, transport_scheduling_group: store_daemon_pb2.TransportSchedulingGroupHint | None = None, + broadcast_session_id: str | None = None, + broadcast_strict_parent: bool = True, timeout_s: float | int | None = None, timing_out: dict[str, float] | None = None, ) -> ( @@ -2106,6 +2114,9 @@ def materialize_by_artifact_id_v2( request.transport_request_id = str(transport_request_id) if transport_scheduling_group is not None: request.transport_scheduling_group.CopyFrom(transport_scheduling_group) + if broadcast_session_id: + request.broadcast.session_id = str(broadcast_session_id) + request.broadcast.strict_parent = bool(broadcast_strict_parent) if wait_for_shared_disk_ms: request.wait_for_shared_disk_ms = int(wait_for_shared_disk_ms) request.source_policy.CopyFrom(resolved_source_policy) diff --git a/tests/python/api/test_daemon_ctl_broadcast_hint.py b/tests/python/api/test_daemon_ctl_broadcast_hint.py new file mode 100644 index 00000000..41ebd25a --- /dev/null +++ b/tests/python/api/test_daemon_ctl_broadcast_hint.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026, TensorCast Team. + +from __future__ import annotations + +from tensorcast.daemon_ctl import DaemonCtl +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.daemon.v2 import store_daemon_pb2 + + +class _FakeUnary: + _method = b"/tensorcast.daemon.v2.StoreDaemonService/MaterializeReplica" + + def __init__(self) -> None: + self.requests: list[store_daemon_pb2.MaterializeReplicaRequest] = [] + + def __call__(self, request, timeout=None): # noqa: ANN001, ANN204 + del timeout + self.requests.append(request) + response = store_daemon_pb2.MaterializeReplicaResponse() + response.status = store_daemon_pb2.MaterializeReplicaStatus.MATERIALIZE_REPLICA_STATUS_ALLOCATED + response.ticket.replica_uuid = request.replica_uuid + return response + + +class _FakeStub: + def __init__(self) -> None: + self.MaterializeReplica = _FakeUnary() + + +def test_daemon_ctl_copies_broadcast_hint_to_materialize_request(monkeypatch) -> None: # noqa: ANN001 + ctl = DaemonCtl.__new__(DaemonCtl) + ctl.server_address = "fake-daemon" + fake_stub = _FakeStub() + ctl.stub_v2 = fake_stub + ctl.stub = fake_stub + monkeypatch.setattr(ctl, "_get_effective_pid", lambda: 123) + monkeypatch.setattr( + ctl, + "_unary_call", + lambda method, request, **kwargs: method( + request, + timeout=kwargs.get("timeout"), + ), + ) + selection = common_pb2.ArtifactSelection(artifact_id="aid") + + ctl.materialize_by_artifact_id_v2( + selection=selection, + replica_uuid="replica-1", + device_uuid="device-uuid", + wait_for_completion=False, + return_response=True, + broadcast_session_id="session-a", + broadcast_strict_parent=True, + ) + + request = fake_stub.MaterializeReplica.requests[0] + assert request.broadcast.session_id == "session-a" + assert request.broadcast.strict_parent is True diff --git a/tests/python/api/test_prefetch_operation.py b/tests/python/api/test_prefetch_operation.py index 9a7a5979..dc4a3c6c 100644 --- a/tests/python/api/test_prefetch_operation.py +++ b/tests/python/api/test_prefetch_operation.py @@ -235,3 +235,20 @@ def test_prefetch_without_group_sends_no_transport_hint() -> None: call = store._materialization.calls[0] assert call["transport_request_id"] is None assert call["transport_scheduling_group"] is None + + +def test_prefetch_forwards_broadcast_context_hint() -> None: + store = _Store() + artifact = Artifact(store_ref=weakref.ref(store), artifact_id="aid") + ctx = tc.context( + broadcast=tc.BroadcastContext( + session_id="broadcast-session-1", + strict_parent=True, + ) + ) + + artifact.prefetch(device="cuda:0", ctx=ctx) + + call = store._materialization.calls[0] + assert call["broadcast_session_id"] == "broadcast-session-1" + assert call["broadcast_strict_parent"] is True From abdc1481bfc46228a3265cf964516a145684a77b Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:12:39 +0800 Subject: [PATCH 36/49] fix(api): export broadcast context at top level --- tensorcast/__init__.py | 1 + tests/python/api/test_prefetch_operation.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/tensorcast/__init__.py b/tensorcast/__init__.py index fb85fcf0..58bd9e27 100644 --- a/tensorcast/__init__.py +++ b/tensorcast/__init__.py @@ -521,6 +521,7 @@ def __dir__() -> list[str]: "calculate_tensor_device_offsets", "build_indices_from_safetensors", "binding_realization_plan_to_proto", + "BroadcastContext", "CallContext", "CollectiveLoadGroup", "TransportSchedulingGroup", diff --git a/tests/python/api/test_prefetch_operation.py b/tests/python/api/test_prefetch_operation.py index dc4a3c6c..2884140f 100644 --- a/tests/python/api/test_prefetch_operation.py +++ b/tests/python/api/test_prefetch_operation.py @@ -252,3 +252,12 @@ def test_prefetch_forwards_broadcast_context_hint() -> None: call = store._materialization.calls[0] assert call["broadcast_session_id"] == "broadcast-session-1" assert call["broadcast_strict_parent"] is True + + +def test_top_level_exports_broadcast_context() -> None: + assert "BroadcastContext" in tc.__all__ + + ctx = tc.BroadcastContext(session_id=" broadcast-session-1 ") + + assert ctx.session_id == "broadcast-session-1" + assert ctx.strict_parent is True From c21cce7f80ddad45846b297365be8d16d3497857 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:19:26 +0800 Subject: [PATCH 37/49] fix(api): propagate broadcast context through materialization --- tensorcast/api/store/artifact.py | 8 +++ tensorcast/api/store/materialization.py | 18 ++++++ tests/python/api/test_artifact_handle.py | 25 ++++++++ .../api/test_materialization_pipeline_v2.py | 61 +++++++++++++++++++ 4 files changed, 112 insertions(+) diff --git a/tensorcast/api/store/artifact.py b/tensorcast/api/store/artifact.py index f5448d6d..471ad6e6 100644 --- a/tensorcast/api/store/artifact.py +++ b/tensorcast/api/store/artifact.py @@ -886,6 +886,14 @@ def tensor_dict_with_diagnostics( replica_uuid=replica_uuid, options=options, ctx=ctx, + broadcast_session_id=( + ctx.broadcast.session_id if ctx is not None and ctx.broadcast else None + ), + broadcast_strict_parent=( + ctx.broadcast.strict_parent + if ctx is not None and ctx.broadcast + else True + ), ) materialize_end = time.perf_counter() state: dict[str, torch.Tensor] | None = None diff --git a/tensorcast/api/store/materialization.py b/tensorcast/api/store/materialization.py index fb7203ef..d9311074 100644 --- a/tensorcast/api/store/materialization.py +++ b/tensorcast/api/store/materialization.py @@ -83,6 +83,19 @@ class _MaterializationSummary(TypedDict): selection: str | None +def _resolve_broadcast_hints( + *, + ctx: CallContext | None, + broadcast_session_id: str | None, + broadcast_strict_parent: bool, +) -> tuple[str | None, bool]: + if broadcast_session_id is not None: + return str(broadcast_session_id), bool(broadcast_strict_parent) + if ctx is None or ctx.broadcast is None: + return None, bool(broadcast_strict_parent) + return ctx.broadcast.session_id, ctx.broadcast.strict_parent + + @dataclass(frozen=True, slots=True) class _RegionBackedLayout: layout: store_daemon_pb2.TargetLayout @@ -2103,6 +2116,11 @@ def _perform_get_with_retry( broadcast_session_id: str | None = None, broadcast_strict_parent: bool = True, ) -> tuple[MaterializationPayload, int]: + broadcast_session_id, broadcast_strict_parent = _resolve_broadcast_hints( + ctx=ctx, + broadcast_session_id=broadcast_session_id, + broadcast_strict_parent=broadcast_strict_parent, + ) options_snapshot = self._build_get_options(options_override) retrieval_policy = options_snapshot.source or RetrievalPolicy() wait_for_shared_disk_ms = int(options_snapshot.wait_for_shared_disk_ms) diff --git a/tests/python/api/test_artifact_handle.py b/tests/python/api/test_artifact_handle.py index 2c376c38..7f18593d 100644 --- a/tests/python/api/test_artifact_handle.py +++ b/tests/python/api/test_artifact_handle.py @@ -13,6 +13,7 @@ import torch from tensorcast.api._materialize import MaterializationPayload, TensorPayloadDescriptor +from tensorcast.api.context import BroadcastContext, CallContext from tensorcast.api.store import Store from tensorcast.api.store.artifact import Artifact from tensorcast.api.store.cache import ArtifactCache, ArtifactCacheEntry @@ -359,6 +360,30 @@ def test_tensor_dict_with_diagnostics_reports_source_and_bytes(): assert diagnostics.total_sec >= diagnostics.materialize_sec +def test_tensor_dict_with_diagnostics_forwards_broadcast_context_hint(): + canonical_bytes, payload = _build_payload({"foo": torch.ones(1)}) + runtime = _RuntimeStub(_ClientStub(canonical_bytes)) + pipeline = _PipelineStub(payload) + store = _StoreStub(runtime, pipeline) + artifact = Artifact( + store_ref=_store_ref(store), + artifact_id="aid", + canonical_index_bytes=canonical_bytes, + ) + ctx = CallContext( + broadcast=BroadcastContext( + session_id="broadcast-session-1", + strict_parent=False, + ) + ) + + result = artifact.tensor_dict_with_diagnostics(device="cpu", ctx=ctx) + + assert set(result.tensors) == {"foo"} + assert pipeline.calls[0]["broadcast_session_id"] == "broadcast-session-1" + assert pipeline.calls[0]["broadcast_strict_parent"] is False + + def test_bind_coerces_serving_manifest_into_runtime_policy( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/python/api/test_materialization_pipeline_v2.py b/tests/python/api/test_materialization_pipeline_v2.py index 68b1e2d1..66822a44 100644 --- a/tests/python/api/test_materialization_pipeline_v2.py +++ b/tests/python/api/test_materialization_pipeline_v2.py @@ -18,6 +18,7 @@ _resolve_collective_load_group, ) from tensorcast.api.context import ( + BroadcastContext, CallContext, CollectiveLoadGroup, TransportSchedulingGroup, @@ -496,3 +497,63 @@ def fake_materialize(**kwargs): assert materialized.replica_uuid == "transport" assert captured["transport_request_id"] == "transport-req-1" assert captured["transport_scheduling_group"] == group + + +def test_get_forwards_broadcast_context_hint(): + runtime = _RuntimeStub() + views = ViewOrchestrator(runtime) + pipeline = MaterializationPipeline(runtime, views) + payload = _make_payload({"a": torch.ones(1)}, replica_uuid="broadcast") + captured: dict[str, object] = {} + ctx = CallContext( + broadcast=BroadcastContext( + session_id="broadcast-session-1", + strict_parent=False, + ) + ) + + def fake_materialize(**kwargs): + captured.update(kwargs) + return payload + + pipeline.set_materialize_fn(fake_materialize) + result = pipeline.get(artifact_id="aid", ctx=ctx) + runtime.close() + + assert torch.equal(result["a"], torch.ones(1)) + assert captured["broadcast_session_id"] == "broadcast-session-1" + assert captured["broadcast_strict_parent"] is False + + +def test_materialize_subset_explicit_broadcast_hint_overrides_context(): + runtime = _RuntimeStub() + views = ViewOrchestrator(runtime) + pipeline = MaterializationPipeline(runtime, views) + payload = _make_payload({"a": torch.ones(1)}, replica_uuid="broadcast") + captured: dict[str, object] = {} + ctx = CallContext( + broadcast=BroadcastContext( + session_id="ctx-session", + strict_parent=False, + ) + ) + + def fake_materialize(**kwargs): + captured.update(kwargs) + return payload + + pipeline.set_materialize_fn(fake_materialize) + materialized, _ = pipeline.materialize_subset( + artifact_id="aid", + key=None, + device=0, + tensor_names=None, + ctx=ctx, + broadcast_session_id="explicit-session", + broadcast_strict_parent=True, + ) + runtime.close() + + assert materialized.replica_uuid == "broadcast" + assert captured["broadcast_session_id"] == "explicit-session" + assert captured["broadcast_strict_parent"] is True From cb84e576419f810532445355fb67a661f0f4f9de Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:29:55 +0800 Subject: [PATCH 38/49] feat(core): enforce broadcast transport parent hints --- core/store/components/global_store_client.cc | 15 ++++ core/store/components/global_store_client.h | 9 ++ .../control/materialize_orchestrator.cc | 58 +++++++++++-- ...terialize_orchestrator_reselection_test.cc | 84 ++++++++++++++++++- .../ingestion/materialization_facade.cc | 17 ++++ .../testing/recording_global_store_client.h | 6 ++ 6 files changed, 183 insertions(+), 6 deletions(-) diff --git a/core/store/components/global_store_client.cc b/core/store/components/global_store_client.cc index 027d9b6d..14188d1b 100644 --- a/core/store/components/global_store_client.cc +++ b/core/store/components/global_store_client.cc @@ -132,6 +132,17 @@ void apply_transport_scheduling_group_hint( group->set_epoch(scheduling_group->epoch); } +void apply_broadcast_transport_hint( + const std::optional& broadcast_hint, + global_store::RequestReplicaTransportRequest* request) { + if (!broadcast_hint.has_value() || broadcast_hint->session_id.empty()) { + return; + } + auto* broadcast = request->mutable_broadcast(); + broadcast->set_session_id(broadcast_hint->session_id); + broadcast->set_strict_parent(broadcast_hint->strict_parent); +} + std::string build_transport_request_id(std::string_view operation_kind) { static std::atomic transport_request_sequence{1}; const std::uint64_t sequence = transport_request_sequence.fetch_add(1, std::memory_order_relaxed); @@ -2376,6 +2387,7 @@ absl::StatusOr GlobalStoreClient::request_replica_transport( const DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, + const std::optional& broadcast_hint, std::string_view requester_worker_id, std::string_view request_id) { const std::string effective_request_id = @@ -2386,6 +2398,7 @@ absl::StatusOr GlobalStoreClient::request_replica_transport( request.set_source_address(std::string(source_address)); request.set_source_port(source_port); apply_transport_scheduling_group_hint(scheduling_group, &request); + apply_broadcast_transport_hint(broadcast_hint, &request); if (!requester_worker_id.empty()) { request.set_requester_worker_id(std::string(requester_worker_id)); } @@ -2459,6 +2472,7 @@ absl::StatusOr GlobalStoreClient::request_view_transport( const DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, + const std::optional& broadcast_hint, std::string_view requester_worker_id, std::string_view request_id) { if (view_id.empty()) { @@ -2473,6 +2487,7 @@ absl::StatusOr GlobalStoreClient::request_view_transport( request.set_source_address(std::string(source_address)); request.set_source_port(source_port); apply_transport_scheduling_group_hint(scheduling_group, &request); + apply_broadcast_transport_hint(broadcast_hint, &request); if (!requester_worker_id.empty()) { request.set_requester_worker_id(std::string(requester_worker_id)); } diff --git a/core/store/components/global_store_client.h b/core/store/components/global_store_client.h index 15dd52b3..cc7b41f6 100644 --- a/core/store/components/global_store_client.h +++ b/core/store/components/global_store_client.h @@ -136,6 +136,11 @@ struct TransportSchedulingGroupHint { uint64_t epoch{0}; }; +struct BroadcastTransportHint { + std::string session_id; + bool strict_parent{true}; +}; + enum class TransportCompletionOutcome : uint8_t { kUnspecified = 0, kSuccess = 1, @@ -573,6 +578,7 @@ class IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, + const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, std::string_view request_id = {}) = 0; @@ -585,6 +591,7 @@ class IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, + const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, std::string_view request_id = {}) = 0; @@ -1006,6 +1013,7 @@ class GlobalStoreClient : public IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, + const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, std::string_view request_id = {}) override; absl::StatusOr request_view_transport( @@ -1017,6 +1025,7 @@ class GlobalStoreClient : public IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, + const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, std::string_view request_id = {}) override; diff --git a/core/store/materialization/control/materialize_orchestrator.cc b/core/store/materialization/control/materialize_orchestrator.cc index dd886cd8..d732aac9 100644 --- a/core/store/materialization/control/materialize_orchestrator.cc +++ b/core/store/materialization/control/materialize_orchestrator.cc @@ -257,6 +257,16 @@ std::optional to_transport_scheduling_ return out; } +std::optional to_broadcast_transport_hint(const MaterializeHints& hints) { + if (!hints.broadcast.has_value() || hints.broadcast->session_id.empty()) { + return std::nullopt; + } + components::BroadcastTransportHint out; + out.session_id = hints.broadcast->session_id; + out.strict_parent = hints.broadcast->strict_parent; + return out; +} + bool should_log_reselection_attempt(int reselection_attempt) { return reselection_attempt <= 5 || (reselection_attempt % 10) == 0; } @@ -367,6 +377,7 @@ absl::StatusOr MaterializeOrchestrator::run( const auto request_deadline = resolve_request_deadline(hints); const int max_reselection_attempts = resolve_max_reselection_attempts(hints); const auto scheduling_group_hint = to_transport_scheduling_group_hint(hints); + const auto broadcast_hint = to_broadcast_transport_hint(hints); const std::string_view requester_worker_id = hints.transport_requester_worker_id.empty() ? std::string_view(local_identity_.worker_id) : std::string_view(hints.transport_requester_worker_id); @@ -390,6 +401,7 @@ absl::StatusOr MaterializeOrchestrator::run( target_device, view_probe_timeout_ms, scheduling_group_hint, + broadcast_hint, requester_worker_id, transport_request_id); if (!view_transport_or.ok() && @@ -414,6 +426,7 @@ absl::StatusOr MaterializeOrchestrator::run( target_device, wait_timeout_ms, scheduling_group_hint, + broadcast_hint, requester_worker_id, transport_request_id); used_canonical_transport_fallback = true; @@ -429,6 +442,7 @@ absl::StatusOr MaterializeOrchestrator::run( target_device, wait_timeout_ms, scheduling_group_hint, + broadcast_hint, requester_worker_id, transport_request_id); }; @@ -499,17 +513,51 @@ absl::StatusOr MaterializeOrchestrator::run( auto load_or = backend_->ingest_from_p2p(std::string(artifact_id), p2p_src, target, hints); if (load_or.ok()) { + const auto& handle = *load_or; + absl::Status reg_status = backend_->register_replica_with_global_store(handle.key(), {}); + if (!reg_status.ok()) { + LOG(WARNING) << "register_replica_with_global_store returned error: " << reg_status; + if (broadcast_hint.has_value()) { + absl::Status comp_status = gs_client_->complete_replica_transport( + session.transport_id, components::TransportCompletionOutcome::kFailed, reg_status.ToString()); + if (!comp_status.ok()) { + LOG(WARNING) << "complete_replica_transport after broadcast registration failure returned error: " + << comp_status; + } + last_p2p_status = reg_status; + if (can_retry_source_selection( + last_p2p_status, reselection_attempt, request_deadline, max_reselection_attempts)) { + record_stale_source_detected("broadcast_register_failure", view_id.has_value(), reselection_attempt); + reselection_attempt += 1; + record_source_reselection_attempt("broadcast_register_failure", view_id.has_value(), reselection_attempt); + if (should_log_reselection_attempt(reselection_attempt)) { + const std::chrono::milliseconds retry_remaining = remaining_request_budget(request_deadline); + const std::string remaining_label = retry_remaining == std::chrono::milliseconds::max() + ? "unbounded" + : std::to_string(retry_remaining.count()); + LOG(WARNING) << "Retrying source selection after broadcast registration failure: artifact_id=" + << artifact_id << " attempt=" << reselection_attempt << "/" + << max_reselection_attempts << " remaining_budget_ms=" << remaining_label + << " status=" << last_p2p_status; + } + continue; + } + if (should_retry_source_selection(last_p2p_status)) { + last_p2p_status = source_reselection_exhausted_status( + artifact_id, reselection_attempt, request_deadline, last_p2p_status); + record_source_reselection_exhausted( + "broadcast_register_failure", view_id.has_value(), reselection_attempt); + } + break; + } + } + // Notify GS that transport finished absl::Status comp_status = gs_client_->complete_replica_transport( session.transport_id, components::TransportCompletionOutcome::kSuccess); if (!comp_status.ok()) { LOG(WARNING) << "complete_replica_transport returned error: " << comp_status; } - const auto& handle = *load_or; - absl::Status reg_status = backend_->register_replica_with_global_store(handle.key(), {}); - if (!reg_status.ok()) { - LOG(WARNING) << "register_replica_with_global_store returned error: " << reg_status; - } if (used_canonical_transport_fallback && view_id.has_value()) { LOG(INFO) << "materialize_view loaded via canonical transport fallback: artifact_id=" << artifact_id diff --git a/core/store/materialization/control/materialize_orchestrator_reselection_test.cc b/core/store/materialization/control/materialize_orchestrator_reselection_test.cc index 2ca8fbe2..36fb1e37 100644 --- a/core/store/materialization/control/materialize_orchestrator_reselection_test.cc +++ b/core/store/materialization/control/materialize_orchestrator_reselection_test.cc @@ -37,6 +37,7 @@ class FakeMaterializationBackend final : public MaterializationBackend { std::vector p2p_attempts; std::vector p2p_scripted_statuses; + absl::Status register_status{absl::OkStatus()}; int register_calls{0}; absl::StatusOr ingest_from_p2p( @@ -80,7 +81,7 @@ class FakeMaterializationBackend final : public MaterializationBackend { absl::Status register_replica_with_global_store(const ReplicaKey&, std::string_view, std::string_view) override { register_calls += 1; - return absl::OkStatus(); + return register_status; } }; @@ -283,6 +284,87 @@ TEST_CASE( CHECK(gs_client->replica_request_wait_timeouts_ms.front() == 5000); } +TEST_CASE( + "MaterializeOrchestrator propagates broadcast transport parent hint", + "[store][materialize][reselection][broadcast]") { + auto gs_client = std::make_shared(); + gs_client->connected = true; + gs_client->allow_replica_transport = true; + gs_client->push_scripted_transport_session(make_transport_session( + "transport-broadcast", "node-remote", "10.6.6.2", 50052, common::memory::MemoryLocation::GPU, 0)); + + FakeMaterializationBackend backend; + MaterializeHints hints; + hints.artifact_id = "artifact-broadcast-hint"; + hints.allow_p2p = true; + hints.allow_disk = false; + hints.transport_request_id = "request-broadcast-1"; + hints.broadcast = loading::BroadcastHint{ + .session_id = "session-a", + .strict_parent = true, + }; + + components::WorkerIdentity local_identity{ + .worker_id = "worker-local", + .node_id = "node-local", + .node_address = "10.6.6.1", + .p2p_port = 50051, + }; + MaterializeOrchestrator orchestrator( + gsl::not_null{&backend}, + gsl::not_null{gs_client.get()}, + local_identity); + + auto result = orchestrator.run("artifact-broadcast-hint", make_gpu_target(0), hints, std::nullopt); + REQUIRE(result.ok()); + REQUIRE(gs_client->replica_request_broadcast_hints.size() == 1); + REQUIRE(gs_client->replica_request_broadcast_hints.front().has_value()); + CHECK(gs_client->replica_request_broadcast_hints.front()->session_id == "session-a"); + CHECK(gs_client->replica_request_broadcast_hints.front()->strict_parent); +} + +TEST_CASE( + "MaterializeOrchestrator completes broadcast transport as failed when local registration fails", + "[store][materialize][reselection][broadcast]") { + auto gs_client = std::make_shared(); + gs_client->connected = true; + gs_client->allow_replica_transport = false; + gs_client->push_scripted_transport_session(make_transport_session( + "transport-broadcast", "node-remote", "10.7.7.2", 50062, common::memory::MemoryLocation::GPU, 0)); + + FakeMaterializationBackend backend; + backend.register_status = absl::UnavailableError("register failed"); + + MaterializeHints hints; + hints.artifact_id = "artifact-broadcast-register-fails"; + hints.allow_p2p = true; + hints.allow_disk = false; + hints.transport_request_id = "request-broadcast-register-fails"; + hints.broadcast = loading::BroadcastHint{ + .session_id = "session-register-fails", + .strict_parent = true, + }; + + components::WorkerIdentity local_identity{ + .worker_id = "worker-local", + .node_id = "node-local", + .node_address = "10.7.7.1", + .p2p_port = 50061, + }; + MaterializeOrchestrator orchestrator( + gsl::not_null{&backend}, + gsl::not_null{gs_client.get()}, + local_identity); + + auto result = orchestrator.run("artifact-broadcast-register-fails", make_gpu_target(0), hints, std::nullopt); + REQUIRE_FALSE(result.ok()); + REQUIRE(backend.register_calls == 1); + REQUIRE(gs_client->completed_transport_ids.size() == 1); + CHECK(gs_client->completed_transport_ids.front() == "transport-broadcast"); + REQUIRE(gs_client->completed_transport_outcomes.size() == 1); + CHECK(gs_client->completed_transport_outcomes.front() == TransportCompletionOutcome::kFailed); +} + TEST_CASE( "MaterializeOrchestrator propagates transport scheduler hint metadata", "[store][materialize][reselection][scheduler_hint]") { diff --git a/core/store/runtime/ingestion/materialization_facade.cc b/core/store/runtime/ingestion/materialization_facade.cc index cfc967db..751047f5 100644 --- a/core/store/runtime/ingestion/materialization_facade.cc +++ b/core/store/runtime/ingestion/materialization_facade.cc @@ -1264,6 +1264,17 @@ std::optional to_transport_scheduling_ return out; } +std::optional to_broadcast_transport_hint( + const loading::MaterializeHints& hints) { + if (!hints.broadcast.has_value() || hints.broadcast->session_id.empty()) { + return std::nullopt; + } + components::BroadcastTransportHint out; + out.session_id = hints.broadcast->session_id; + out.strict_parent = hints.broadcast->strict_parent; + return out; +} + loading::ReplicaHandle build_local_replica_handle( const loading::ReplicaKey& key, const std::shared_ptr& replica, @@ -3753,6 +3764,7 @@ absl::StatusOr MaterializationFacade::mate } const auto scheduling_group_hint = to_transport_scheduling_group_hint(hints); + const auto broadcast_hint = to_broadcast_transport_hint(hints); const std::string_view requester_worker_id = hints.transport_requester_worker_id.empty() ? std::string_view(local_identity.worker_id) : std::string_view(hints.transport_requester_worker_id); @@ -3767,6 +3779,7 @@ absl::StatusOr MaterializationFacade::mate target_device, resolve_transport_wait_timeout_ms(hints), scheduling_group_hint, + broadcast_hint, requester_worker_id, transport_request_id); if (transport_or.ok()) { @@ -4571,6 +4584,7 @@ absl::StatusOr MaterializationFacade::mate requested_view_id = *request_hints.variant->view_id; } const auto scheduling_group_hint = to_transport_scheduling_group_hint(request_hints); + const auto broadcast_hint = to_broadcast_transport_hint(request_hints); const std::string_view requester_worker_id = request_hints.transport_requester_worker_id.empty() ? std::string_view(local_identity.worker_id) : std::string_view(request_hints.transport_requester_worker_id); @@ -4592,6 +4606,7 @@ absl::StatusOr MaterializationFacade::mate target_device, view_probe_timeout_ms, scheduling_group_hint, + broadcast_hint, requester_worker_id, transport_request_id); if (!view_transport_or.ok() && @@ -4619,6 +4634,7 @@ absl::StatusOr MaterializationFacade::mate target_device, wait_timeout_ms, scheduling_group_hint, + broadcast_hint, requester_worker_id, transport_request_id); } @@ -4632,6 +4648,7 @@ absl::StatusOr MaterializationFacade::mate target_device, wait_timeout_ms, scheduling_group_hint, + broadcast_hint, requester_worker_id, transport_request_id); }; diff --git a/core/store/testing/recording_global_store_client.h b/core/store/testing/recording_global_store_client.h index ae5415c0..b9663861 100644 --- a/core/store/testing/recording_global_store_client.h +++ b/core/store/testing/recording_global_store_client.h @@ -31,6 +31,8 @@ class RecordingGlobalStoreClient final : public components::IGlobalStoreClient { std::vector replica_requests; std::vector> view_request_groups; std::vector> replica_request_groups; + std::vector> view_request_broadcast_hints; + std::vector> replica_request_broadcast_hints; std::vector view_request_request_ids; std::vector replica_request_request_ids; std::vector view_request_requester_worker_ids; @@ -522,10 +524,12 @@ class RecordingGlobalStoreClient final : public components::IGlobalStoreClient { const tensorcast::store::DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, + const std::optional& broadcast_hint, std::string_view requester_worker_id, std::string_view request_id) override { replica_requests.emplace_back(std::string(artifact_id)); replica_request_groups.push_back(scheduling_group); + replica_request_broadcast_hints.push_back(broadcast_hint); replica_request_request_ids.emplace_back(std::string(request_id)); replica_request_requester_worker_ids.emplace_back(std::string(requester_worker_id)); replica_request_wait_timeouts_ms.push_back(wait_timeout_ms); @@ -554,10 +558,12 @@ class RecordingGlobalStoreClient final : public components::IGlobalStoreClient { const tensorcast::store::DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, + const std::optional& broadcast_hint, std::string_view requester_worker_id, std::string_view request_id) override { view_requests.emplace_back(std::string(view_id)); view_request_groups.push_back(scheduling_group); + view_request_broadcast_hints.push_back(broadcast_hint); view_request_request_ids.emplace_back(std::string(request_id)); view_request_requester_worker_ids.emplace_back(std::string(requester_worker_id)); view_request_wait_timeouts_ms.push_back(wait_timeout_ms); From 6c1f6671ab270a18863d0e74b5fe962f2c8f9648 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:36:49 +0800 Subject: [PATCH 39/49] fix(core): update global store client test stubs --- core/store/runtime/metadata/metadata_gateway_test.cc | 2 ++ core/store/testing/global_store_client_stub.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/core/store/runtime/metadata/metadata_gateway_test.cc b/core/store/runtime/metadata/metadata_gateway_test.cc index 5445d129..84c670bf 100644 --- a/core/store/runtime/metadata/metadata_gateway_test.cc +++ b/core/store/runtime/metadata/metadata_gateway_test.cc @@ -484,6 +484,7 @@ class TestGlobalStoreClient final : public tensorcast::store::components::IGloba const DeviceKey&, uint32_t, const std::optional&, + const std::optional&, std::string_view, std::string_view) override { return absl::UnimplementedError("request_replica_transport not used in tests"); @@ -498,6 +499,7 @@ class TestGlobalStoreClient final : public tensorcast::store::components::IGloba const DeviceKey&, uint32_t, const std::optional&, + const std::optional&, std::string_view, std::string_view) override { return absl::UnimplementedError("request_view_transport not used in tests"); diff --git a/core/store/testing/global_store_client_stub.h b/core/store/testing/global_store_client_stub.h index 3415dd93..c6537fef 100644 --- a/core/store/testing/global_store_client_stub.h +++ b/core/store/testing/global_store_client_stub.h @@ -146,6 +146,7 @@ class GlobalStoreClientStub : public components::IGlobalStoreClient { const DeviceKey&, uint32_t, const std::optional&, + const std::optional&, std::string_view, std::string_view) override { return absl::UnimplementedError("request_replica_transport not supported in GlobalStoreClientStub"); @@ -160,6 +161,7 @@ class GlobalStoreClientStub : public components::IGlobalStoreClient { const DeviceKey&, uint32_t, const std::optional&, + const std::optional&, std::string_view, std::string_view) override { return absl::UnimplementedError("request_view_transport not supported in GlobalStoreClientStub"); From 7c80302485d4fed0f0d6bc7fb82a350cdcad6649 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 19:48:46 +0800 Subject: [PATCH 40/49] fix(core): preserve broadcast transport compatibility --- core/store/components/global_store_client.cc | 8 +- core/store/components/global_store_client.h | 16 +- .../control/materialize_orchestrator.cc | 91 ++++++---- ...terialize_orchestrator_reselection_test.cc | 170 ++++++++++++++++++ .../ingestion/materialization_facade.cc | 16 +- .../runtime/metadata/metadata_gateway_test.cc | 8 +- core/store/testing/global_store_client_stub.h | 8 +- .../testing/recording_global_store_client.h | 8 +- 8 files changed, 256 insertions(+), 69 deletions(-) diff --git a/core/store/components/global_store_client.cc b/core/store/components/global_store_client.cc index 14188d1b..42755e11 100644 --- a/core/store/components/global_store_client.cc +++ b/core/store/components/global_store_client.cc @@ -2387,9 +2387,9 @@ absl::StatusOr GlobalStoreClient::request_replica_transport( const DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, - const std::optional& broadcast_hint, std::string_view requester_worker_id, - std::string_view request_id) { + std::string_view request_id, + const std::optional& broadcast_hint) { const std::string effective_request_id = request_id.empty() ? build_transport_request_id("canonical") : std::string(request_id); global_store::RequestReplicaTransportRequest request; @@ -2472,9 +2472,9 @@ absl::StatusOr GlobalStoreClient::request_view_transport( const DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, - const std::optional& broadcast_hint, std::string_view requester_worker_id, - std::string_view request_id) { + std::string_view request_id, + const std::optional& broadcast_hint) { if (view_id.empty()) { return absl::InvalidArgumentError("view_id must be non-empty for view transport"); } diff --git a/core/store/components/global_store_client.h b/core/store/components/global_store_client.h index cc7b41f6..53df4e48 100644 --- a/core/store/components/global_store_client.h +++ b/core/store/components/global_store_client.h @@ -578,9 +578,9 @@ class IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, - const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, - std::string_view request_id = {}) = 0; + std::string_view request_id = {}, + const std::optional& broadcast_hint = std::nullopt) = 0; virtual absl::StatusOr request_view_transport( std::string_view artifact_id, @@ -591,9 +591,9 @@ class IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, - const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, - std::string_view request_id = {}) = 0; + std::string_view request_id = {}, + const std::optional& broadcast_hint = std::nullopt) = 0; virtual absl::Status complete_replica_transport( std::string_view transport_id, @@ -1013,9 +1013,9 @@ class GlobalStoreClient : public IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, - const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, - std::string_view request_id = {}) override; + std::string_view request_id = {}, + const std::optional& broadcast_hint = std::nullopt) override; absl::StatusOr request_view_transport( std::string_view artifact_id, std::string_view view_id, @@ -1025,9 +1025,9 @@ class GlobalStoreClient : public IGlobalStoreClient { const DeviceKey& target_device, uint32_t wait_timeout_ms = 30000, const std::optional& scheduling_group = std::nullopt, - const std::optional& broadcast_hint = std::nullopt, std::string_view requester_worker_id = {}, - std::string_view request_id = {}) override; + std::string_view request_id = {}, + const std::optional& broadcast_hint = std::nullopt) override; absl::Status complete_replica_transport( std::string_view transport_id, diff --git a/core/store/materialization/control/materialize_orchestrator.cc b/core/store/materialization/control/materialize_orchestrator.cc index d732aac9..1ac8c7ea 100644 --- a/core/store/materialization/control/materialize_orchestrator.cc +++ b/core/store/materialization/control/materialize_orchestrator.cc @@ -401,9 +401,9 @@ absl::StatusOr MaterializeOrchestrator::run( target_device, view_probe_timeout_ms, scheduling_group_hint, - broadcast_hint, requester_worker_id, - transport_request_id); + transport_request_id, + broadcast_hint); if (!view_transport_or.ok() && (absl::IsNotFound(view_transport_or.status()) || absl::IsUnimplemented(view_transport_or.status()) || absl::IsDeadlineExceeded(view_transport_or.status()))) { @@ -426,9 +426,9 @@ absl::StatusOr MaterializeOrchestrator::run( target_device, wait_timeout_ms, scheduling_group_hint, - broadcast_hint, requester_worker_id, - transport_request_id); + transport_request_id, + broadcast_hint); used_canonical_transport_fallback = true; return canonical_transport_or; } @@ -442,9 +442,9 @@ absl::StatusOr MaterializeOrchestrator::run( target_device, wait_timeout_ms, scheduling_group_hint, - broadcast_hint, requester_worker_id, - transport_request_id); + transport_request_id, + broadcast_hint); }; while (gs_connected && allow_p2p) { @@ -514,42 +514,59 @@ absl::StatusOr MaterializeOrchestrator::run( auto load_or = backend_->ingest_from_p2p(std::string(artifact_id), p2p_src, target, hints); if (load_or.ok()) { const auto& handle = *load_or; + if (!broadcast_hint.has_value()) { + absl::Status comp_status = gs_client_->complete_replica_transport( + session.transport_id, components::TransportCompletionOutcome::kSuccess); + if (!comp_status.ok()) { + LOG(WARNING) << "complete_replica_transport returned error: " << comp_status; + } + absl::Status reg_status = backend_->register_replica_with_global_store(handle.key(), {}); + if (!reg_status.ok()) { + LOG(WARNING) << "register_replica_with_global_store returned error: " << reg_status; + } + + if (used_canonical_transport_fallback && view_id.has_value()) { + LOG(INFO) << "materialize_view loaded via canonical transport fallback: artifact_id=" << artifact_id + << " view_id=" << *view_id; + } + if (reselection_attempt > 0) { + record_source_reselection_success("p2p_load", view_id.has_value(), reselection_attempt); + } + return load_or; + } + absl::Status reg_status = backend_->register_replica_with_global_store(handle.key(), {}); if (!reg_status.ok()) { LOG(WARNING) << "register_replica_with_global_store returned error: " << reg_status; - if (broadcast_hint.has_value()) { - absl::Status comp_status = gs_client_->complete_replica_transport( - session.transport_id, components::TransportCompletionOutcome::kFailed, reg_status.ToString()); - if (!comp_status.ok()) { - LOG(WARNING) << "complete_replica_transport after broadcast registration failure returned error: " - << comp_status; - } - last_p2p_status = reg_status; - if (can_retry_source_selection( - last_p2p_status, reselection_attempt, request_deadline, max_reselection_attempts)) { - record_stale_source_detected("broadcast_register_failure", view_id.has_value(), reselection_attempt); - reselection_attempt += 1; - record_source_reselection_attempt("broadcast_register_failure", view_id.has_value(), reselection_attempt); - if (should_log_reselection_attempt(reselection_attempt)) { - const std::chrono::milliseconds retry_remaining = remaining_request_budget(request_deadline); - const std::string remaining_label = retry_remaining == std::chrono::milliseconds::max() - ? "unbounded" - : std::to_string(retry_remaining.count()); - LOG(WARNING) << "Retrying source selection after broadcast registration failure: artifact_id=" - << artifact_id << " attempt=" << reselection_attempt << "/" - << max_reselection_attempts << " remaining_budget_ms=" << remaining_label - << " status=" << last_p2p_status; - } - continue; - } - if (should_retry_source_selection(last_p2p_status)) { - last_p2p_status = source_reselection_exhausted_status( - artifact_id, reselection_attempt, request_deadline, last_p2p_status); - record_source_reselection_exhausted( - "broadcast_register_failure", view_id.has_value(), reselection_attempt); + absl::Status comp_status = gs_client_->complete_replica_transport( + session.transport_id, components::TransportCompletionOutcome::kFailed, reg_status.ToString()); + if (!comp_status.ok()) { + LOG(WARNING) << "complete_replica_transport after broadcast registration failure returned error: " + << comp_status; + } + last_p2p_status = reg_status; + if (can_retry_source_selection( + last_p2p_status, reselection_attempt, request_deadline, max_reselection_attempts)) { + record_stale_source_detected("broadcast_register_failure", view_id.has_value(), reselection_attempt); + reselection_attempt += 1; + record_source_reselection_attempt("broadcast_register_failure", view_id.has_value(), reselection_attempt); + if (should_log_reselection_attempt(reselection_attempt)) { + const std::chrono::milliseconds retry_remaining = remaining_request_budget(request_deadline); + const std::string remaining_label = retry_remaining == std::chrono::milliseconds::max() + ? "unbounded" + : std::to_string(retry_remaining.count()); + LOG(WARNING) << "Retrying source selection after broadcast registration failure: artifact_id=" + << artifact_id << " attempt=" << reselection_attempt << "/" << max_reselection_attempts + << " remaining_budget_ms=" << remaining_label << " status=" << last_p2p_status; } - break; + continue; + } + if (should_retry_source_selection(last_p2p_status)) { + last_p2p_status = + source_reselection_exhausted_status(artifact_id, reselection_attempt, request_deadline, last_p2p_status); + record_source_reselection_exhausted("broadcast_register_failure", view_id.has_value(), reselection_attempt); } + break; } // Notify GS that transport finished diff --git a/core/store/materialization/control/materialize_orchestrator_reselection_test.cc b/core/store/materialization/control/materialize_orchestrator_reselection_test.cc index 36fb1e37..6797c92f 100644 --- a/core/store/materialization/control/materialize_orchestrator_reselection_test.cc +++ b/core/store/materialization/control/materialize_orchestrator_reselection_test.cc @@ -3,6 +3,7 @@ #include "core/store/materialization/control/materialize_orchestrator.h" #include +#include #include #include #include @@ -37,7 +38,9 @@ class FakeMaterializationBackend final : public MaterializationBackend { std::vector p2p_attempts; std::vector p2p_scripted_statuses; + std::vector register_scripted_statuses; absl::Status register_status{absl::OkStatus()}; + std::function on_register; int register_calls{0}; absl::StatusOr ingest_from_p2p( @@ -80,7 +83,14 @@ class FakeMaterializationBackend final : public MaterializationBackend { } absl::Status register_replica_with_global_store(const ReplicaKey&, std::string_view, std::string_view) override { + const size_t call_index = static_cast(register_calls); register_calls += 1; + if (on_register) { + on_register(); + } + if (call_index < register_scripted_statuses.size()) { + return register_scripted_statuses[call_index]; + } return register_status; } }; @@ -114,6 +124,21 @@ DeviceKey make_gpu_target(int ordinal) { }; } +absl::StatusOr request_replica_transport_with_legacy_positional_args( + components::IGlobalStoreClient& client, + const DeviceKey& target_device) { + return client.request_replica_transport( + "artifact-legacy", + "node-legacy", + "10.9.9.1", + 50090, + target_device, + 5000, + std::nullopt, + "worker-legacy", + "request-legacy"); +} + TEST_CASE("MaterializeOrchestrator accepts local route returned by Global Store", "[store][materialize][reselection]") { auto gs_client = std::make_shared(); gs_client->connected = true; @@ -148,6 +173,61 @@ TEST_CASE("MaterializeOrchestrator accepts local route returned by Global Store" CHECK(gs_client->completed_transport_outcomes[0] == TransportCompletionOutcome::kSuccess); } +TEST_CASE("GlobalStoreClient request transport keeps legacy positional arguments", "[store][materialize][reselection]") { + RecordingGlobalStoreClient gs_client; + gs_client.connected = true; + gs_client.allow_replica_transport = true; + + auto result = request_replica_transport_with_legacy_positional_args(gs_client, make_gpu_target(0)); + + REQUIRE(result.ok()); + REQUIRE(gs_client.replica_requests.size() == 1); + CHECK(gs_client.replica_request_requester_worker_ids.front() == "worker-legacy"); + CHECK(gs_client.replica_request_request_ids.front() == "request-legacy"); + REQUIRE(gs_client.replica_request_broadcast_hints.size() == 1); + CHECK_FALSE(gs_client.replica_request_broadcast_hints.front().has_value()); +} + +TEST_CASE( + "MaterializeOrchestrator preserves non-broadcast transport completion before best-effort registration", + "[store][materialize][reselection]") { + auto gs_client = std::make_shared(); + gs_client->connected = true; + gs_client->allow_replica_transport = true; + gs_client->push_scripted_transport_session(make_transport_session( + "transport-non-broadcast", "node-remote", "10.1.1.2", 50002, common::memory::MemoryLocation::GPU, 0)); + + bool completed_before_register = false; + FakeMaterializationBackend backend; + backend.register_status = absl::UnavailableError("best-effort register failed"); + backend.on_register = [&]() { + completed_before_register = !gs_client->completed_transport_ids.empty(); + }; + + MaterializeHints hints; + hints.artifact_id = "artifact-non-broadcast-register-fails"; + hints.allow_p2p = true; + hints.allow_disk = false; + + components::WorkerIdentity local_identity{ + .worker_id = "worker-local", + .node_id = "node-local", + .node_address = "10.1.1.1", + .p2p_port = 50001, + }; + MaterializeOrchestrator orchestrator( + gsl::not_null{&backend}, + gsl::not_null{gs_client.get()}, + local_identity); + + auto result = orchestrator.run("artifact-non-broadcast-register-fails", make_gpu_target(0), hints, std::nullopt); + REQUIRE(result.ok()); + REQUIRE(backend.register_calls == 1); + REQUIRE(gs_client->completed_transport_ids.size() == 1); + CHECK(gs_client->completed_transport_outcomes.front() == TransportCompletionOutcome::kSuccess); + CHECK(completed_before_register); +} + TEST_CASE("MaterializeOrchestrator reselects source after retryable P2P failure", "[store][materialize][reselection]") { auto gs_client = std::make_shared(); gs_client->connected = true; @@ -365,6 +445,96 @@ TEST_CASE( CHECK(gs_client->completed_transport_outcomes.front() == TransportCompletionOutcome::kFailed); } +TEST_CASE( + "MaterializeOrchestrator retries broadcast source selection after registration failure", + "[store][materialize][reselection][broadcast]") { + auto gs_client = std::make_shared(); + gs_client->connected = true; + gs_client->allow_replica_transport = true; + gs_client->push_scripted_transport_session(make_transport_session( + "transport-broadcast-a", "node-a", "10.8.8.2", 50072, common::memory::MemoryLocation::GPU, 0)); + gs_client->push_scripted_transport_session(make_transport_session( + "transport-broadcast-b", "node-b", "10.8.8.3", 50073, common::memory::MemoryLocation::GPU, 0)); + + FakeMaterializationBackend backend; + backend.register_scripted_statuses = {absl::UnavailableError("register failed"), absl::OkStatus()}; + + MaterializeHints hints; + hints.artifact_id = "artifact-broadcast-register-retry"; + hints.allow_p2p = true; + hints.allow_disk = false; + hints.transport_request_id = "request-broadcast-register-retry"; + hints.broadcast = loading::BroadcastHint{ + .session_id = "session-register-retry", + .strict_parent = true, + }; + + components::WorkerIdentity local_identity{ + .worker_id = "worker-local", + .node_id = "node-local", + .node_address = "10.8.8.1", + .p2p_port = 50071, + }; + MaterializeOrchestrator orchestrator( + gsl::not_null{&backend}, + gsl::not_null{gs_client.get()}, + local_identity); + + auto result = orchestrator.run("artifact-broadcast-register-retry", make_gpu_target(0), hints, std::nullopt); + REQUIRE(result.ok()); + REQUIRE(backend.register_calls == 2); + REQUIRE(backend.p2p_attempts.size() == 2); + CHECK(backend.p2p_attempts[0].source_ip == "10.8.8.2"); + CHECK(backend.p2p_attempts[1].source_ip == "10.8.8.3"); + REQUIRE(gs_client->completed_transport_ids.size() == 2); + CHECK(gs_client->completed_transport_ids[0] == "transport-broadcast-a"); + CHECK(gs_client->completed_transport_outcomes[0] == TransportCompletionOutcome::kFailed); + CHECK(gs_client->completed_transport_ids[1] == "transport-broadcast-b"); + CHECK(gs_client->completed_transport_outcomes[1] == TransportCompletionOutcome::kSuccess); +} + +TEST_CASE( + "MaterializeOrchestrator returns terminal broadcast registration failure status", + "[store][materialize][reselection][broadcast]") { + auto gs_client = std::make_shared(); + gs_client->connected = true; + gs_client->allow_replica_transport = true; + gs_client->push_scripted_transport_session(make_transport_session( + "transport-broadcast-terminal", "node-remote", "10.9.9.2", 50092, common::memory::MemoryLocation::GPU, 0)); + + FakeMaterializationBackend backend; + backend.register_status = absl::InvalidArgumentError("invalid child replica"); + + MaterializeHints hints; + hints.artifact_id = "artifact-broadcast-terminal-register-fails"; + hints.allow_p2p = true; + hints.allow_disk = false; + hints.transport_request_id = "request-broadcast-terminal-register-fails"; + hints.broadcast = loading::BroadcastHint{ + .session_id = "session-terminal-register-fails", + .strict_parent = true, + }; + + components::WorkerIdentity local_identity{ + .worker_id = "worker-local", + .node_id = "node-local", + .node_address = "10.9.9.1", + .p2p_port = 50091, + }; + MaterializeOrchestrator orchestrator( + gsl::not_null{&backend}, + gsl::not_null{gs_client.get()}, + local_identity); + + auto result = + orchestrator.run("artifact-broadcast-terminal-register-fails", make_gpu_target(0), hints, std::nullopt); + REQUIRE_FALSE(result.ok()); + CHECK(absl::IsInvalidArgument(result.status())); + REQUIRE(gs_client->completed_transport_ids.size() == 1); + CHECK(gs_client->completed_transport_ids.front() == "transport-broadcast-terminal"); + CHECK(gs_client->completed_transport_outcomes.front() == TransportCompletionOutcome::kFailed); +} + TEST_CASE( "MaterializeOrchestrator propagates transport scheduler hint metadata", "[store][materialize][reselection][scheduler_hint]") { diff --git a/core/store/runtime/ingestion/materialization_facade.cc b/core/store/runtime/ingestion/materialization_facade.cc index 751047f5..11d12a2a 100644 --- a/core/store/runtime/ingestion/materialization_facade.cc +++ b/core/store/runtime/ingestion/materialization_facade.cc @@ -3779,9 +3779,9 @@ absl::StatusOr MaterializationFacade::mate target_device, resolve_transport_wait_timeout_ms(hints), scheduling_group_hint, - broadcast_hint, requester_worker_id, - transport_request_id); + transport_request_id, + broadcast_hint); if (transport_or.ok()) { const auto& session = *transport_or; const auto& remote = session.remote_replica; @@ -4606,9 +4606,9 @@ absl::StatusOr MaterializationFacade::mate target_device, view_probe_timeout_ms, scheduling_group_hint, - broadcast_hint, requester_worker_id, - transport_request_id); + transport_request_id, + broadcast_hint); if (!view_transport_or.ok() && (absl::IsNotFound(view_transport_or.status()) || absl::IsUnimplemented(view_transport_or.status()) || absl::IsDeadlineExceeded(view_transport_or.status()))) { @@ -4634,9 +4634,9 @@ absl::StatusOr MaterializationFacade::mate target_device, wait_timeout_ms, scheduling_group_hint, - broadcast_hint, requester_worker_id, - transport_request_id); + transport_request_id, + broadcast_hint); } return view_transport_or; } @@ -4648,9 +4648,9 @@ absl::StatusOr MaterializationFacade::mate target_device, wait_timeout_ms, scheduling_group_hint, - broadcast_hint, requester_worker_id, - transport_request_id); + transport_request_id, + broadcast_hint); }; auto transport_or = request_transport(); diff --git a/core/store/runtime/metadata/metadata_gateway_test.cc b/core/store/runtime/metadata/metadata_gateway_test.cc index 84c670bf..bebf5d72 100644 --- a/core/store/runtime/metadata/metadata_gateway_test.cc +++ b/core/store/runtime/metadata/metadata_gateway_test.cc @@ -484,9 +484,9 @@ class TestGlobalStoreClient final : public tensorcast::store::components::IGloba const DeviceKey&, uint32_t, const std::optional&, - const std::optional&, std::string_view, - std::string_view) override { + std::string_view, + const std::optional&) override { return absl::UnimplementedError("request_replica_transport not used in tests"); } @@ -499,9 +499,9 @@ class TestGlobalStoreClient final : public tensorcast::store::components::IGloba const DeviceKey&, uint32_t, const std::optional&, - const std::optional&, std::string_view, - std::string_view) override { + std::string_view, + const std::optional&) override { return absl::UnimplementedError("request_view_transport not used in tests"); } diff --git a/core/store/testing/global_store_client_stub.h b/core/store/testing/global_store_client_stub.h index c6537fef..c49131b1 100644 --- a/core/store/testing/global_store_client_stub.h +++ b/core/store/testing/global_store_client_stub.h @@ -146,9 +146,9 @@ class GlobalStoreClientStub : public components::IGlobalStoreClient { const DeviceKey&, uint32_t, const std::optional&, - const std::optional&, std::string_view, - std::string_view) override { + std::string_view, + const std::optional&) override { return absl::UnimplementedError("request_replica_transport not supported in GlobalStoreClientStub"); } @@ -161,9 +161,9 @@ class GlobalStoreClientStub : public components::IGlobalStoreClient { const DeviceKey&, uint32_t, const std::optional&, - const std::optional&, std::string_view, - std::string_view) override { + std::string_view, + const std::optional&) override { return absl::UnimplementedError("request_view_transport not supported in GlobalStoreClientStub"); } diff --git a/core/store/testing/recording_global_store_client.h b/core/store/testing/recording_global_store_client.h index b9663861..f47770e2 100644 --- a/core/store/testing/recording_global_store_client.h +++ b/core/store/testing/recording_global_store_client.h @@ -524,9 +524,9 @@ class RecordingGlobalStoreClient final : public components::IGlobalStoreClient { const tensorcast::store::DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, - const std::optional& broadcast_hint, std::string_view requester_worker_id, - std::string_view request_id) override { + std::string_view request_id, + const std::optional& broadcast_hint) override { replica_requests.emplace_back(std::string(artifact_id)); replica_request_groups.push_back(scheduling_group); replica_request_broadcast_hints.push_back(broadcast_hint); @@ -558,9 +558,9 @@ class RecordingGlobalStoreClient final : public components::IGlobalStoreClient { const tensorcast::store::DeviceKey& target_device, uint32_t wait_timeout_ms, const std::optional& scheduling_group, - const std::optional& broadcast_hint, std::string_view requester_worker_id, - std::string_view request_id) override { + std::string_view request_id, + const std::optional& broadcast_hint) override { view_requests.emplace_back(std::string(view_id)); view_request_groups.push_back(scheduling_group); view_request_broadcast_hints.push_back(broadcast_hint); From f21a48dc207082b7a8e402a975e3defc35c8bbc0 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:13:50 +0800 Subject: [PATCH 41/49] feat(daemon): expose broadcast session creation --- core/store/components/global_store_client.cc | 18 ++ core/store/components/global_store_client.h | 12 + daemon/BUILD | 16 ++ daemon/service/grpc_service_impl.h | 5 + ...rpc_service_impl_broadcast_session_test.cc | 149 +++++++++++++ .../grpc_service_impl_rpc_delegates.cc | 62 ++++++ proto/tensorcast/daemon/v2/store_daemon.proto | 26 +++ tensorcast/api/store/__init__.py | 72 ++++++ tensorcast/daemon_ctl.py | 62 ++++++ .../python/api/test_broadcast_session_api.py | 210 ++++++++++++++++++ 10 files changed, 632 insertions(+) create mode 100644 daemon/service/grpc_service_impl_broadcast_session_test.cc create mode 100644 tests/python/api/test_broadcast_session_api.py diff --git a/core/store/components/global_store_client.cc b/core/store/components/global_store_client.cc index 42755e11..c8f57431 100644 --- a/core/store/components/global_store_client.cc +++ b/core/store/components/global_store_client.cc @@ -2565,6 +2565,24 @@ absl::StatusOr GlobalStoreClient::request_view_transport( return session; } +absl::StatusOr GlobalStoreClient::create_broadcast_session( + const global_store::CreateBroadcastSessionRequest& request, + const RpcOptions& rpc_options) { + global_store::CreateBroadcastSessionResponse response; + auto status = execute_rpc_with_retry( + request, + &response, + [this](auto* ctx, const auto& req, auto* resp) { + return cluster_runtime_stub_->CreateBroadcastSession(ctx, req, resp); + }, + "CreateBroadcastSession", + rpc_options); + if (!status.ok()) { + return status; + } + return response; +} + absl::Status GlobalStoreClient::complete_replica_transport( std::string_view transport_id, TransportCompletionOutcome outcome, diff --git a/core/store/components/global_store_client.h b/core/store/components/global_store_client.h index 53df4e48..802be408 100644 --- a/core/store/components/global_store_client.h +++ b/core/store/components/global_store_client.h @@ -604,6 +604,14 @@ class IGlobalStoreClient { std::string_view artifact_id, std::optional view_id = std::nullopt) = 0; + virtual absl::StatusOr create_broadcast_session( + const global_store::CreateBroadcastSessionRequest& request, + const RpcOptions& rpc_options = RpcOptions{}) { + (void)request; + (void)rpc_options; + return absl::UnimplementedError("CreateBroadcastSession not available"); + } + virtual absl::StatusOr> query_chunk_locations( std::string_view artifact_id, const std::vector& chunk_indices) = 0; @@ -1038,6 +1046,10 @@ class GlobalStoreClient : public IGlobalStoreClient { std::string_view artifact_id, std::optional view_id = std::nullopt) override; + absl::StatusOr create_broadcast_session( + const global_store::CreateBroadcastSessionRequest& request, + const RpcOptions& rpc_options = RpcOptions{}) override; + absl::StatusOr> query_chunk_locations( std::string_view artifact_id, const std::vector& chunk_indices) override; diff --git a/daemon/BUILD b/daemon/BUILD index 0f2553f0..79d2f84f 100644 --- a/daemon/BUILD +++ b/daemon/BUILD @@ -2006,6 +2006,22 @@ cc_test( ], ) +cc_test( + name = "grpc_service_impl_broadcast_session_test", + srcs = ["service/grpc_service_impl_broadcast_session_test.cc"], + deps = [ + ":daemon_service_harness_lib", + ":grpc_service_impl", + "//core/store:store_engine", + "//core/store:store_engine_options", + "//core/store:testing_global_store_client_stub", + "//proto/tensorcast/daemon/v2:daemon_grpc_cc", + "//proto/tensorcast/global_store/v1:global_store_cc", + "@abseil-cpp//absl/status", + "@catch2//:catch2_main", + ], +) + cc_test( name = "daemon_shutdown_drain_test", srcs = ["app/daemon_shutdown_drain_test.cc"], diff --git a/daemon/service/grpc_service_impl.h b/daemon/service/grpc_service_impl.h index ec89e667..48e45d6f 100644 --- a/daemon/service/grpc_service_impl.h +++ b/daemon/service/grpc_service_impl.h @@ -226,6 +226,11 @@ class StoreDaemonServiceImpl final : public v2::StoreDaemonService::Service { const v2::UnlockTransportChunksRequest* req, v2::UnlockTransportChunksResponse* resp) override; + grpc::Status CreateBroadcastSession( + grpc::ServerContext* ctx, + const v2::CreateBroadcastSessionRequest* req, + v2::CreateBroadcastSessionResponse* resp) override; + grpc::Status BeginRegisterArtifact( grpc::ServerContext* ctx, const v2::BeginRegisterArtifactRequest* req, diff --git a/daemon/service/grpc_service_impl_broadcast_session_test.cc b/daemon/service/grpc_service_impl_broadcast_session_test.cc new file mode 100644 index 00000000..31c9eac2 --- /dev/null +++ b/daemon/service/grpc_service_impl_broadcast_session_test.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2026, TensorCast Team. + +#include "daemon/testing/daemon_service_harness.h" + +#include +#include +#include +#include +#include + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "core/store/store_engine.h" +#include "core/store/store_engine_options.h" +#include "core/store/testing/global_store_client_stub.h" +#include "grpcpp/server_context.h" +#include "tensorcast/common/v1/common.pb.h" +#include "tensorcast/daemon/v2/store_daemon.grpc.pb.h" +#include "tensorcast/global_store/v1/global_store.pb.h" + +namespace { + +namespace common = tensorcast::common::v1; +namespace daemon = tensorcast::daemon; +namespace daemon_v2 = tensorcast::daemon::v2; +namespace global_store = tensorcast::global_store::v1; +namespace store = tensorcast::store; + +std::filesystem::path test_tmpdir() { + const char* env = std::getenv("TEST_TMPDIR"); + if (env != nullptr && *env != '\0') { + return std::filesystem::path(env); + } + return std::filesystem::temp_directory_path() / "tensorcast_daemon_broadcast_session_test"; +} + +store::StoreEngineOptions make_opts() { + store::StoreEngineOptions opts; + opts.storage_path = (test_tmpdir() / "engine").string(); + std::filesystem::create_directories(opts.storage_path); + opts.p2p_port = 0; + opts.memory_pool_size = 32ULL << 20; + opts.tx_slice_bytes = 1ULL << 20; + opts.num_thread = 2; + opts.global_store_address.clear(); + return opts; +} + +class BroadcastSessionClient final : public store::testing::GlobalStoreClientStub { + public: + std::vector requests; + absl::Status rpc_status = absl::OkStatus(); + global_store::Status response_status = global_store::STATUS_OK; + std::string response_session_id{"session-from-global"}; + + absl::StatusOr create_broadcast_session( + const global_store::CreateBroadcastSessionRequest& request, + const store::components::RpcOptions&) override { + requests.push_back(request); + if (!rpc_status.ok()) { + return rpc_status; + } + global_store::CreateBroadcastSessionResponse response; + response.set_status(response_status); + response.mutable_session()->set_session_id(response_session_id); + return response; + } +}; + +std::unique_ptr make_harness(std::shared_ptr client) { + auto engine = std::make_shared(make_opts()); + engine->set_global_store_client_for_testing(client); + + daemon::DaemonOptions daemon_opts; + daemon_opts.storage_path = test_tmpdir() / "daemon"; + std::filesystem::create_directories(daemon_opts.storage_path); + auto harness_or = daemon::DaemonServiceHarness::create(engine, daemon_opts, nullptr, client); + REQUIRE(harness_or.ok()); + auto harness = std::move(*harness_or); + REQUIRE(harness->start().ok()); + return harness; +} + +} // namespace + +TEST_CASE("CreateBroadcastSession forwards daemon request to Global Store", "[daemon][broadcast]") { + auto client = std::make_shared(); + auto harness = make_harness(client); + + daemon_v2::CreateBroadcastSessionRequest request; + request.set_session_id("session-a"); + request.set_artifact_id("artifact-a"); + request.set_requested_view_id("view-a"); + request.set_epoch(7); + request.set_fanout(2); + request.add_target_worker_ids("worker-a"); + request.add_target_worker_ids("worker-b"); + request.add_target_daemon_ids("daemon-a"); + request.set_root_replica_id("replica-root"); + request.set_strict_parent(true); + request.set_max_attempts(3); + + daemon_v2::CreateBroadcastSessionResponse response; + grpc::ServerContext context; + const auto status = harness->service().CreateBroadcastSession(&context, &request, &response); + + REQUIRE(status.ok()); + REQUIRE(response.status() == daemon_v2::BROADCAST_SESSION_STATUS_OK); + REQUIRE(response.session_id() == "session-from-global"); + REQUIRE(client->requests.size() == 1); + + const auto& global_request = client->requests.front(); + REQUIRE(global_request.session_id() == "session-a"); + REQUIRE(global_request.artifact_id() == "artifact-a"); + REQUIRE(global_request.requested_byte_space().kind() == common::BYTE_SPACE_KIND_VIEW); + REQUIRE(global_request.requested_byte_space().id() == "view-a"); + REQUIRE(global_request.epoch() == 7); + REQUIRE(global_request.fanout() == 2); + REQUIRE(global_request.root_replica_id() == "replica-root"); + REQUIRE(global_request.strict_parent()); + REQUIRE(global_request.max_attempts() == 3); + REQUIRE(global_request.targets_size() == 3); + REQUIRE(global_request.targets(0).worker_id() == "worker-a"); + REQUIRE(global_request.targets(1).worker_id() == "worker-b"); + REQUIRE(global_request.targets(2).daemon_id() == "daemon-a"); +} + +TEST_CASE("CreateBroadcastSession reports daemon status when Global Store is unavailable", "[daemon][broadcast]") { + auto client = std::make_shared(); + client->connected = false; + auto harness = make_harness(client); + + daemon_v2::CreateBroadcastSessionRequest request; + request.set_session_id("session-a"); + request.set_artifact_id("artifact-a"); + request.set_fanout(2); + request.set_max_attempts(3); + request.add_target_daemon_ids("daemon-a"); + + daemon_v2::CreateBroadcastSessionResponse response; + grpc::ServerContext context; + const auto status = harness->service().CreateBroadcastSession(&context, &request, &response); + + REQUIRE(status.ok()); + REQUIRE(response.status() == daemon_v2::BROADCAST_SESSION_STATUS_ERROR); + REQUIRE(client->requests.empty()); +} diff --git a/daemon/service/grpc_service_impl_rpc_delegates.cc b/daemon/service/grpc_service_impl_rpc_delegates.cc index afa3447f..3b710690 100644 --- a/daemon/service/grpc_service_impl_rpc_delegates.cc +++ b/daemon/service/grpc_service_impl_rpc_delegates.cc @@ -6,6 +6,7 @@ namespace tensorcast::daemon { using ::grpc::Status; using ::grpc::StatusCode; +namespace global_store = tensorcast::global_store::v1; Status StoreDaemonServiceImpl::MaterializeReplica( grpc::ServerContext* ctx, @@ -588,6 +589,67 @@ Status StoreDaemonServiceImpl::UnlockTransportChunks( return transport_controller_->unlock(rctx, *req, dummy); } +Status StoreDaemonServiceImpl::CreateBroadcastSession( + grpc::ServerContext* ctx, + const v2::CreateBroadcastSessionRequest* req, + v2::CreateBroadcastSessionResponse* resp) { + if (auto startup_status = block_if_startup_pending(); !startup_status.ok()) { + return startup_status; + } + RpcContext rctx{"CreateBroadcastSession", *ctx, opts_.allow_high_card_attrs}; + if (global_store_client_ == nullptr || !global_store_client_->is_connected()) { + resp->set_status(v2::BROADCAST_SESSION_STATUS_ERROR); + rctx.mark_success(); + return Status::OK; + } + + global_store::CreateBroadcastSessionRequest global_req; + global_req.set_session_id(req->session_id()); + global_req.set_artifact_id(req->artifact_id()); + if (!req->requested_view_id().empty()) { + auto* requested_space = global_req.mutable_requested_byte_space(); + requested_space->set_kind(tensorcast::common::v1::BYTE_SPACE_KIND_VIEW); + requested_space->set_id(req->requested_view_id()); + } + global_req.set_epoch(req->epoch()); + global_req.set_fanout(req->fanout()); + global_req.set_root_replica_id(req->root_replica_id()); + global_req.set_strict_parent(req->strict_parent()); + global_req.set_max_attempts(req->max_attempts()); + for (const auto& worker_id : req->target_worker_ids()) { + global_req.add_targets()->set_worker_id(worker_id); + } + for (const auto& daemon_id : req->target_daemon_ids()) { + global_req.add_targets()->set_daemon_id(daemon_id); + } + + store::components::RpcOptions rpc_options; + rpc_options.max_retries = 0; + rpc_options.cancel_check = [ctx]() { return ctx != nullptr && ctx->IsCancelled(); }; + auto global_resp_or = global_store_client_->create_broadcast_session(global_req, rpc_options); + if (!global_resp_or.ok()) { + resp->set_status(v2::BROADCAST_SESSION_STATUS_ERROR); + rctx.mark_success(); + return Status::OK; + } + + const auto& global_resp = *global_resp_or; + if (global_resp.status() == global_store::STATUS_OK) { + resp->set_status(v2::BROADCAST_SESSION_STATUS_OK); + std::string session_id = req->session_id(); + if (global_resp.has_session() && !global_resp.session().session_id().empty()) { + session_id = global_resp.session().session_id(); + } + resp->set_session_id(session_id); + } else if (global_resp.status() == global_store::STATUS_NOT_FOUND) { + resp->set_status(v2::BROADCAST_SESSION_STATUS_NOT_FOUND); + } else { + resp->set_status(v2::BROADCAST_SESSION_STATUS_ERROR); + } + rctx.mark_success(); + return Status::OK; +} + Status StoreDaemonServiceImpl::BeginRegisterArtifact( grpc::ServerContext* ctx, const v2::BeginRegisterArtifactRequest* req, diff --git a/proto/tensorcast/daemon/v2/store_daemon.proto b/proto/tensorcast/daemon/v2/store_daemon.proto index a392b1bd..5a37233f 100644 --- a/proto/tensorcast/daemon/v2/store_daemon.proto +++ b/proto/tensorcast/daemon/v2/store_daemon.proto @@ -62,6 +62,7 @@ service StoreDaemonService { // Dual-end locking mechanism for P2P transfers rpc LockTransportChunks(LockTransportChunksRequest) returns (LockTransportChunksResponse) {} rpc UnlockTransportChunks(UnlockTransportChunksRequest) returns (UnlockTransportChunksResponse) {} + rpc CreateBroadcastSession(CreateBroadcastSessionRequest) returns (CreateBroadcastSessionResponse) {} // ========== Memory Artifact Registration ========== // New canonical RPC names @@ -252,6 +253,31 @@ enum ExportPolicy { EXPORT_POLICY_FORCE = 3; } +enum BroadcastSessionStatus { + BROADCAST_SESSION_STATUS_UNSPECIFIED = 0; + BROADCAST_SESSION_STATUS_OK = 1; + BROADCAST_SESSION_STATUS_ERROR = 2; + BROADCAST_SESSION_STATUS_NOT_FOUND = 3; +} + +message CreateBroadcastSessionRequest { + string session_id = 1; + string artifact_id = 2; + string requested_view_id = 3; + uint64 epoch = 4; + uint32 fanout = 5; + repeated string target_worker_ids = 6; + repeated string target_daemon_ids = 7; + string root_replica_id = 8; + bool strict_parent = 9; + uint32 max_attempts = 10; +} + +message CreateBroadcastSessionResponse { + BroadcastSessionStatus status = 1; + string session_id = 2; +} + // Source selection policy that can further gate which sources are allowed. message SourcePolicy { SourcePreference preference = 1; diff --git a/tensorcast/api/store/__init__.py b/tensorcast/api/store/__init__.py index 8a9e445e..7da7030b 100644 --- a/tensorcast/api/store/__init__.py +++ b/tensorcast/api/store/__init__.py @@ -11,6 +11,7 @@ import time import weakref from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, cast import grpc @@ -201,6 +202,11 @@ from tensorcast.api.plan import PlanResult, PlanStepRef +@dataclass(frozen=True, slots=True) +class BroadcastSessionHandle: + session_id: str + + def _coerce_representation_publish_closeout( publication: RepresentationPublishSpec | AssemblyCloseoutContract, ) -> AssemblyCloseoutContract: @@ -1098,6 +1104,71 @@ def __init__( ) _LIVE_STORES.add(self) + def create_broadcast_session( + self, + *, + artifact_id: str, + session_id: str | None = None, + requested_view_id: str | None = None, + epoch: int = 0, + fanout: int = 0, + target_worker_ids: Sequence[str] | None = None, + target_daemon_ids: Sequence[str] | None = None, + root_replica_id: str | None = None, + strict_parent: bool = True, + max_attempts: int = 3, + ) -> BroadcastSessionHandle: + if not session_id: + raise ArtifactError( + "session_id is required", + status_code="INVALID_ARGUMENT", + retryable=False, + ) + if not artifact_id: + raise ArtifactError( + "artifact_id is required", + status_code="INVALID_ARGUMENT", + retryable=False, + ) + if int(fanout) <= 0: + raise ArtifactError( + "fanout must be > 0", + status_code="INVALID_ARGUMENT", + retryable=False, + ) + if int(max_attempts) <= 0: + raise ArtifactError( + "max_attempts must be > 0", + status_code="INVALID_ARGUMENT", + retryable=False, + ) + response = self._runtime.ensure_client().create_broadcast_session( + session_id=session_id, + artifact_id=artifact_id, + requested_view_id=requested_view_id, + epoch=epoch, + fanout=fanout, + target_worker_ids=list(target_worker_ids or ()), + target_daemon_ids=list(target_daemon_ids or ()), + root_replica_id=root_replica_id, + strict_parent=strict_parent, + max_attempts=max_attempts, + ) + status = response.status + if status == store_daemon_pb2.BROADCAST_SESSION_STATUS_OK: + return BroadcastSessionHandle(session_id=str(response.session_id)) + if status == store_daemon_pb2.BROADCAST_SESSION_STATUS_NOT_FOUND: + raise ArtifactError( + "broadcast session source artifact or target not found", + status_code="NOT_FOUND", + retryable=False, + ) + raise ArtifactError( + "failed to create broadcast session", + status_code="FAILED_PRECONDITION", + retryable=False, + ) + def set_register_fn(self, register_fn: Callable[..., RegistrationResult]) -> None: self._registration.set_register_fn(register_fn) @@ -4380,6 +4451,7 @@ def realize_into_binding( "BindingValueRef", "BindingLayout", "BindingUpdateEpoch", + "BroadcastSessionHandle", "BuilderMode", "CanonicalIndex", "CanonicalIndexEntry", diff --git a/tensorcast/daemon_ctl.py b/tensorcast/daemon_ctl.py index 402cbe37..b7a849cd 100644 --- a/tensorcast/daemon_ctl.py +++ b/tensorcast/daemon_ctl.py @@ -1245,6 +1245,68 @@ def create_binding( ) from e return response + def create_broadcast_session( + self, + *, + artifact_id: str, + session_id: str | None = None, + requested_view_id: str | None = None, + epoch: int = 0, + fanout: int = 0, + target_worker_ids: Iterable[str] | None = None, + target_daemon_ids: Iterable[str] | None = None, + root_replica_id: str | None = None, + strict_parent: bool = True, + max_attempts: int = 3, + timeout_s: float = 30.0, + ) -> store_daemon_pb2.CreateBroadcastSessionResponse: + if not session_id: + raise ValueError("session_id is required") + if not artifact_id: + raise ValueError("artifact_id is required") + if int(fanout) <= 0: + raise ValueError("fanout must be > 0") + if int(max_attempts) <= 0: + raise ValueError("max_attempts must be > 0") + request = store_daemon_pb2.CreateBroadcastSessionRequest( + session_id=str(session_id), + artifact_id=str(artifact_id), + epoch=int(epoch), + fanout=int(fanout), + strict_parent=bool(strict_parent), + max_attempts=int(max_attempts), + ) + if requested_view_id: + request.requested_view_id = str(requested_view_id) + if target_worker_ids is not None: + request.target_worker_ids.extend(str(item) for item in target_worker_ids) + if target_daemon_ids is not None: + request.target_daemon_ids.extend(str(item) for item in target_daemon_ids) + if root_replica_id: + request.root_replica_id = str(root_replica_id) + with self._client_span("Client/CreateBroadcastSession") as span: + try: + response: store_daemon_pb2.CreateBroadcastSessionResponse = ( + self._unary_call( + self.stub_v2.CreateBroadcastSession, + request, + timeout=float(timeout_s), + span=span, + retries=1, + ) + ) + except grpc.RpcError as e: # noqa: BLE001 + span.record_exception(e) + code = e.code() + if code == grpc.StatusCode.UNAVAILABLE: + raise RuntimeError( + f"Local StoreDaemon ({self.server_address}) is not available." + ) from e + raise RuntimeError( + _grpc_message(e, fallback="CreateBroadcastSession RPC failed") + ) from e + return response + def commit_binding_artifact( self, *, diff --git a/tests/python/api/test_broadcast_session_api.py b/tests/python/api/test_broadcast_session_api.py new file mode 100644 index 00000000..e504835a --- /dev/null +++ b/tests/python/api/test_broadcast_session_api.py @@ -0,0 +1,210 @@ +# Copyright (c) 2026, TensorCast Team. + +from __future__ import annotations + +import pytest + +from tensorcast.api.store import Store +from tensorcast.api.store.types import ArtifactError +from tensorcast.daemon_ctl import DaemonCtl +from tensorcast.proto.daemon.v2 import store_daemon_pb2 + + +class _DaemonClient: + def __init__(self, response: store_daemon_pb2.CreateBroadcastSessionResponse) -> None: + self.response = response + self.calls: list[dict[str, object]] = [] + + def create_broadcast_session(self, **kwargs): + self.calls.append(kwargs) + return self.response + + +class _Runtime: + daemon_endpoint = "daemon" + daemon_id = "daemon-1" + session_id = "sess" + closed = False + + def __init__(self, client: _DaemonClient) -> None: + self._client = client + + def ensure_client(self) -> _DaemonClient: + return self._client + + +def test_store_create_broadcast_session_uses_daemon_client() -> None: + client = _DaemonClient( + store_daemon_pb2.CreateBroadcastSessionResponse( + status=store_daemon_pb2.BROADCAST_SESSION_STATUS_OK, + session_id="session-a", + ) + ) + store = Store("daemon", runtime=_Runtime(client)) + + handle = store.create_broadcast_session( + session_id="session-a", + artifact_id="artifact-a", + target_daemon_ids=["daemon-a", "daemon-b"], + target_worker_ids=["worker-a"], + requested_view_id="view-a", + epoch=7, + fanout=2, + root_replica_id="replica-root", + strict_parent=True, + max_attempts=3, + ) + + assert handle.session_id == "session-a" + assert client.calls == [ + { + "session_id": "session-a", + "artifact_id": "artifact-a", + "requested_view_id": "view-a", + "epoch": 7, + "fanout": 2, + "target_worker_ids": ["worker-a"], + "target_daemon_ids": ["daemon-a", "daemon-b"], + "root_replica_id": "replica-root", + "strict_parent": True, + "max_attempts": 3, + } + ] + + +def test_store_create_broadcast_session_validates_required_fields() -> None: + client = _DaemonClient( + store_daemon_pb2.CreateBroadcastSessionResponse( + status=store_daemon_pb2.BROADCAST_SESSION_STATUS_OK, + session_id="unused", + ) + ) + store = Store("daemon", runtime=_Runtime(client)) + + with pytest.raises(ArtifactError, match="session_id is required"): + store.create_broadcast_session( + session_id="", + artifact_id="artifact-a", + fanout=2, + target_daemon_ids=["daemon-a"], + ) + + with pytest.raises(ArtifactError, match="fanout must be > 0"): + store.create_broadcast_session( + session_id="session-a", + artifact_id="artifact-a", + fanout=0, + target_daemon_ids=["daemon-a"], + ) + + assert client.calls == [] + + +def test_daemon_ctl_create_broadcast_session_builds_request() -> None: + client = DaemonCtl.__new__(DaemonCtl) + client.server_address = "127.0.0.1:1" + captured: list[store_daemon_pb2.CreateBroadcastSessionRequest] = [] + + class _Stub: + def CreateBroadcastSession(self): # pragma: no cover - marker only + raise AssertionError("fake unary call should intercept this") + + def _unary_call(method, request, *, timeout=None, retries=0, span=None): # noqa: ANN001 + del method, timeout, retries, span + captured.append(request) + return store_daemon_pb2.CreateBroadcastSessionResponse( + status=store_daemon_pb2.BROADCAST_SESSION_STATUS_OK, + session_id="session-a", + ) + + client.stub_v2 = _Stub() + client._unary_call = _unary_call + + response = client.create_broadcast_session( + session_id="session-a", + artifact_id="artifact-a", + requested_view_id="view-a", + epoch=9, + fanout=4, + target_worker_ids=["worker-a"], + target_daemon_ids=["daemon-a", "daemon-b"], + root_replica_id="replica-root", + strict_parent=False, + max_attempts=5, + timeout_s=12.0, + ) + + assert response.session_id == "session-a" + assert len(captured) == 1 + request = captured[0] + assert request.session_id == "session-a" + assert request.artifact_id == "artifact-a" + assert request.requested_view_id == "view-a" + assert request.epoch == 9 + assert request.fanout == 4 + assert list(request.target_worker_ids) == ["worker-a"] + assert list(request.target_daemon_ids) == ["daemon-a", "daemon-b"] + assert request.root_replica_id == "replica-root" + assert request.strict_parent is False + assert request.max_attempts == 5 + + +def test_daemon_ctl_create_broadcast_session_defaults_max_attempts() -> None: + client = DaemonCtl.__new__(DaemonCtl) + client.server_address = "127.0.0.1:1" + captured: list[store_daemon_pb2.CreateBroadcastSessionRequest] = [] + + class _Stub: + def CreateBroadcastSession(self): # pragma: no cover - marker only + raise AssertionError("fake unary call should intercept this") + + def _unary_call(method, request, *, timeout=None, retries=0, span=None): # noqa: ANN001 + del method, timeout, retries, span + captured.append(request) + return store_daemon_pb2.CreateBroadcastSessionResponse( + status=store_daemon_pb2.BROADCAST_SESSION_STATUS_OK, + session_id="session-a", + ) + + client.stub_v2 = _Stub() + client._unary_call = _unary_call + + client.create_broadcast_session( + session_id="session-a", + artifact_id="artifact-a", + fanout=2, + target_daemon_ids=["daemon-a"], + ) + + assert len(captured) == 1 + assert captured[0].max_attempts == 3 + + +@pytest.mark.parametrize( + ("status", "status_code"), + [ + (store_daemon_pb2.BROADCAST_SESSION_STATUS_ERROR, "FAILED_PRECONDITION"), + (store_daemon_pb2.BROADCAST_SESSION_STATUS_NOT_FOUND, "NOT_FOUND"), + ], +) +def test_store_create_broadcast_session_maps_daemon_errors( + status: int, + status_code: str, +) -> None: + client = _DaemonClient( + store_daemon_pb2.CreateBroadcastSessionResponse( + status=status, + session_id="", + ) + ) + store = Store("daemon", runtime=_Runtime(client)) + + with pytest.raises(ArtifactError) as exc_info: + store.create_broadcast_session( + session_id="session-a", + artifact_id="artifact-a", + fanout=2, + target_daemon_ids=["daemon-a"], + ) + + assert exc_info.value.status_code == status_code From d1f0b928ce531d5f088019f4804aa109a068e642 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 29 Apr 2026 20:26:15 +0800 Subject: [PATCH 42/49] test: cover broadcast tree dissemination --- tensorcast/global_store/README.md | 11 ++ .../services/broadcast_service.py | 48 +++++- .../python/api/test_broadcast_session_api.py | 4 +- .../python/global_store/test_broadcast_e2e.py | 153 ++++++++++++++++++ .../global_store/test_broadcast_service.py | 63 +++++++- 5 files changed, 270 insertions(+), 9 deletions(-) create mode 100644 tests/python/global_store/test_broadcast_e2e.py diff --git a/tensorcast/global_store/README.md b/tensorcast/global_store/README.md index 23f5e0bd..8ba0cbbe 100644 --- a/tensorcast/global_store/README.md +++ b/tensorcast/global_store/README.md @@ -258,6 +258,17 @@ counts by scope and capability. 2. Select replica with source-balance scoring and memory-tier preference (GPU > RAM > DISK). 3. Finalize assignment transactionally with idempotent request replay protection. +### Broadcast Sessions + +Broadcast sessions coordinate strict tree dissemination for model-weight prefetch. A session records the artifact or +view, epoch, target workers, fanout, and planned parent-child edges. Broadcast-tagged `RequestReplicaTransport` calls +resolve to the parent replica assigned by the active edge; untagged requests continue to use group dispatch or ordinary +source selection. + +Only successful transport completions advance tree progress. After a child materializes and registers an exportable +replica, that child can become a parent for the next layer; failed or exhausted edges are retried or marked terminal by +the Global Store while the artifact bytes still move only between Store Daemons. + ### RecoveryService **Purpose:** Handle high-availability state reconciliation. diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index 23e8e516..a89a0a2d 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -357,10 +357,13 @@ def _mark_session_terminal_if_done(self, session_id: str, *, cursor=None) -> Non session = self._broadcast_repository.find_session(session_id, cursor=cursor) if session is None or session.state is not BroadcastSessionState.ACTIVE: return - if self._broadcast_repository.count_incomplete_targets( - session_id, - cursor=cursor, - ) != 0: + if ( + self._broadcast_repository.count_incomplete_targets( + session_id, + cursor=cursor, + ) + != 0 + ): return targets = self._broadcast_repository.list_targets(session_id, cursor=cursor) if any(target.state is BroadcastTargetState.FAILED for target in targets): @@ -512,7 +515,9 @@ def _plan_more_edges( return [] planned: list[BroadcastEdge] = [] - for target, parent in zip(pending_targets[:capacity], parent_pool * capacity): + for target, parent in zip( + pending_targets[:capacity], parent_pool, strict=False + ): parent_replica, parent_level = parent edge = BroadcastEdge( edge_id=str(uuid4()), @@ -559,6 +564,15 @@ def _parent_pool( cursor=None, ) -> list[tuple[Replica, int]]: parents: list[tuple[Replica, int]] = [] + edge_counts = self._parent_edge_counts(session.session_id, cursor=cursor) + fanout = max(0, int(session.fanout)) + + def add_available_slots(replica: Replica, level: int) -> None: + if fanout <= 0: + return + remaining = fanout - edge_counts.get(str(replica.replica_id), 0) + parents.extend((replica, level) for _ in range(max(0, remaining))) + if session.root_replica_id is not None: root = self._replica_repository.find_by_id( session.root_replica_id, @@ -566,7 +580,7 @@ def _parent_pool( cursor=cursor, ) if root is not None: - parents.append((root, 0)) + add_available_slots(root, 0) completed_targets = self._broadcast_repository.list_targets_by_state( session.session_id, @@ -581,5 +595,25 @@ def _parent_pool( target.completed_replica_id ) if replica is not None: - parents.append((replica, int(target.level or 0))) + add_available_slots(replica, int(target.level or 0)) return parents + + def _parent_edge_counts(self, session_id: str, *, cursor=None) -> dict[str, int]: + owns_cursor = cursor is None + if owns_cursor: + cursor = self._broadcast_repository.get_cursor() + try: + rows = cursor.execute( + """ + SELECT parent_replica_id, COUNT(*) + FROM broadcast_edges + WHERE session_id = ? + AND state IN ('planned', 'assigned', 'materializing', 'completed') + GROUP BY parent_replica_id + """, + [session_id], + ).fetchall() + return {str(row[0]): int(row[1] or 0) for row in rows} + finally: + if owns_cursor: + cursor.close() diff --git a/tests/python/api/test_broadcast_session_api.py b/tests/python/api/test_broadcast_session_api.py index e504835a..b2eb47d5 100644 --- a/tests/python/api/test_broadcast_session_api.py +++ b/tests/python/api/test_broadcast_session_api.py @@ -11,7 +11,9 @@ class _DaemonClient: - def __init__(self, response: store_daemon_pb2.CreateBroadcastSessionResponse) -> None: + def __init__( + self, response: store_daemon_pb2.CreateBroadcastSessionResponse + ) -> None: self.response = response self.calls: list[dict[str, object]] = [] diff --git a/tests/python/global_store/test_broadcast_e2e.py b/tests/python/global_store/test_broadcast_e2e.py new file mode 100644 index 00000000..661db153 --- /dev/null +++ b/tests/python/global_store/test_broadcast_e2e.py @@ -0,0 +1,153 @@ +# Copyright (c) 2026, TensorCast Team. + +from __future__ import annotations + +from tensorcast.proto.common.v1 import common_pb2 +from tensorcast.proto.global_store.v1 import global_store_pb2 + +ARTIFACT_ID = "mi2:model-e2e" + + +def _register_worker(servicer, context, idx: int) -> str: # noqa: ANN001 + response = servicer.RegisterWorker( + global_store_pb2.RegisterWorkerRequest( + daemon_id=f"daemon-e2e-{idx}", + node_id=f"node-e2e-{idx}", + node_address=f"10.30.0.{idx}", + grpc_port=53000 + idx, + p2p_port=54000 + idx, + mem_pool_total_size=4096, + mem_pool_available_size=4096, + ), + context, + ) + assert response.status == global_store_pb2.STATUS_OK + return response.worker_id + + +def _register_exportable_replica( + servicer, # noqa: ANN001 + context, # noqa: ANN001 + *, + worker_id: str, + idx: int, +) -> str: + mem_info = common_pb2.MemoryInfo( + node_id=f"node-e2e-{idx}", + node_address=f"10.30.0.{idx}", + node_port=54000 + idx, + memory_size=1024, + memory_type=common_pb2.MEMORY_TYPE_GPU, + device_id=0, + byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + ) + mem_info.transport.export_state = ( + common_pb2.ReplicaTransportMetadata.EXPORT_STATE_EXPORTABLE + ) + mem_info.transport.export_generation = 1 + mem_info.transport.remote_memory_keys.append(f"rk-e2e-{idx}") + mem_info.transport.buffer_sizes.append(1024) + request = global_store_pb2.RegisterReplicaRequest( + artifact_id=ARTIFACT_ID, + worker_id=worker_id, + mem_info=mem_info, + max_concurrency=4, + ) + response = servicer.RegisterReplica(request, context) + assert response.status == global_store_pb2.STATUS_OK + return response.replica_id + + +def _request_broadcast_transport( + servicer, # noqa: ANN001 + context, # noqa: ANN001 + *, + session_id: str, + worker_id: str, + idx: int, + request_id: str, +) -> global_store_pb2.RequestReplicaTransportResponse: + request = global_store_pb2.RequestReplicaTransportRequest( + artifact_id=ARTIFACT_ID, + source_node_id=f"node-e2e-{idx}", + source_address=f"10.30.0.{idx}", + source_port=54000 + idx, + requester_worker_id=worker_id, + request_id=request_id, + requested_byte_space=common_pb2.ByteSpaceRef( + kind=common_pb2.BYTE_SPACE_KIND_CANONICAL, + ), + ) + request.local_memory_info.memory_type = common_pb2.MEMORY_TYPE_GPU + request.local_memory_info.device_id = 0 + request.broadcast.session_id = session_id + request.broadcast.strict_parent = True + response = servicer.RequestReplicaTransport(request, context) + assert response.status == global_store_pb2.STATUS_OK + return response + + +def test_tree_broadcast_promotes_first_child_to_second_layer_parent( + servicer, # noqa: ANN001 + test_context, # noqa: ANN001 +) -> None: + root = _register_worker(servicer, test_context, 1) + child1 = _register_worker(servicer, test_context, 2) + child2 = _register_worker(servicer, test_context, 3) + root_replica = _register_exportable_replica( + servicer, + test_context, + worker_id=root, + idx=1, + ) + + create = servicer.CreateBroadcastSession( + global_store_pb2.CreateBroadcastSessionRequest( + session_id="session-e2e", + artifact_id=ARTIFACT_ID, + epoch=1, + fanout=1, + root_replica_id=root_replica, + strict_parent=True, + max_attempts=3, + targets=[ + global_store_pb2.BroadcastTargetIdentity(worker_id=child1), + global_store_pb2.BroadcastTargetIdentity(worker_id=child2), + ], + ), + test_context, + ) + assert create.status == global_store_pb2.STATUS_OK + + first = _request_broadcast_transport( + servicer, + test_context, + session_id="session-e2e", + worker_id=child1, + idx=2, + request_id="request-child-1", + ) + assert first.remote_memory_info.node_id == "node-e2e-1" + + _register_exportable_replica(servicer, test_context, worker_id=child1, idx=2) + complete = servicer.CompleteReplicaTransport( + global_store_pb2.CompleteReplicaTransportRequest( + transport_id=first.transport_id, + outcome=global_store_pb2.TRANSPORT_COMPLETION_OUTCOME_SUCCESS, + ), + test_context, + ) + assert complete.status == global_store_pb2.STATUS_OK + + second = _request_broadcast_transport( + servicer, + test_context, + session_id="session-e2e", + worker_id=child2, + idx=3, + request_id="request-child-2", + ) + + assert second.remote_memory_info.node_id == "node-e2e-2" diff --git a/tests/python/global_store/test_broadcast_service.py b/tests/python/global_store/test_broadcast_service.py index 901660f1..39eb5ba9 100644 --- a/tests/python/global_store/test_broadcast_service.py +++ b/tests/python/global_store/test_broadcast_service.py @@ -96,10 +96,71 @@ def test_create_session_plans_first_layer_by_fanout(repositories): ] assert all(edge is not None for edge in edges) assert all(edge.state is BroadcastEdgeState.PLANNED for edge in edges if edge) - assert all(edge.parent_replica_id == root_replica.replica_id for edge in edges if edge) + assert all( + edge.parent_replica_id == root_replica.replica_id for edge in edges if edge + ) assert len(service.list_edges("session-a")) == 2 +def test_completed_child_becomes_parent_after_root_fanout_is_full(repositories): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + ) + root = _worker("worker-root-hop", "daemon-root-hop", "node1") + child1 = _worker("worker-child-hop-1", "daemon-child-hop-1", "node2") + child2 = _worker("worker-child-hop-2", "daemon-child-hop-2", "node3") + for worker in (root, child1, child2): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-hop", root)) + + service.create_session( + session_id="session-hop", + artifact_id="mi2:model-hop", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=["daemon-child-hop-1", "daemon-child-hop-2"], + root_replica_id=str(root_replica.replica_id), + strict_parent=True, + max_attempts=3, + ) + first_edge = broadcast_repo.find_active_edge_for_child( + "session-hop", + child1.worker_id, + ) + assert first_edge is not None + assert first_edge.parent_replica_id == root_replica.replica_id + + child1_replica = replica_repo.create(_exportable_replica("mi2:model-hop", child1)) + with broadcast_repo.transaction() as tx: + assert broadcast_repo.mark_edge_materializing( + first_edge.edge_id, + "request-hop-1", + cursor=tx, + ) + service.complete_transport_edge( + session_id="session-hop", + edge_id=first_edge.edge_id, + transport_outcome=TransportCompletionOutcome.SUCCESS, + outcome_detail=None, + cursor=tx, + ) + + second_edge = broadcast_repo.find_active_edge_for_child( + "session-hop", + child2.worker_id, + ) + assert second_edge is not None + assert second_edge.parent_replica_id == child1_replica.replica_id + assert second_edge.level == 2 + + def test_create_session_duplicate_explicit_root_returns_existing_without_counter_change( repositories, ): From e1a1f706c5522e16e943e880133f3f27f615dc82 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:54:21 +0800 Subject: [PATCH 43/49] fix(core): harden broadcast tree materialization --- .../replica_materialization_service.cc | 48 ++++++++++++------- ..._service_impl_no_lease_materialize_test.cc | 20 ++++++++ proto/gen/python/tensorcast/__init__.py | 0 .../gen/python/tensorcast/common/__init__.py | 0 .../python/tensorcast/common/v1/__init__.py | 0 .../tensorcast/communicator/__init__.py | 0 .../tensorcast/communicator/v1/__init__.py | 0 .../gen/python/tensorcast/config/__init__.py | 0 .../python/tensorcast/config/v1/__init__.py | 0 .../gen/python/tensorcast/daemon/__init__.py | 0 .../python/tensorcast/daemon/v2/__init__.py | 0 .../tensorcast/global_store/__init__.py | 0 .../tensorcast/global_store/v1/__init__.py | 0 .../gen/python/tensorcast/layout/__init__.py | 0 .../python/tensorcast/layout/v1/__init__.py | 0 .../python/tensorcast/memory_tier/__init__.py | 0 .../tensorcast/memory_tier/v1/__init__.py | 2 - .../python/tensorcast/node_agent/__init__.py | 0 .../tensorcast/node_agent/v1/__init__.py | 0 .../python/tensorcast/operation/__init__.py | 0 .../tensorcast/operation/v1/__init__.py | 0 proto/gen/python/tensorcast/plan/__init__.py | 0 .../gen/python/tensorcast/plan/v1/__init__.py | 0 .../python/tensorcast/publication/__init__.py | 0 .../tensorcast/publication/v1/__init__.py | 0 pyproject.toml | 12 ++--- tensorcast/global_store/grpc_service.py | 1 + .../services/broadcast_service.py | 13 +++-- .../global_store/test_broadcast_service.py | 48 +++++++++++++++++++ 29 files changed, 117 insertions(+), 27 deletions(-) create mode 100644 proto/gen/python/tensorcast/__init__.py create mode 100644 proto/gen/python/tensorcast/common/__init__.py create mode 100644 proto/gen/python/tensorcast/common/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/communicator/__init__.py create mode 100644 proto/gen/python/tensorcast/communicator/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/config/__init__.py create mode 100644 proto/gen/python/tensorcast/config/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/daemon/__init__.py create mode 100644 proto/gen/python/tensorcast/daemon/v2/__init__.py create mode 100644 proto/gen/python/tensorcast/global_store/__init__.py create mode 100644 proto/gen/python/tensorcast/global_store/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/layout/__init__.py create mode 100644 proto/gen/python/tensorcast/layout/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/memory_tier/__init__.py create mode 100644 proto/gen/python/tensorcast/node_agent/__init__.py create mode 100644 proto/gen/python/tensorcast/node_agent/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/operation/__init__.py create mode 100644 proto/gen/python/tensorcast/operation/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/plan/__init__.py create mode 100644 proto/gen/python/tensorcast/plan/v1/__init__.py create mode 100644 proto/gen/python/tensorcast/publication/__init__.py create mode 100644 proto/gen/python/tensorcast/publication/v1/__init__.py diff --git a/daemon/service/controllers/replica_materialization_service.cc b/daemon/service/controllers/replica_materialization_service.cc index 1c878973..3396afec 100644 --- a/daemon/service/controllers/replica_materialization_service.cc +++ b/daemon/service/controllers/replica_materialization_service.cc @@ -63,6 +63,7 @@ using materialization_policy::resolve_transform_placement; using materialization_policy::to_hint_export_policy; using materialization_post_seal::check_post_seal_view_reuse_safe; using materialization_replica_handle::bind_replica_handle_for_response; +using materialization_replica_handle::register_session_and_refs; using materialization_request_common::LeaseContext; using materialization_request_common::LipFastPathRequest; using materialization_request_common::materialize_with_shared_disk_retry; @@ -943,22 +944,37 @@ grpc::Status ReplicaMaterializationService::materialize_replica( << handle.replica_key << " cpu_state=" << static_cast(handle.cpu_state) << " gpu_state=" << static_cast(handle.gpu_state); } - auto bind_status = bind_materialized_handle( - d_.engine, - d_.sessions, - d_.refs, - d_.lifecycle, - d_.handle_leases, - handle, - req.replica_uuid(), - effective_pid, - loopback_peer, - cpu_target, - "engine path", - *resp.mutable_mem_handle()); - if (!bind_status.ok()) { - resp.set_status(MaterializeReplicaStatus::MATERIALIZE_REPLICA_STATUS_FAILED); - return to_grpc_status(bind_status); + if (no_lease) { + auto session_status = register_session_and_refs( + d_.sessions, + d_.refs, + handle.replica_key, + handle.ready_signal, + req.replica_uuid(), + effective_pid, + /*allow_pid_ref=*/false); + if (!session_status.ok()) { + resp.set_status(MaterializeReplicaStatus::MATERIALIZE_REPLICA_STATUS_FAILED); + return to_grpc_status(session_status); + } + } else { + auto bind_status = bind_materialized_handle( + d_.engine, + d_.sessions, + d_.refs, + d_.lifecycle, + d_.handle_leases, + handle, + req.replica_uuid(), + effective_pid, + loopback_peer, + cpu_target, + "engine path", + *resp.mutable_mem_handle()); + if (!bind_status.ok()) { + resp.set_status(MaterializeReplicaStatus::MATERIALIZE_REPLICA_STATUS_FAILED); + return to_grpc_status(bind_status); + } } resp.set_status(MaterializeReplicaStatus::MATERIALIZE_REPLICA_STATUS_ALLOCATED); if (handle.view_index_json.has_value()) { diff --git a/daemon/service/grpc_service_impl_no_lease_materialize_test.cc b/daemon/service/grpc_service_impl_no_lease_materialize_test.cc index 1f7e8d2a..f323be35 100644 --- a/daemon/service/grpc_service_impl_no_lease_materialize_test.cc +++ b/daemon/service/grpc_service_impl_no_lease_materialize_test.cc @@ -244,6 +244,26 @@ TEST_CASE("MaterializeReplica honors NO_LEASE semantics", "[daemon][materialize] REQUIRE(resp.has_ticket()); REQUIRE(resp.ticket().replica_uuid() == "op-by-key-no-lease"); } + + { + tensorcast::daemon::v2::MaterializeReplicaRequest req; + req.mutable_selection()->set_artifact_id(artifact_id); + req.set_target_device_type(tensorcast::daemon::v2::DeviceType::DEVICE_TYPE_CPU); + req.mutable_source_policy()->set_preference( + tensorcast::daemon::v2::SourcePreference::SOURCE_PREFERENCE_PREFER_DISK); + req.set_wait_for_completion(false); + req.set_replica_uuid("op-by-key-no-lease-cpu"); + req.set_pid(0); + req.set_lease_mode(tensorcast::daemon::v2::LeaseMode::LEASE_MODE_NO_LEASE); + + grpc::ServerContext ctx; + tensorcast::daemon::v2::MaterializeReplicaResponse resp; + const auto st = svc.MaterializeReplica(&ctx, &req, &resp); + REQUIRE(st.ok()); + REQUIRE_FALSE(resp.has_mem_handle()); + REQUIRE(resp.has_ticket()); + REQUIRE(resp.ticket().replica_uuid() == "op-by-key-no-lease-cpu"); + } } TEST_CASE("MaterializeReplica short-circuits local cache before disk resolution", "[daemon][materialize]") { diff --git a/proto/gen/python/tensorcast/__init__.py b/proto/gen/python/tensorcast/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/common/__init__.py b/proto/gen/python/tensorcast/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/common/v1/__init__.py b/proto/gen/python/tensorcast/common/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/communicator/__init__.py b/proto/gen/python/tensorcast/communicator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/communicator/v1/__init__.py b/proto/gen/python/tensorcast/communicator/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/config/__init__.py b/proto/gen/python/tensorcast/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/config/v1/__init__.py b/proto/gen/python/tensorcast/config/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/daemon/__init__.py b/proto/gen/python/tensorcast/daemon/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/daemon/v2/__init__.py b/proto/gen/python/tensorcast/daemon/v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/global_store/__init__.py b/proto/gen/python/tensorcast/global_store/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/global_store/v1/__init__.py b/proto/gen/python/tensorcast/global_store/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/layout/__init__.py b/proto/gen/python/tensorcast/layout/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/layout/v1/__init__.py b/proto/gen/python/tensorcast/layout/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/memory_tier/__init__.py b/proto/gen/python/tensorcast/memory_tier/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/memory_tier/v1/__init__.py b/proto/gen/python/tensorcast/memory_tier/v1/__init__.py index 502a286e..e69de29b 100644 --- a/proto/gen/python/tensorcast/memory_tier/v1/__init__.py +++ b/proto/gen/python/tensorcast/memory_tier/v1/__init__.py @@ -1,2 +0,0 @@ -# Copyright (c) 2025, TensorCast Team. - diff --git a/proto/gen/python/tensorcast/node_agent/__init__.py b/proto/gen/python/tensorcast/node_agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/node_agent/v1/__init__.py b/proto/gen/python/tensorcast/node_agent/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/operation/__init__.py b/proto/gen/python/tensorcast/operation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/operation/v1/__init__.py b/proto/gen/python/tensorcast/operation/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/plan/__init__.py b/proto/gen/python/tensorcast/plan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/plan/v1/__init__.py b/proto/gen/python/tensorcast/plan/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/publication/__init__.py b/proto/gen/python/tensorcast/publication/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/proto/gen/python/tensorcast/publication/v1/__init__.py b/proto/gen/python/tensorcast/publication/v1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 72dc6e22..7a4827ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] # Note: The torch version here is the default. When building releases, # use tools/release.sh with --torch-version to specify different versions -requires = ["setuptools>=68.0.0", "pyyaml", "torch==2.8.0+cu128", "toml"] +requires = ["setuptools>=68.0.0", "pyyaml", "torch==2.9.1", "toml"] build-backend = "setuptools.build_meta" [project] @@ -20,7 +20,7 @@ dependencies = [ "pytz>=2025.2", "prometheus-client>=0.21.1", "py-grpc-prometheus>=0.8.0", - "torch==2.8.0+cu128", + "torch==2.9.1", "protobuf==6.31.1", # OpenTelemetry is a required dependency. No automatic downgrade. "opentelemetry-api>=1.36.0", @@ -29,7 +29,7 @@ dependencies = [ "opentelemetry-exporter-otlp-proto-grpc>=1.36.0", "opentelemetry-instrumentation-grpc>=0.57b0", "uvicorn>=0.32.0", - "pyyaml>=6.0.2", + "pyyaml==6.0.1", "duckdb>=1.2.1", "pydantic>=2.0.0", "psutil>=7.0.0", @@ -43,7 +43,7 @@ dev = [ "clang-format>=20.1.0", "clang-tidy>=20.1.0", "einops>=0.8.1", - "grpcio-tools<=1.73.1", + "grpcio-tools==1.75.1", "httpx>=0.28.1", "hypothesis>=6.135.26", "ipdb>=0.13.13", @@ -255,8 +255,8 @@ default = true name = "pypi-public" url = "https://pypi.org/simple" -[tool.uv.sources] -torch = { index = "pytorch" } +# [tool.uv.sources] +# torch = { index = "pytorch" } [[tool.uv.index]] name = "pytorch" diff --git a/tensorcast/global_store/grpc_service.py b/tensorcast/global_store/grpc_service.py index c49cdc28..eda7e352 100644 --- a/tensorcast/global_store/grpc_service.py +++ b/tensorcast/global_store/grpc_service.py @@ -470,6 +470,7 @@ def _rebuild_runtime_services_and_handlers(self) -> None: broadcast_repository=self.broadcast_repository, replica_repository=self.replica_repository, worker_repository=self.worker_repository, + root_heartbeat_timeout_seconds=self.config.heartbeat_timeout_ms / 1000.0, ) self.transport_service = TransportService( self.replica_repository, diff --git a/tensorcast/global_store/services/broadcast_service.py b/tensorcast/global_store/services/broadcast_service.py index a89a0a2d..31c45835 100644 --- a/tensorcast/global_store/services/broadcast_service.py +++ b/tensorcast/global_store/services/broadcast_service.py @@ -35,7 +35,7 @@ class BroadcastService: """Coordinates broadcast session topology state.""" - _ROOT_HEARTBEAT_TIMEOUT_SECONDS = 5.0 + _DEFAULT_ROOT_HEARTBEAT_TIMEOUT_SECONDS = 30.0 def __init__( self, @@ -43,10 +43,17 @@ def __init__( broadcast_repository: BroadcastRepository, replica_repository: ReplicaRepository, worker_repository: WorkerRepository, + root_heartbeat_timeout_seconds: float | None = None, ) -> None: self._broadcast_repository = broadcast_repository self._replica_repository = replica_repository self._worker_repository = worker_repository + timeout = ( + self._DEFAULT_ROOT_HEARTBEAT_TIMEOUT_SECONDS + if root_heartbeat_timeout_seconds is None + else float(root_heartbeat_timeout_seconds) + ) + self._root_heartbeat_timeout_seconds = max(0.0, timeout) def create_session( self, @@ -289,7 +296,7 @@ def complete_transport_edge( artifact_id=session.artifact_id, view_id=session.requested_view_id, worker_id=edge.child_worker_id, - heartbeat_timeout_seconds=self._ROOT_HEARTBEAT_TIMEOUT_SECONDS, + heartbeat_timeout_seconds=self._root_heartbeat_timeout_seconds, cursor=cursor, ) if child_replica is None: @@ -471,7 +478,7 @@ def _resolve_root_replica( result = self._replica_repository.find_available_for_transport( artifact_id=artifact_id, - heartbeat_timeout_seconds=self._ROOT_HEARTBEAT_TIMEOUT_SECONDS, + heartbeat_timeout_seconds=self._root_heartbeat_timeout_seconds, view_id=requested_view_id, ) if result.replica is None: diff --git a/tests/python/global_store/test_broadcast_service.py b/tests/python/global_store/test_broadcast_service.py index 39eb5ba9..1079abd9 100644 --- a/tests/python/global_store/test_broadcast_service.py +++ b/tests/python/global_store/test_broadcast_service.py @@ -259,6 +259,54 @@ def test_create_session_duplicate_auto_root_returns_existing_without_counter_cha assert len(service.list_edges("session-auto")) == 1 +def test_create_session_auto_root_uses_configured_heartbeat_timeout(repositories): + worker_repo = repositories["worker"] + replica_repo = repositories["replica"] + broadcast_repo = repositories["broadcast"] + service = BroadcastService( + broadcast_repository=broadcast_repo, + replica_repository=replica_repo, + worker_repository=worker_repo, + root_heartbeat_timeout_seconds=30.0, + ) + root = _worker("worker-root-timeout", "daemon-root-timeout", "node1") + child = _worker("worker-child-timeout", "daemon-child-timeout", "node2") + for worker in (root, child): + worker_repo.create(worker) + assert worker_repo.update_heartbeat(worker.worker_id, 4096, True) + root_replica = replica_repo.create(_exportable_replica("mi2:model-timeout", root)) + + cursor = worker_repo.get_cursor() + try: + cursor.execute( + """ + UPDATE worker_liveness + SET last_heartbeat = now() - INTERVAL '6 seconds' + WHERE worker_id = ? + """, + [root.worker_id], + ) + finally: + cursor.close() + + session = service.create_session( + session_id="session-timeout", + artifact_id="mi2:model-timeout", + requested_view_id=None, + epoch=1, + fanout=1, + target_daemon_ids=["daemon-child-timeout"], + root_replica_id="", + strict_parent=True, + max_attempts=3, + ) + + assert session.root_replica_id == root_replica.replica_id + assert replica_repo.get_current_requests(root_replica.replica_id) == 0 + assert len(broadcast_repo.list_targets("session-timeout")) == 1 + assert len(service.list_edges("session-timeout")) == 1 + + def test_create_session_auto_root_failure_releases_counter_and_rolls_back( repositories, monkeypatch, From 2f5d6fdd057bbb98bad845cd05ea2dfb0e3f7774 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:25:32 +0800 Subject: [PATCH 44/49] style: format C++ sources --- .../materialize_orchestrator_reselection_test.cc | 11 +++++------ .../store/runtime/ingestion/materialization_facade.cc | 3 +-- .../controllers/materialization_policy_utils.cc | 4 ++-- .../controllers/replica_materialization_service.cc | 5 ++--- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/core/store/materialization/control/materialize_orchestrator_reselection_test.cc b/core/store/materialization/control/materialize_orchestrator_reselection_test.cc index 6797c92f..1e3c2ebe 100644 --- a/core/store/materialization/control/materialize_orchestrator_reselection_test.cc +++ b/core/store/materialization/control/materialize_orchestrator_reselection_test.cc @@ -173,7 +173,9 @@ TEST_CASE("MaterializeOrchestrator accepts local route returned by Global Store" CHECK(gs_client->completed_transport_outcomes[0] == TransportCompletionOutcome::kSuccess); } -TEST_CASE("GlobalStoreClient request transport keeps legacy positional arguments", "[store][materialize][reselection]") { +TEST_CASE( + "GlobalStoreClient request transport keeps legacy positional arguments", + "[store][materialize][reselection]") { RecordingGlobalStoreClient gs_client; gs_client.connected = true; gs_client.allow_replica_transport = true; @@ -200,9 +202,7 @@ TEST_CASE( bool completed_before_register = false; FakeMaterializationBackend backend; backend.register_status = absl::UnavailableError("best-effort register failed"); - backend.on_register = [&]() { - completed_before_register = !gs_client->completed_transport_ids.empty(); - }; + backend.on_register = [&]() { completed_before_register = !gs_client->completed_transport_ids.empty(); }; MaterializeHints hints; hints.artifact_id = "artifact-non-broadcast-register-fails"; @@ -526,8 +526,7 @@ TEST_CASE( gsl::not_null{gs_client.get()}, local_identity); - auto result = - orchestrator.run("artifact-broadcast-terminal-register-fails", make_gpu_target(0), hints, std::nullopt); + auto result = orchestrator.run("artifact-broadcast-terminal-register-fails", make_gpu_target(0), hints, std::nullopt); REQUIRE_FALSE(result.ok()); CHECK(absl::IsInvalidArgument(result.status())); REQUIRE(gs_client->completed_transport_ids.size() == 1); diff --git a/core/store/runtime/ingestion/materialization_facade.cc b/core/store/runtime/ingestion/materialization_facade.cc index 11d12a2a..020154be 100644 --- a/core/store/runtime/ingestion/materialization_facade.cc +++ b/core/store/runtime/ingestion/materialization_facade.cc @@ -1264,8 +1264,7 @@ std::optional to_transport_scheduling_ return out; } -std::optional to_broadcast_transport_hint( - const loading::MaterializeHints& hints) { +std::optional to_broadcast_transport_hint(const loading::MaterializeHints& hints) { if (!hints.broadcast.has_value() || hints.broadcast->session_id.empty()) { return std::nullopt; } diff --git a/daemon/service/controllers/materialization_policy_utils.cc b/daemon/service/controllers/materialization_policy_utils.cc index c7ffe204..15d75f6c 100644 --- a/daemon/service/controllers/materialization_policy_utils.cc +++ b/daemon/service/controllers/materialization_policy_utils.cc @@ -109,8 +109,8 @@ std::optional resolve_collective_group_ std::optional resolve_transport_scheduling_group_hint( const v2::TransportSchedulingGroupHint* group) { - if (group == nullptr || group->group_kind().empty() || group->group_id().empty() || - group->part_id().empty() || group->total_parts() == 0) { + if (group == nullptr || group->group_kind().empty() || group->group_id().empty() || group->part_id().empty() || + group->total_parts() == 0) { return std::nullopt; } return store::loading::TransportSchedulingGroupHint{ diff --git a/daemon/service/controllers/replica_materialization_service.cc b/daemon/service/controllers/replica_materialization_service.cc index 3396afec..9e7d8fe8 100644 --- a/daemon/service/controllers/replica_materialization_service.cc +++ b/daemon/service/controllers/replica_materialization_service.cc @@ -58,8 +58,8 @@ using materialization_policy::NormalizedMaterializationRequestContext; using materialization_policy::resolve_broadcast_materialization_hint; using materialization_policy::resolve_collective_group_hint; using materialization_policy::resolve_materialization_request_context; -using materialization_policy::resolve_transport_scheduling_group_hint; using materialization_policy::resolve_transform_placement; +using materialization_policy::resolve_transport_scheduling_group_hint; using materialization_policy::to_hint_export_policy; using materialization_post_seal::check_post_seal_view_reuse_safe; using materialization_replica_handle::bind_replica_handle_for_response; @@ -691,8 +691,7 @@ grpc::Status ReplicaMaterializationService::materialize_replica( hints.transport_request_id = req.transport_request_id(); } if (req.has_transport_scheduling_group()) { - auto group_hint = - resolve_transport_scheduling_group_hint(&req.transport_scheduling_group()); + auto group_hint = resolve_transport_scheduling_group_hint(&req.transport_scheduling_group()); if (group_hint.has_value()) { hints.transport_scheduling_group = std::move(*group_hint); } From 89f405b3c27a14ef93582c4f2ae8a90dd71d6ddd Mon Sep 17 00:00:00 2001 From: Lyu Wangrunze <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 6 May 2026 10:53:27 +0800 Subject: [PATCH 45/49] fix: Recover torch and pyyaml versions in pyproject.toml --- pyproject.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7a4827ff..72dc6e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] # Note: The torch version here is the default. When building releases, # use tools/release.sh with --torch-version to specify different versions -requires = ["setuptools>=68.0.0", "pyyaml", "torch==2.9.1", "toml"] +requires = ["setuptools>=68.0.0", "pyyaml", "torch==2.8.0+cu128", "toml"] build-backend = "setuptools.build_meta" [project] @@ -20,7 +20,7 @@ dependencies = [ "pytz>=2025.2", "prometheus-client>=0.21.1", "py-grpc-prometheus>=0.8.0", - "torch==2.9.1", + "torch==2.8.0+cu128", "protobuf==6.31.1", # OpenTelemetry is a required dependency. No automatic downgrade. "opentelemetry-api>=1.36.0", @@ -29,7 +29,7 @@ dependencies = [ "opentelemetry-exporter-otlp-proto-grpc>=1.36.0", "opentelemetry-instrumentation-grpc>=0.57b0", "uvicorn>=0.32.0", - "pyyaml==6.0.1", + "pyyaml>=6.0.2", "duckdb>=1.2.1", "pydantic>=2.0.0", "psutil>=7.0.0", @@ -43,7 +43,7 @@ dev = [ "clang-format>=20.1.0", "clang-tidy>=20.1.0", "einops>=0.8.1", - "grpcio-tools==1.75.1", + "grpcio-tools<=1.73.1", "httpx>=0.28.1", "hypothesis>=6.135.26", "ipdb>=0.13.13", @@ -255,8 +255,8 @@ default = true name = "pypi-public" url = "https://pypi.org/simple" -# [tool.uv.sources] -# torch = { index = "pytorch" } +[tool.uv.sources] +torch = { index = "pytorch" } [[tool.uv.index]] name = "pytorch" From a9663a5571110b14feeb007092adb27b3e52c367 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 6 May 2026 12:00:49 +0800 Subject: [PATCH 46/49] style: fix ruff I001 import ordering --- tensorcast/__init__.py | 2 +- tensorcast/global_store/rpc/broadcast_rpc_handler.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorcast/__init__.py b/tensorcast/__init__.py index 58bd9e27..6ccd7b49 100644 --- a/tensorcast/__init__.py +++ b/tensorcast/__init__.py @@ -425,7 +425,6 @@ def __dir__() -> list[str]: CapabilityDirectoryClient, CapabilityDirectoryOptions, CollectiveLoadGroup, - TransportSchedulingGroup, DirectorySnapshot, ExecutionTopologyContext, GetArtifactOptions, @@ -457,6 +456,7 @@ def __dir__() -> list[str]: TensorCastDirectory, TensorCastSignals, TransformSpec, + TransportSchedulingGroup, Worker, WorkerRoute, WorkerStatus, diff --git a/tensorcast/global_store/rpc/broadcast_rpc_handler.py b/tensorcast/global_store/rpc/broadcast_rpc_handler.py index f2f6c015..b3f1a848 100644 --- a/tensorcast/global_store/rpc/broadcast_rpc_handler.py +++ b/tensorcast/global_store/rpc/broadcast_rpc_handler.py @@ -21,7 +21,6 @@ from tensorcast.proto.common.v1 import common_pb2 from tensorcast.proto.global_store.v1 import global_store_pb2 - _SESSION_STATE_TO_PROTO = { BroadcastSessionState.PLANNING: global_store_pb2.BROADCAST_SESSION_STATE_PLANNING, BroadcastSessionState.ACTIVE: global_store_pb2.BROADCAST_SESSION_STATE_ACTIVE, From 35dcccf9998845a45740c5f390393abccf84a368 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 6 May 2026 12:23:16 +0800 Subject: [PATCH 47/49] style: format global store repositories --- .../global_store/repositories/broadcast_repository.py | 8 ++------ .../global_store/repositories/replica_repository.py | 4 +++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorcast/global_store/repositories/broadcast_repository.py b/tensorcast/global_store/repositories/broadcast_repository.py index 35f1c302..150e72ec 100644 --- a/tensorcast/global_store/repositories/broadcast_repository.py +++ b/tensorcast/global_store/repositories/broadcast_repository.py @@ -789,15 +789,11 @@ def _row_to_target( return BroadcastTarget( session_id=str(row[idx["session_id"]]), target_worker_id=str(row[idx["target_worker_id"]]), - target_daemon_id=cls._normalize_optional_text( - row[idx["target_daemon_id"]] - ), + target_daemon_id=cls._normalize_optional_text(row[idx["target_daemon_id"]]), state=BroadcastTargetState(str(row[idx["state"]])), level=int(raw_level) if raw_level is not None else None, attempt=int(row[idx["attempt"]]), - assigned_edge_id=cls._normalize_optional_text( - row[idx["assigned_edge_id"]] - ), + assigned_edge_id=cls._normalize_optional_text(row[idx["assigned_edge_id"]]), completed_replica_id=cls._uuid_or_none(row[idx["completed_replica_id"]]), failure_reason=cls._normalize_optional_text(row[idx["failure_reason"]]), created_at=cls._coerce_datetime_optional(row[idx["created_at"]]), diff --git a/tensorcast/global_store/repositories/replica_repository.py b/tensorcast/global_store/repositories/replica_repository.py index 6de4ab77..c42cb2e8 100644 --- a/tensorcast/global_store/repositories/replica_repository.py +++ b/tensorcast/global_store/repositories/replica_repository.py @@ -613,7 +613,9 @@ def claim_replica_for_transport( + "AND mr.artifact_id = ? " + "AND COALESCE(mr.view_id, '') = COALESCE(?, '')" ) - result = cursor.execute(query, [str(replica_id), artifact_id, view_id or ""]) + result = cursor.execute( + query, [str(replica_id), artifact_id, view_id or ""] + ) row = result.fetchone() if row is None: return TransportSelectionResult(replica=None, exportable_replicas=0) From 21577ff0faea13d704d17fd419b284525dac8792 Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 6 May 2026 13:26:50 +0800 Subject: [PATCH 48/49] fix(types): resolve pyright errors and __all__ warnings - Add missing TYPE_CHECKING imports in tensorcast/__init__.py for names exposed via _LAZY_ATTRS so they satisfy __all__ - Restructure broadcast transport claim with try/except/else so pyright narrows replica/edge as non-None on the success path - Pass transport_request_id and transport_scheduling_group as typed named arguments instead of via dict[str, object] **kwargs to match materialize_by_artifact_id_v2 overloads --- tensorcast/__init__.py | 18 ++++++++++++ .../services/transport_service.py | 5 +--- tensorcast/node_agent/executor.py | 28 ++++++++++--------- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/tensorcast/__init__.py b/tensorcast/__init__.py index 6ccd7b49..167547d8 100644 --- a/tensorcast/__init__.py +++ b/tensorcast/__init__.py @@ -421,14 +421,25 @@ def __dir__() -> list[str]: ArtifactDescriptor, ArtifactError, ArtifactFuture, + BindingRealizationEntry, + BindingRealizationPlan, + BindingUpdateEpoch, + BindingValueRef, + BroadcastContext, CallContext, + CanonicalIndex, + CanonicalIndexEntry, CapabilityDirectoryClient, CapabilityDirectoryOptions, CollectiveLoadGroup, DirectorySnapshot, + ExecutionDiagnostics, ExecutionTopologyContext, GetArtifactOptions, GovernanceContext, + HashBackend, + HashLocation, + IdentityMintStrategy, Instance, InstanceExecutionRoute, Operation, @@ -442,6 +453,7 @@ def __dir__() -> list[str]: PlanStepResult, PlanType, PreparedServingRegistration, + PublicDiskSourceHandle, RegisterArtifactOptions, RegisteredArtifact, RegisteredLease, @@ -449,7 +461,9 @@ def __dir__() -> list[str]: RegistrationResult, RetentionHandle, Runtime, + ServingPublicationSubject, SignalSnapshot, + SourceBoundCapability, Store, StoreOptions, TargetSpec, @@ -470,21 +484,25 @@ def __dir__() -> list[str]: from tensorcast.api.store import ( # type: ignore[no-redef] # noqa: F401 artifact, artifact_async, + binding_realization_plan_to_proto, build_serving_publication_bundle, build_serving_publication_bundle_from_registered_artifact, complete_pure_transform_publication, deregister_artifact, from_disk, + normalize_binding_realization_plan, persistence_operation, prepare_serving_registration, put, put_async, query_persistence_status, + realize_into_binding, register, register_async, register_pure_transform_publication, register_view, register_vram_region, + resolve_public_disk_source, store, unregister_vram_region, ) diff --git a/tensorcast/global_store/services/transport_service.py b/tensorcast/global_store/services/transport_service.py index c4c78754..022ee527 100644 --- a/tensorcast/global_store/services/transport_service.py +++ b/tensorcast/global_store/services/transport_service.py @@ -451,10 +451,7 @@ def _request_transport_broadcast( ) except (NotFoundError, ValidationError) as exc: claim_error = exc - replica = None - edge = None - - if claim_error is None: + else: transport = self._build_transport( replica=replica, artifact_id=artifact_id, diff --git a/tensorcast/node_agent/executor.py b/tensorcast/node_agent/executor.py index 91dd4d8a..02be0da3 100644 --- a/tensorcast/node_agent/executor.py +++ b/tensorcast/node_agent/executor.py @@ -1273,21 +1273,22 @@ def _materialize_selection( else: target_device_type = store_daemon_pb2.DeviceType.DEVICE_TYPE_GPU device_uuid = device_uuid_for(device_id) - transport_kwargs: dict[str, object] = {} + transport_request_id: str | None = None + transport_scheduling_group: ( + store_daemon_pb2.TransportSchedulingGroupHint | None + ) = None transport_group = call_ctx.transport_group if call_ctx is not None else None if transport_group is not None: - transport_kwargs["transport_request_id"] = ( - _transport_request_id_for_selection( - group=transport_group, - daemon_id=self._daemon_id, - selection=selection, - device_id=device_id, - device_uuid=device_uuid, - ) + transport_request_id = _transport_request_id_for_selection( + group=transport_group, + daemon_id=self._daemon_id, + selection=selection, + device_id=device_id, + device_uuid=device_uuid, + ) + transport_scheduling_group = _transport_group_to_daemon_proto( + transport_group ) - transport_group_proto = _transport_group_to_daemon_proto(transport_group) - if transport_group_proto is not None: - transport_kwargs["transport_scheduling_group"] = transport_group_proto self._client.materialize_by_artifact_id_v2( selection=selection, replica_uuid=replica_uuid, @@ -1297,7 +1298,8 @@ def _materialize_selection( target_device_type=target_device_type, lease_mode=store_daemon_pb2.LeaseMode.LEASE_MODE_NO_LEASE, timeout_s=timeout_s, - **transport_kwargs, + transport_request_id=transport_request_id, + transport_scheduling_group=transport_scheduling_group, ) def _pin( From 4c8be022c3e7ef4017afb27b7c7048be40aa114b Mon Sep 17 00:00:00 2001 From: FernanDAlumin <69069347+FernanDAlumin@users.noreply.github.com> Date: Wed, 6 May 2026 13:51:22 +0800 Subject: [PATCH 49/49] fix(types): narrow cursor type in BroadcastRepository Add after the guard so strict type checkers can narrow to non-optional inside the try/finally bodies. Resolves 22 union-attr errors in broadcast_repository.py. --- .../global_store/repositories/broadcast_repository.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensorcast/global_store/repositories/broadcast_repository.py b/tensorcast/global_store/repositories/broadcast_repository.py index 150e72ec..b4758524 100644 --- a/tensorcast/global_store/repositories/broadcast_repository.py +++ b/tensorcast/global_store/repositories/broadcast_repository.py @@ -55,6 +55,7 @@ def create_session( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: cursor.execute( """ @@ -93,6 +94,7 @@ def find_session( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: query = cursor.execute( f""" @@ -123,6 +125,7 @@ def update_session_state( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: completed_sql = ( ", completed_at = CURRENT_TIMESTAMP" @@ -161,6 +164,7 @@ def upsert_target( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: updated_at = datetime.now() cursor.execute( @@ -215,6 +219,7 @@ def find_target( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: query = cursor.execute( f""" @@ -244,6 +249,7 @@ def list_targets( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: query = cursor.execute( f""" @@ -274,6 +280,7 @@ def list_targets_by_state( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: query = cursor.execute( f""" @@ -363,6 +370,7 @@ def find_edge( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: query = cursor.execute( f""" @@ -392,6 +400,7 @@ def list_edges( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: query = cursor.execute( f""" @@ -422,6 +431,7 @@ def find_active_edge_for_child( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: query = cursor.execute( f""" @@ -681,6 +691,7 @@ def count_incomplete_targets( owns_cursor = cursor is None if owns_cursor: cursor = self.get_cursor() + assert cursor is not None try: row = cursor.execute( """