diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27555d70cb..003c1c0098 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -184,7 +184,7 @@ jobs: cd build export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib ldconfig -v || echo "always continue" - MC_METADATA_SERVER=http://127.0.0.1:8080/metadata DEFAULT_KV_LEASE_TTL=500 ctest -j --output-on-failure -E ub_transport_test + MC_METADATA_SERVER=http://127.0.0.1:8080/metadata DEFAULT_KV_LEASE_TTL=500 ctest -j --output-on-failure shell: bash - name: Drain HTTP E2E test diff --git a/mooncake-common/FindUrma.cmake b/mooncake-common/FindUrma.cmake index d2d93cc38b..0af8d1a7ca 100644 --- a/mooncake-common/FindUrma.cmake +++ b/mooncake-common/FindUrma.cmake @@ -4,7 +4,7 @@ include(FetchContent) FetchContent_Declare( urma GIT_REPOSITORY https://atomgit.com/openeuler/umdk.git - GIT_TAG v25.12.0 + GIT_TAG v25.12.0.B081 ) FetchContent_MakeAvailable(urma) diff --git a/mooncake-p2p-store/build.sh b/mooncake-p2p-store/build.sh index 66eff5f107..72aca4b3c6 100644 --- a/mooncake-p2p-store/build.sh +++ b/mooncake-p2p-store/build.sh @@ -53,6 +53,10 @@ if [ -d "/usr/local/musa/lib" ]; then EXT_LDFLAGS+=" -L/usr/local/musa/lib -lmusart" fi +if [ -e "/usr/lib64/liburma.so" ]; then + EXT_LDFLAGS+=" -L/usr/lib64 -lurma" +fi + if [ "$USE_ETCD" = "ON" ]; then if [ "$USE_ETCD_LEGACY" = "ON" ]; then EXT_LDFLAGS+=" -letcd-cpp-api -lprotobuf -lgrpc++ -lgrpc" diff --git a/mooncake-store/include/real_client.h b/mooncake-store/include/real_client.h index bcf9b67465..0f73edddb5 100644 --- a/mooncake-store/include/real_client.h +++ b/mooncake-store/include/real_client.h @@ -766,11 +766,22 @@ class RealClient : public PyClient { } }; + struct UbSegmentDeleter { + size_t size = 0; + std::string protocol = "ub"; + void operator()(void *ptr) const { + if (ptr && size > 0) { + free_memory(protocol.c_str(), ptr); + } + } + }; + std::vector> hugepage_segment_ptrs_; std::vector> segment_ptrs_; std::vector> ascend_segment_ptrs_; + std::vector> ub_segment_ptrs_; std::string protocol; std::string device_name; std::string local_hostname; diff --git a/mooncake-store/src/client_service.cpp b/mooncake-store/src/client_service.cpp index 30abd6af15..f69d25c6c8 100644 --- a/mooncake-store/src/client_service.cpp +++ b/mooncake-store/src/client_service.cpp @@ -569,6 +569,23 @@ ErrorCode Client::InitTransferEngine( LOG(ERROR) << "Failed to install CXL transport"; return ErrorCode::INTERNAL_ERROR; } + } else if (protocol == "ub") { + auto deviceName = device_names.value_or("bonding_dev_0"); + LOG(ERROR) << "ub protocol entable device names is " << deviceName; + auto devices = splitString(deviceName, ',', true); + auto topology = transfer_engine_->getLocalTopology(); + if (topology) { + topology->discover(devices); + LOG(INFO) << "Topology discovery complete with specified " + "devices. Found " + << topology->getHcaList().size() << " HCAs"; + } + transport = transfer_engine_->installTransport("ub", nullptr); + if (!transport) { + LOG(ERROR) << "Failed to install ub transport with specified " + "devices"; + return ErrorCode::INTERNAL_ERROR; + } } else { LOG(ERROR) << "unsupported_protocol protocol=" << protocol; return ErrorCode::INVALID_PARAMS; diff --git a/mooncake-store/src/real_client.cpp b/mooncake-store/src/real_client.cpp index b0cab9369a..0af42bbcc6 100644 --- a/mooncake-store/src/real_client.cpp +++ b/mooncake-store/src/real_client.cpp @@ -850,6 +850,9 @@ tl::expected RealClient::setup_internal( if (this->protocol == "ascend" || this->protocol == "ubshmem") { ascend_segment_ptrs_.emplace_back( ptr, AscendSegmentDeleter{this->protocol}); + } else if (this->protocol == "ub") { + ub_segment_ptrs_.emplace_back(ptr, + UbSegmentDeleter{mapped_size}); } else if (!seg_numa_nodes.empty() || should_use_hugepage) { // NUMA-segmented or hugepage: track as mmap allocation for // munmap cleanup @@ -1106,6 +1109,7 @@ tl::expected RealClient::tearDownAll_internal() { client_buffer_allocator_.reset(); port_binder_.reset(); hugepage_segment_ptrs_.clear(); + ub_segment_ptrs_.clear(); segment_ptrs_.clear(); local_hostname = ""; device_name = ""; diff --git a/mooncake-store/src/utils.cpp b/mooncake-store/src/utils.cpp index 2b9639fdce..810a7cf871 100644 --- a/mooncake-store/src/utils.cpp +++ b/mooncake-store/src/utils.cpp @@ -2,6 +2,7 @@ #include "mmap_arena.h" #include "config.h" #include "common.h" +#include "ub_allocator.h" #include #include @@ -117,6 +118,11 @@ void *allocate_buffer_allocator_memory(size_t total_size, return ascend_allocate_memory(total_size, protocol); } #endif +#if defined(USE_UB) + if (protocol == "ub") { + return mooncake::ub_allocate_memory(alignment, total_size); + } +#endif #ifdef USE_NOF if (use_spdk_dma && total_size > 0) { return mooncake::SpdkWrapper::GetInstance().Alloc(total_size, alignment, @@ -371,7 +377,12 @@ void free_memory(const std::string &protocol, void *ptr) { return ascend_free_memory(protocol, ptr); } #endif - +#if defined(USE_UB) + if (protocol == "ub") { + mooncake::ub_free_memory(ptr); + return; + } +#endif free(ptr); } diff --git a/mooncake-transfer-engine/include/CMakeLists.txt b/mooncake-transfer-engine/include/CMakeLists.txt index 4f52e1e5cd..56929077df 100644 --- a/mooncake-transfer-engine/include/CMakeLists.txt +++ b/mooncake-transfer-engine/include/CMakeLists.txt @@ -7,5 +7,6 @@ install(FILES multi_transport.h DESTINATION include) install(FILES topology.h DESTINATION include) install(FILES transfer_engine.h DESTINATION include) install(FILES transfer_metadata.h DESTINATION include) +install(FILES ub_allocator.h DESTINATION include) install(FILES common/base/status.h DESTINATION include/common/base) install(FILES transport/transport.h DESTINATION include/transport) diff --git a/mooncake-transfer-engine/include/transport/kunpeng_transport/ub_context.h b/mooncake-transfer-engine/include/transport/kunpeng_transport/ub_context.h index 7240ff7bbd..fd7f937141 100644 --- a/mooncake-transfer-engine/include/transport/kunpeng_transport/ub_context.h +++ b/mooncake-transfer-engine/include/transport/kunpeng_transport/ub_context.h @@ -83,6 +83,7 @@ class UbEndpointStore { virtual std::shared_ptr insertEndpoint( const std::string& peer_nic_path, UbContext* context) = 0; virtual int deleteEndpoint(const std::string& peer_nic_path) = 0; + virtual int deleteEndpointByPtr(UbEndPoint* point_ptr) = 0; virtual void evictEndpoint() = 0; virtual void reclaimEndpoint() = 0; virtual size_t getSize() = 0; @@ -102,6 +103,7 @@ class UbSIEVEEndpointStore : public UbEndpointStore { std::shared_ptr insertEndpoint(const std::string& peer_nic_path, UbContext* context) override; int deleteEndpoint(const std::string& peer_nic_path) override; + int deleteEndpointByPtr(UbEndPoint* point_ptr) override; void evictEndpoint() override; void reclaimEndpoint() override; size_t getSize() override; @@ -229,6 +231,10 @@ class UbContext { return endpoint_store_->deleteEndpoint(peer_nic_path); } + int deleteEndpointByPtr(UbEndPoint* point_ptr) { + return endpoint_store_->deleteEndpointByPtr(point_ptr); + } + int disconnectAllEndpoints() { return endpoint_store_->disconnect(); } // Device name, such as `mlx5_3` diff --git a/mooncake-transfer-engine/include/transport/transport.h b/mooncake-transfer-engine/include/transport/transport.h index 60f46fba39..aca0a0d15c 100644 --- a/mooncake-transfer-engine/include/transport/transport.h +++ b/mooncake-transfer-engine/include/transport/transport.h @@ -132,6 +132,7 @@ class Transport { uint32_t max_retry_cnt; void *r_seg; void *l_seg; + void *endpoint; } ub; struct { void *dest_addr; diff --git a/mooncake-transfer-engine/include/ub_allocator.h b/mooncake-transfer-engine/include/ub_allocator.h new file mode 100644 index 0000000000..a753f42165 --- /dev/null +++ b/mooncake-transfer-engine/include/ub_allocator.h @@ -0,0 +1,11 @@ +#pragma once + +namespace mooncake { + +void* ub_allocate_memory(size_t alignment, size_t total_size); + +void ub_free_memory(void* ptr); + +bool ub_is_store_memory(void* addr, size_t length); + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/kunpeng_transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/kunpeng_transport/CMakeLists.txt index c861696859..5583b7d813 100644 --- a/mooncake-transfer-engine/src/transport/kunpeng_transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/kunpeng_transport/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB UB_SOURCES "*.cpp" "urma/urma_endpoint.cpp") +file(GLOB UB_SOURCES "*.cpp" "urma/urma_endpoint.cpp" "ub_allocator.cpp") # Check if liburma.so exists find_library(URMA_LIBRARY urma PATHS /usr/lib64) diff --git a/mooncake-transfer-engine/src/transport/kunpeng_transport/ub_allocator.cpp b/mooncake-transfer-engine/src/transport/kunpeng_transport/ub_allocator.cpp new file mode 100644 index 0000000000..609ecf9843 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/kunpeng_transport/ub_allocator.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "ub_allocator.h" + +namespace mooncake { +struct UbStoreMemRange { + void* base; + size_t size; +}; +std::mutex g_ub_store_mem_mutex; +std::vector g_ub_store_mem_ranges; + +size_t remove_store_memory_range(void* ptr) { + std::lock_guard store_lock(g_ub_store_mem_mutex); + + auto it = std::find_if( + g_ub_store_mem_ranges.begin(), g_ub_store_mem_ranges.end(), + [ptr](const UbStoreMemRange& range) { return range.base == ptr; }); + + if (it == g_ub_store_mem_ranges.end()) { + LOG(ERROR) << "failed for UB protocol, addr at " << ptr; + return 0; + } + + size_t sz = it->size; + g_ub_store_mem_ranges.erase(it); + return sz; +} + +void* ub_allocate_memory(size_t alignment, size_t total_size) { + void* ptr = numa_alloc_local(total_size); + if (!ptr) { + LOG(ERROR) << "failed for UB protocol, size=" << total_size + << ", alignment : " << alignment; + return nullptr; + } + LOG(INFO) << "UB: allocated total size : " << total_size + << ", alignment : " << alignment << " addr at " << ptr; + + std::lock_guard store_lock(g_ub_store_mem_mutex); + g_ub_store_mem_ranges.push_back({ptr, total_size}); + + return ptr; +} + +void ub_free_memory(void* ptr) { + if (!ptr) { + return; + } + auto size = remove_store_memory_range(ptr); + numa_free(ptr, size); + LOG(INFO) << "UB: freed bytes at " << ptr; +} + +bool ub_is_store_memory(void* addr, size_t length) { + if (!addr || length == 0) return false; + auto addr_start = reinterpret_cast(addr); + uintptr_t addr_end = addr_start + length; + std::lock_guard lock(g_ub_store_mem_mutex); + for (const auto& range : g_ub_store_mem_ranges) { + auto range_start = reinterpret_cast(range.base); + uintptr_t range_end = range_start + range.size; + if (addr_start >= range_start && addr_end <= range_end) { + return true; + } + } + return false; +} + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/kunpeng_transport/ub_context.cpp b/mooncake-transfer-engine/src/transport/kunpeng_transport/ub_context.cpp index 4814a183f8..ac10faa515 100644 --- a/mooncake-transfer-engine/src/transport/kunpeng_transport/ub_context.cpp +++ b/mooncake-transfer-engine/src/transport/kunpeng_transport/ub_context.cpp @@ -80,6 +80,29 @@ int UbSIEVEEndpointStore::deleteEndpoint(const std::string& peer_nic_path) { return 0; } +int UbSIEVEEndpointStore::deleteEndpointByPtr(UbEndPoint* point_ptr) { + RWSpinlock::WriteGuard guard(endpoint_map_lock_); + for (auto iter = endpoint_map_.begin(); iter != endpoint_map_.end(); + iter++) { + if (iter->second.first.get() == point_ptr) { + std::string peer_nic_path = iter->first; + iter->second.first->deconstruct(); + waiting_list_len_++; + waiting_list_.insert(iter->second.first); + auto fifo_iter = fifo_map_[peer_nic_path]; + if (hand_.has_value() && hand_.value() == fifo_iter) { + fifo_iter == fifo_list_.begin() ? hand_ = std::nullopt + : hand_ = std::prev(fifo_iter); + } + fifo_list_.erase(fifo_iter); + fifo_map_.erase(peer_nic_path); + endpoint_map_.erase(iter); + return 0; + } + } + return 0; +} + void UbSIEVEEndpointStore::evictEndpoint() { if (fifo_list_.empty()) { return; @@ -246,6 +269,12 @@ int UbWorkerPool::submitPostSend( auto targetSegment = peer_segment_desc->buffers[buffer_id].tseg[device_id]; slice->ub.r_seg = context_.retrieveRemoteSeg(targetSegment); + if (!slice->ub.r_seg) { + LOG(ERROR) << "[UB] retrieveRemoteSeg failed for target_id=" + << slice->target_id << " buffer_id=" << buffer_id + << " device_id" << device_id + << " dest_addr=" << slice->ub.dest_addr; + } auto peer_nic_path = MakeNicPath(peer_segment_desc->name, peer_segment_desc->devices[device_id].name); @@ -333,7 +362,7 @@ void UbWorkerPool::performPostSend(int thread_id) { } if (!endpoint->active()) { if (endpoint->inactiveTime() > 1.0) - context_.deleteEndpoint(entry.first); + context_.deleteEndpointByPtr(endpoint.get()); // enable for re-establishation for (auto& slice : entry.second) failed_slice_list.push_back(slice); entry.second.clear(); @@ -355,6 +384,10 @@ void UbWorkerPool::performPostSend(int thread_id) { entry.second.clear(); continue; } + // Set endpoint pointer for each slice before submitting + for (auto& slice : entry.second) { + slice->ub.endpoint = endpoint.get(); + } endpoint->submitPostSend(entry.second, failed_slice_list); #endif } @@ -392,9 +425,12 @@ void UbWorkerPool::performPoll(int thread_id) { << context_.nicPath() << ", mark it inactive"; context_.set_active(false); } - context_.deleteEndpoint(slice->peer_nic_path); slice->ub.retry_cnt++; if (slice->ub.retry_cnt >= slice->ub.max_retry_cnt) { + if (slice->ub.endpoint) { + auto ptr = static_cast(slice->ub.endpoint); + context_.deleteEndpointByPtr(ptr); + } slice->markFailed(); processed_slice_count_++; } else { diff --git a/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/mock_urma.cpp b/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/mock_urma.cpp index 020c3a6b19..c392577a6b 100644 --- a/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/mock_urma.cpp +++ b/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/mock_urma.cpp @@ -1,16 +1,26 @@ #include "urma_api.h" -#include -#include +#include +#include #include +#include +#include #include +#include +#include namespace { -std::mutex mock_mutex; + +struct JfcState { + std::mutex mutex; + std::deque pending_ctx; +}; + +std::shared_mutex g_rw_mutex; bool initialized = false; std::vector device_list; std::map context_map; std::map jfce_map; -std::map> jfc_user_ctx_map; +std::map jfc_state_map; std::map jfr_map; std::map seg_map; std::map jetty_map; @@ -32,10 +42,11 @@ urma_eid_info_t mock_eid_info = { .eid = {{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}}, .eid_index = 0}; + } // namespace urma_status_t urma_init(urma_init_attr_t *init_attr) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (initialized) { return URMA_EEXIST; } @@ -44,7 +55,7 @@ urma_status_t urma_init(urma_init_attr_t *init_attr) { } urma_status_t urma_uninit(void) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); initialized = false; for (auto device : device_list) { delete device; @@ -52,7 +63,10 @@ urma_status_t urma_uninit(void) { device_list.clear(); context_map.clear(); jfce_map.clear(); - jfc_user_ctx_map.clear(); + for (auto &kv : jfc_state_map) { + delete kv.second; + } + jfc_state_map.clear(); jfr_map.clear(); seg_map.clear(); jetty_map.clear(); @@ -61,53 +75,81 @@ urma_status_t urma_uninit(void) { } urma_device_t **urma_get_device_list(int *num_devices) { - std::lock_guard lock(mock_mutex); - if (!initialized) { - *num_devices = 0; - return nullptr; - } - - if (device_list.empty()) { - urma_device_t *device = new urma_device_t; - strcpy(device->name, "mock_urma_device"); - strcpy(device->path, "/sys/class/infiniband/mock_device"); - device->type = URMA_TRANSPORT_UB; - device->ops = nullptr; - device->sysfs_dev = nullptr; - device_list.push_back(device); + { + std::shared_lock lock(g_rw_mutex); + if (!initialized) { + *num_devices = 0; + return nullptr; + } + if (!device_list.empty()) { + *num_devices = device_list.size(); + urma_device_t **devices = new urma_device_t *[device_list.size()]; + for (size_t i = 0; i < device_list.size(); ++i) { + devices[i] = device_list[i]; + } + return devices; + } } - - *num_devices = device_list.size(); - urma_device_t **devices = new urma_device_t *[device_list.size()]; - for (size_t i = 0; i < device_list.size(); ++i) { - devices[i] = device_list[i]; + { + std::unique_lock write_lock(g_rw_mutex); + if (!initialized) { + *num_devices = 0; + return nullptr; + } + if (device_list.empty()) { + urma_device_t *device = new urma_device_t; + strcpy(device->name, "mock_urma_device"); + strcpy(device->path, "/sys/class/infiniband/mock_device"); + device->type = URMA_TRANSPORT_UB; + device->ops = nullptr; + device->sysfs_dev = nullptr; + device_list.push_back(device); + } + *num_devices = device_list.size(); + urma_device_t **devices = new urma_device_t *[device_list.size()]; + for (size_t i = 0; i < device_list.size(); ++i) { + devices[i] = device_list[i]; + } + return devices; } - return devices; } urma_device_t *urma_get_device_by_name(const char *name) { - std::lock_guard lock(mock_mutex); - if (!initialized) { - return nullptr; - } - - if (device_list.empty()) { - auto *device = new urma_device_t; - strcpy(device->name, "mock_urma_device"); - strcpy(device->path, "/sys/class/infiniband/mock_device"); - device->type = URMA_TRANSPORT_UB; - device->ops = nullptr; - device->sysfs_dev = nullptr; - device_list.push_back(device); + { + std::shared_lock lock(g_rw_mutex); + if (!initialized) { + return nullptr; + } + if (!device_list.empty()) { + for (auto device : device_list) { + if (strcmp(device->name, name) == 0) { + return device; + } + } + return device_list[0]; + } } - - for (auto device : device_list) { - if (strcmp(device->name, name) == 0) { - return device; + { + std::unique_lock write_lock(g_rw_mutex); + if (!initialized) { + return nullptr; + } + if (device_list.empty()) { + auto *device = new urma_device_t; + strcpy(device->name, "mock_urma_device"); + strcpy(device->path, "/sys/class/infiniband/mock_device"); + device->type = URMA_TRANSPORT_UB; + device->ops = nullptr; + device->sysfs_dev = nullptr; + device_list.push_back(device); } + for (auto device : device_list) { + if (strcmp(device->name, name) == 0) { + return device; + } + } + return device_list.empty() ? nullptr : device_list[0]; } - - return device_list.empty() ? nullptr : device_list[0]; } void urma_free_device_list(urma_device_t **device_list) { @@ -118,7 +160,6 @@ void urma_free_device_list(urma_device_t **device_list) { urma_status_t urma_query_device(urma_device_t *device, urma_device_attr_t *attr) { - std::lock_guard lock(mock_mutex); if (!device || !attr) { return URMA_EINVAL; } @@ -129,7 +170,6 @@ urma_status_t urma_query_device(urma_device_t *device, } urma_eid_info_t *urma_get_eid_list(urma_device_t *device, uint32_t *eid_cnt) { - std::lock_guard lock(mock_mutex); if (!device || !eid_cnt) { return nullptr; } @@ -146,7 +186,7 @@ void urma_free_eid_list(urma_eid_info_t *eid_list) { } urma_context_t *urma_create_context(urma_device_t *device, uint32_t eid_index) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!device) { return nullptr; } @@ -158,7 +198,7 @@ urma_context_t *urma_create_context(urma_device_t *device, uint32_t eid_index) { } urma_status_t urma_delete_context(urma_context_t *ctx) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || context_map.find(ctx) == context_map.end()) { return URMA_EINVAL; } @@ -168,7 +208,7 @@ urma_status_t urma_delete_context(urma_context_t *ctx) { } urma_jfce_t *urma_create_jfce(urma_context_t *ctx) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || context_map.find(ctx) == context_map.end()) { return nullptr; } @@ -178,7 +218,7 @@ urma_jfce_t *urma_create_jfce(urma_context_t *ctx) { } urma_status_t urma_delete_jfce(urma_jfce_t *jfce) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!jfce || jfce_map.find(jfce) == jfce_map.end()) { return URMA_EINVAL; } @@ -188,7 +228,7 @@ urma_status_t urma_delete_jfce(urma_jfce_t *jfce) { } urma_jfc_t *urma_create_jfc(urma_context_t *ctx, urma_jfc_cfg_t *cfg) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || !cfg || context_map.find(ctx) == context_map.end()) { return nullptr; } @@ -201,22 +241,23 @@ urma_jfc_t *urma_create_jfc(urma_context_t *ctx, urma_jfc_cfg_t *cfg) { jfc->comp_events_acked = 0; jfc->async_events_acked = 0; jfc->jfc_cfg = *cfg; - jfc_user_ctx_map[jfc] = std::vector(); + jfc_state_map[jfc] = new JfcState(); return jfc; } urma_status_t urma_delete_jfc(urma_jfc_t *jfc) { - std::lock_guard lock(mock_mutex); - if (!jfc || jfc_user_ctx_map.find(jfc) == jfc_user_ctx_map.end()) { + std::unique_lock lock(g_rw_mutex); + if (!jfc || jfc_state_map.find(jfc) == jfc_state_map.end()) { return URMA_EINVAL; } - jfc_user_ctx_map.erase(jfc); + delete jfc_state_map[jfc]; + jfc_state_map.erase(jfc); delete jfc; return URMA_SUCCESS; } urma_jfr_t *urma_create_jfr(urma_context_t *ctx, urma_jfr_cfg_t *cfg) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || !cfg || context_map.find(ctx) == context_map.end()) { return nullptr; } @@ -226,7 +267,7 @@ urma_jfr_t *urma_create_jfr(urma_context_t *ctx, urma_jfr_cfg_t *cfg) { } urma_status_t urma_delete_jfr(urma_jfr_t *jfr) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!jfr || jfr_map.find(jfr) == jfr_map.end()) { return URMA_EINVAL; } @@ -236,7 +277,7 @@ urma_status_t urma_delete_jfr(urma_jfr_t *jfr) { } urma_target_seg_t *urma_register_seg(urma_context_t *ctx, urma_seg_cfg_t *cfg) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || !cfg || context_map.find(ctx) == context_map.end()) { return nullptr; } @@ -252,7 +293,7 @@ urma_target_seg_t *urma_register_seg(urma_context_t *ctx, urma_seg_cfg_t *cfg) { } urma_status_t urma_unregister_seg(urma_target_seg_t *seg) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!seg || seg_map.find(seg) == seg_map.end()) { return URMA_EINVAL; } @@ -264,7 +305,7 @@ urma_status_t urma_unregister_seg(urma_target_seg_t *seg) { urma_target_seg_t *urma_import_seg(urma_context_t *ctx, urma_seg_t *seg, urma_token_t *token_value, uint64_t addr, urma_import_seg_flag_t flag) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || !seg || !token_value || context_map.find(ctx) == context_map.end()) { return nullptr; @@ -277,7 +318,7 @@ urma_target_seg_t *urma_import_seg(urma_context_t *ctx, urma_seg_t *seg, } urma_status_t urma_unimport_seg(urma_target_seg_t *tseg) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!tseg || seg_map.find(tseg) == seg_map.end()) { return URMA_EINVAL; } @@ -288,8 +329,11 @@ urma_status_t urma_unimport_seg(urma_target_seg_t *tseg) { urma_status_t urma_get_async_event(urma_context_t *ctx, urma_async_event_t *event) { - std::lock_guard lock(mock_mutex); - if (!ctx || !event || context_map.find(ctx) == context_map.end()) { + if (!ctx || !event) { + return URMA_EINVAL; + } + std::shared_lock lock(g_rw_mutex); + if (context_map.find(ctx) == context_map.end()) { return URMA_EINVAL; } return URMA_ETIMEOUT; @@ -298,7 +342,7 @@ urma_status_t urma_get_async_event(urma_context_t *ctx, void urma_ack_async_event(urma_async_event_t *event) {} urma_jetty_t *urma_create_jetty(urma_context_t *ctx, urma_jetty_cfg_t *cfg) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || !cfg || context_map.find(ctx) == context_map.end()) { return nullptr; } @@ -314,7 +358,7 @@ urma_jetty_t *urma_create_jetty(urma_context_t *ctx, urma_jetty_cfg_t *cfg) { } urma_status_t urma_delete_jetty(urma_jetty_t *jetty) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!jetty || jetty_map.find(jetty) == jetty_map.end()) { return URMA_EINVAL; } @@ -324,7 +368,7 @@ urma_status_t urma_delete_jetty(urma_jetty_t *jetty) { } urma_status_t urma_unbind_jetty(urma_jetty_t *jetty) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!jetty || jetty_map.find(jetty) == jetty_map.end()) { return URMA_EINVAL; } @@ -335,7 +379,7 @@ urma_status_t urma_unbind_jetty(urma_jetty_t *jetty) { urma_target_jetty_t *urma_import_jetty(urma_context_t *ctx, urma_rjetty_t *rjetty, urma_token_t *token_value) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!ctx || !rjetty || !token_value || context_map.find(ctx) == context_map.end()) { return nullptr; @@ -348,7 +392,7 @@ urma_target_jetty_t *urma_import_jetty(urma_context_t *ctx, } urma_status_t urma_unimport_jetty(urma_target_jetty_t *tjetty) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!tjetty || target_jetty_map.find(tjetty) == target_jetty_map.end()) { return URMA_EINVAL; } @@ -359,7 +403,7 @@ urma_status_t urma_unimport_jetty(urma_target_jetty_t *tjetty) { urma_status_t urma_bind_jetty(urma_jetty_t *jetty, urma_target_jetty_t *tjetty) { - std::lock_guard lock(mock_mutex); + std::unique_lock lock(g_rw_mutex); if (!jetty || !tjetty || jetty_map.find(jetty) == jetty_map.end() || target_jetty_map.find(tjetty) == target_jetty_map.end()) { return URMA_EINVAL; @@ -369,7 +413,7 @@ urma_status_t urma_bind_jetty(urma_jetty_t *jetty, } urma_status_t urma_modify_jetty(urma_jetty_t *jetty, urma_jetty_attr_t *attr) { - std::lock_guard lock(mock_mutex); + std::shared_lock lock(g_rw_mutex); if (!jetty || !attr || jetty_map.find(jetty) == jetty_map.end()) { return URMA_EINVAL; } @@ -378,19 +422,37 @@ urma_status_t urma_modify_jetty(urma_jetty_t *jetty, urma_jetty_attr_t *attr) { urma_status_t urma_post_jetty_send_wr(urma_jetty_t *jetty, urma_jfs_wr_t *wr, urma_jfs_wr_t **bad_wr) { - std::lock_guard lock(mock_mutex); - if (!jetty || !wr || jetty_map.find(jetty) == jetty_map.end()) { - if (bad_wr) { - *bad_wr = wr; + { + std::shared_lock lock(g_rw_mutex); + if (!jetty || !wr || jetty_map.find(jetty) == jetty_map.end()) { + if (bad_wr) { + *bad_wr = wr; + } + return URMA_EINVAL; } - return URMA_EINVAL; } - urma_jfs_wr_t *current_wr = wr; - while (current_wr) { - jfc_user_ctx_map[jetty->jetty_cfg.jfs_cfg.jfc].push_back( - current_wr->user_ctx); - current_wr = current_wr->next; + urma_jfc_t *jfc = jetty->jetty_cfg.jfs_cfg.jfc; + JfcState *state = nullptr; + { + std::shared_lock lock(g_rw_mutex); + auto it = jfc_state_map.find(jfc); + if (it == jfc_state_map.end()) { + if (bad_wr) { + *bad_wr = wr; + } + return URMA_EINVAL; + } + state = it->second; + } + + { + std::lock_guard jfc_lock(state->mutex); + urma_jfs_wr_t *current_wr = wr; + while (current_wr) { + state->pending_ctx.push_back(current_wr->user_ctx); + current_wr = current_wr->next; + } } if (bad_wr) { @@ -400,18 +462,27 @@ urma_status_t urma_post_jetty_send_wr(urma_jetty_t *jetty, urma_jfs_wr_t *wr, } int urma_poll_jfc(urma_jfc_t *jfc, int num_entries, urma_cr_t *cr_list) { - std::lock_guard lock(mock_mutex); - if (!jfc || !cr_list || - jfc_user_ctx_map.find(jfc) == jfc_user_ctx_map.end()) { - return -1; - } - int available = jfc_user_ctx_map[jfc].size(); - int num_completed = std::min(num_entries, available); - for (int i = 0; i < num_completed; ++i) { - cr_list[i].status = URMA_CR_SUCCESS; - cr_list[i].user_ctx = jfc_user_ctx_map[jfc][i]; - } - jfc_user_ctx_map[jfc].erase(jfc_user_ctx_map[jfc].begin(), - jfc_user_ctx_map[jfc].begin() + num_completed); + JfcState *state = nullptr; + { + std::shared_lock lock(g_rw_mutex); + auto it = jfc_state_map.find(jfc); + if (it == jfc_state_map.end()) { + return -1; + } + state = it->second; + } + + int num_completed = 0; + { + std::lock_guard jfc_lock(state->mutex); + int available = static_cast(state->pending_ctx.size()); + num_completed = std::min(num_entries, available); + for (int i = 0; i < num_completed; ++i) { + cr_list[i].status = URMA_CR_SUCCESS; + cr_list[i].user_ctx = state->pending_ctx[i]; + } + state->pending_ctx.erase(state->pending_ctx.begin(), + state->pending_ctx.begin() + num_completed); + } return num_completed; } diff --git a/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/urma_endpoint.cpp b/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/urma_endpoint.cpp index f7080a764b..0d9b3c61a1 100644 --- a/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/urma_endpoint.cpp +++ b/mooncake-transfer-engine/src/transport/kunpeng_transport/urma/urma_endpoint.cpp @@ -420,13 +420,16 @@ int UrmaContext::openDevice(const std::string& device_name, uint8_t port, return ERR_CONTEXT; } for (int p = 0; p < MAX_PORT_CNT; p++) { - if (dev_attr_.port_attr[p].state == URMA_PORT_ACTIVE) { + auto port_attr = dev_attr_.port_attr[p]; + if (port_attr.state == URMA_PORT_ACTIVE || + port_attr.state == URMA_PORT_ACTIVE_DEFER) { port_ = p; break; } } if (dev_attr_.port_cnt != 0 && - dev_attr_.port_attr[port_].state != URMA_PORT_ACTIVE) { + dev_attr_.port_attr[port_].state != URMA_PORT_ACTIVE && + dev_attr_.port_attr[port_].state != URMA_PORT_ACTIVE_DEFER) { LOG(WARNING) << "Device " << device_name << " not found active port"; if (urma_delete_context(context)) { @@ -528,9 +531,9 @@ int UrmaContext::poll(int num_entries, Transport::Slice** slices, if (!slice) { continue; } + slices[i] = slice; if (cr[i].status == URMA_CR_SUCCESS) { slice->markSuccess(); - slices[i] = slice; continue; } if (cr[i].status != URMA_CR_WR_FLUSH_ERR || @@ -886,6 +889,10 @@ int UrmaEndpoint::submitPostSend( wr.flag.bs.inline_flag = 0; // Check if the jetty is in the imported_jetty_map_ auto it = imported_jetty_map_.find(jetty_list_[jetty_index]); + if (it == imported_jetty_map_.end()) { + LOG(ERROR) << "Jetty not imported for endpoint, tjetty is nullptr" + << jetty_index << ", local_nic="; + } if (it != imported_jetty_map_.end()) { wr.tjetty = it->second; } else { @@ -895,6 +902,8 @@ int UrmaEndpoint::submitPostSend( slice->ts = getCurrentTimeInNano(); slice->status = Transport::Slice::POSTED; slice->ub.jetty_depth = &wr_depth_list_[jetty_index]; + // Set endpoint pointer for each slice before submitting + slice->ub.endpoint = this; } __sync_fetch_and_add(&wr_depth_list_[jetty_index], wr_count); __sync_fetch_and_add(jfc_outstanding_, wr_count); @@ -966,6 +975,8 @@ int UrmaEndpoint::doSetupConnection(int jetty_index, rjetty.jetty_id.eid = eid; rjetty.trans_mode = URMA_TM_RC; rjetty.type = URMA_JETTY; + rjetty.tp_type = URMA_CTP; + rjetty.flag.value = 0; LOG(INFO) << "Peer jetty id = " << peer_jetty_num; urma_target_jetty_t* imported_jetty = urma_import_jetty(context_->urma_context_, &rjetty, &urma_token); diff --git a/scripts/build_wheel.sh b/scripts/build_wheel.sh index feff9662be..fdd5737f0c 100755 --- a/scripts/build_wheel.sh +++ b/scripts/build_wheel.sh @@ -363,6 +363,7 @@ ${AUDITWHEEL_CMD} repair ${OUTPUT_DIR}/*.whl \ --exclude libllm_datadist*.so \ --exclude ascend_transport*.so \ --exclude libaccl_barex.so* \ + --exclude liburma.so* \ -w ${REPAIRED_DIR}/ --plat ${PLATFORM_TAG} # Inject CUDA extensions into the repaired wheel. patchelf (used by auditwheel)