From ac593143d1db2c38861c95115c669ce94cb30354 Mon Sep 17 00:00:00 2001 From: "henry.guo" Date: Sat, 15 Mar 2025 08:38:30 -0700 Subject: [PATCH 01/13] Add Dockerfile.build for a docker image to build infinistore project --- Dockerfile.build | 54 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 Dockerfile.build diff --git a/Dockerfile.build b/Dockerfile.build new file mode 100644 index 0000000..b1530d1 --- /dev/null +++ b/Dockerfile.build @@ -0,0 +1,54 @@ +FROM quay.io/pypa/manylinux_2_28_x86_64 + +RUN yum -y install rdma-core-devel libuv-devel + +RUN dnf clean all +RUN dnf makecache + +# build spdlog +WORKDIR /tmp +RUN git clone --branch v1.15.1 --recurse-submodules https://github.com/gabime/spdlog.git +WORKDIR /tmp/spdlog +RUN cmake -G "Unix Makefiles" && \ + make && \ + make install +RUN rm -rf /tmp/spdlog + +# build fmt +WORKDIR /tmp +RUN git clone --branch 11.1.3 https://github.com/fmtlib/fmt.git +WORKDIR /tmp/fmt +RUN cmake -G "Unix Makefiles" && \ + make && \ + make install +RUN rm -rf /tmp/fmt + +# build flatbuffer +WORKDIR /tmp +RUN git clone --branch v25.2.10 https://github.com/google/flatbuffers.git +WORKDIR /tmp/flatbuffers +RUN cmake -G "Unix Makefiles" && \ + make && \ + make install + +ENV PATH=/usr/local/flatbuffers/bin:$PATH +RUN rm -rf /tmp/flatbuffers + +# Install boost +RUN dnf install -y boost boost-devel + +# Install pybind11 for different versions of built-in python3 by almalinux +RUN /opt/python/cp310-cp310/bin/pip3 install pybind11 +RUN /opt/python/cp311-cp311/bin/pip3 install pybind11 + +# In almalinux, setuptools for python3.12 is not installed +# so install it +RUN /opt/python/cp312-cp312/bin/pip3 install setuptools +RUN /opt/python/cp312-cp312/bin/pip3 install pybind11 + +# The above get the build environment ready! +WORKDIR /app +RUN git config --global --add safe.directory /app + +# Optional: Define an entry point to run the executable directly +# ENTRYPOINT ["/app/build/my_executable"] From e0bafa7e253c986c9fd9f01a0e540284a29e4df6 Mon Sep 17 00:00:00 2001 From: dongmao zhang Date: Sat, 15 Mar 2025 00:27:04 -0700 Subject: [PATCH 02/13] Update README.md Fix badges since previous URLs are deprecated --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c2c10a9..7c4a4d7 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -[![Run pre-commit checks](https://github.com/bd-iaas-us/InfiniStore/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/bd-iaas-us/InfiniStore/actions/workflows/pre-commit.yml) +[![Run pre-commit checks](https://github.com/bytedance/InfiniStore/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/bd-iaas-us/InfiniStore/actions/workflows/pre-commit.yml) [![Slack](https://img.shields.io/badge/Slack-Join%20Us-blue?logo=slack)](https://vllm-dev.slack.com/archives/C07VCUQLE1F) -[![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://bd-iaas-us.github.io/InfiniStore/) +[![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://bytedance.github.io/InfiniStore/) # What's InfiniStore From 39ad73196c4ff38e3775073f3b6a3025749a9149 Mon Sep 17 00:00:00 2001 From: thesues Date: Thu, 13 Mar 2025 20:39:21 +0000 Subject: [PATCH 03/13] Refactor APIs 1. Add TCP blocking APIs 2. Simplify RDMA APIs 3. rewrite all unittest and benchmark --- infinistore/__init__.py | 4 +- infinistore/benchmark.py | 116 ++++- infinistore/example/client.py | 52 +- infinistore/example/client_async.py | 31 +- infinistore/example/client_async_single.py | 42 +- infinistore/example/demo_prefill.py | 10 +- infinistore/lib.py | 483 +++---------------- infinistore/server.py | 72 --- infinistore/test_infinistore.py | 303 ++++++------ src/Makefile | 2 +- src/infinistore.cpp | 528 +++++++++++--------- src/libinfinistore.cpp | 536 +++++++-------------- src/libinfinistore.h | 77 +-- src/mempool.cpp | 6 +- src/mempool.h | 4 +- src/meta_request.fbs | 2 +- src/protocol.cpp | 17 +- src/protocol.h | 28 +- src/pybind.cpp | 117 +---- src/tcp_payload_request.fbs | 7 + src/tcp_payload_request_generated.h | 94 ++++ 21 files changed, 1012 insertions(+), 1519 deletions(-) create mode 100644 src/tcp_payload_request.fbs create mode 100644 src/tcp_payload_request_generated.h diff --git a/infinistore/__init__.py b/infinistore/__init__.py index 5312db8..e96dc30 100644 --- a/infinistore/__init__.py +++ b/infinistore/__init__.py @@ -3,8 +3,8 @@ ClientConfig, ServerConfig, TYPE_RDMA, + TYPE_TCP, Logger, - check_supported, LINK_ETHERNET, LINK_IB, register_server, @@ -20,8 +20,8 @@ "ClientConfig", "ServerConfig", "TYPE_RDMA", + "TYPE_TCP", "Logger", - "check_supported", "LINK_ETHERNET", "LINK_IB", "purge_kv_map", diff --git a/infinistore/benchmark.py b/infinistore/benchmark.py index 7274bb3..7a12357 100644 --- a/infinistore/benchmark.py +++ b/infinistore/benchmark.py @@ -7,6 +7,8 @@ import string import argparse import uuid +import asyncio +import threading def parse_args(): @@ -17,7 +19,7 @@ def parse_args(): required=False, action="store_true", help="use rdma connection, default False", - default=True, + default=False, ) parser.add_argument( @@ -110,6 +112,16 @@ def generate_uuid(): return str(uuid.uuid4()) +def start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +loop = asyncio.new_event_loop() +t = threading.Thread(target=start_loop, args=(loop,)) +t.start() + + def run(args): config = infinistore.ClientConfig( host_addr=args.server, @@ -120,16 +132,24 @@ def run(args): log_level="warning", ) - config.connection_type = infinistore.TYPE_RDMA + if args.rdma: + config.connection_type = infinistore.TYPE_RDMA + else: + config.connection_type = infinistore.TYPE_TCP conn = infinistore.InfinityConnection(config) try: conn.connect() - src_device = "cuda:" + str(args.src_gpu) - dst_device = "cuda:" + str(args.dst_gpu) + # rdma support GPUDirect RDMA, so we can use cuda tensor + if args.rdma: + src_device = "cuda:" + str(args.src_gpu) + dst_device = "cuda:" + str(args.dst_gpu) + else: + src_device = "cpu" + dst_device = "cpu" - block_size = args.block_size * 1024 // 4 + block_size = args.block_size * 1024 // 4 # element size is 4 bytes num_of_blocks = args.size * 1024 * 1024 // (args.block_size * 1024) src_tensor = torch.rand( @@ -140,25 +160,28 @@ def run(args): num_of_blocks * block_size, device=dst_device, dtype=torch.float32 ) - torch.cuda.synchronize(src_tensor.device) - torch.cuda.synchronize(dst_tensor.device) if args.rdma: - conn.register_mr(src_tensor) - conn.register_mr(dst_tensor) + torch.cuda.synchronize(src_tensor.device) + torch.cuda.synchronize(dst_tensor.device) + conn.register_mr( + src_tensor.data_ptr(), src_tensor.numel() * src_tensor.element_size() + ) + conn.register_mr( + dst_tensor.data_ptr(), dst_tensor.numel() * dst_tensor.element_size() + ) # blocks = [(keys[i], offset_blocks[i]) for i in range(num_of_blocks)] write_sum = 0.0 read_sum = 0.0 + element_size = src_tensor.element_size() + for _ in range(args.iteration): keys = [generate_uuid() for i in range(num_of_blocks)] - offset_blocks = [i * block_size for i in range(num_of_blocks)] + offsets = [i * block_size * element_size for i in range(num_of_blocks)] # zip keys and offset_blocks - blocks = list(zip(keys, offset_blocks)) - - if args.rdma: - remote_addrs = conn.allocate_rdma(keys, block_size * 4) - + # blocks = list(zip(keys, offset_blocks)) + blocks = list(zip(keys, offsets)) steps = args.steps # simulate we have layers, this steps should be less then MAX_WR_SIZE while len(blocks) % steps != 0 and steps > 1: @@ -169,23 +192,64 @@ def run(args): start = time.time() + futures = [] for i in range(steps): if args.rdma: - conn.rdma_write_cache( - src_tensor, - offset_blocks[i * n : i * n + n], - block_size, - remote_addrs[i * n : i * n + n], + future = asyncio.run_coroutine_threadsafe( + conn.rdma_write_cache_async( + blocks[i * n : i * n + n], + block_size * element_size, + src_tensor.data_ptr(), + ), + loop, ) - conn.sync() - # print(f"write takes {time.time() - start} seconds") + futures.append(future) + else: + for j in range(n): + key = blocks[i * n + j][0] + ptr = src_tensor.data_ptr() + blocks[i * n + j][1] + conn.tcp_write_cache(key, ptr, block_size * element_size) + + # wait for all the futures to finish + if args.rdma: + for future in futures: + future.result() + else: # TCP + pass mid = time.time() write_sum += mid - start + futures = [] for i in range(steps): - conn.read_cache(dst_tensor, blocks[i * n : i * n + n], block_size) + if args.rdma: + future = asyncio.run_coroutine_threadsafe( + conn.rdma_read_cache_async( + blocks[i * n : i * n + n], + block_size * element_size, + dst_tensor.data_ptr(), + ), + loop, + ) + futures.append(future) + else: + for j in range(n): + key = blocks[i * n + j][0] + # ptr = dst_tensor.data_ptr() + blocks[i*n + j][1] + ret = conn.tcp_read_cache(key) + assert len(ret) == block_size * element_size + # copy data from ret_value to dst_tensor + ret_tensor = torch.from_numpy(ret).view(torch.float32) + offset_in_tensor = blocks[i * n + j][1] // element_size + dst_tensor[offset_in_tensor : offset_in_tensor + block_size] = ( + ret_tensor + ) + + if args.rdma: + for future in futures: + future.result() + else: # TCP + pass - conn.sync() end = time.time() read_sum += end - mid @@ -200,10 +264,12 @@ def run(args): args.size * args.iteration / read_sum, ) ) - + # super important to compare the data assert torch.equal(src_tensor.cpu(), dst_tensor.cpu()) finally: conn.close() + loop.call_soon_threadsafe(loop.stop) + t.join() if __name__ == "__main__": diff --git a/infinistore/example/client.py b/infinistore/example/client.py index 9be379d..1255100 100644 --- a/infinistore/example/client.py +++ b/infinistore/example/client.py @@ -1,11 +1,12 @@ from infinistore import ( ClientConfig, - check_supported, InfinityConnection, ) import infinistore import torch import time +import asyncio +import threading def generate_random_string(length): @@ -17,36 +18,45 @@ def generate_random_string(length): return random_string +def start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +loop = asyncio.new_event_loop() +t = threading.Thread(target=start_loop, args=(loop,)) +t.start() + + +# run is a blocking function, but it could invoke RDMA operations asynchronously +# by using asyncio.run_coroutine_threadsafe and wait for the result by future.result() def run(conn, src_device="cuda:0", dst_device="cuda:2"): - check_supported() src_tensor = torch.tensor( [i for i in range(4096)], device=src_device, dtype=torch.float32 ) - conn.register_mr(src_tensor) + conn.register_mr( + src_tensor.data_ptr(), src_tensor.numel() * src_tensor.element_size() + ) - keys = ["key1", "key2", "key3"] - remote_addr = conn.allocate_rdma( - keys, 1024 * 4 - ) # 1024(block_size) * 4(element size) - # print(f"remote_addr: {remote_addr}") + keys_offsets = [("key1", 0), ("key2", 1024 * 4), ("key3", 2048 * 4)] now = time.time() - conn.rdma_write_cache(src_tensor, [0, 1024, 2048], 1024, remote_addr) - + future = asyncio.run_coroutine_threadsafe( + conn.rdma_write_cache_async(keys_offsets, 1024 * 4, src_tensor.data_ptr()), loop + ) + future.result() print(f"write elapse time is {time.time() - now}") - before_sync = time.time() - conn.sync() - print(f"sync elapse time is {time.time() - before_sync}") - dst_tensor = torch.zeros(4096, device=dst_device, dtype=torch.float32) + conn.register_mr( + dst_tensor.data_ptr(), dst_tensor.numel() * dst_tensor.element_size() + ) - conn.register_mr(dst_tensor) now = time.time() - - conn.read_cache(dst_tensor, [("key1", 0), ("key2", 1024)], 1024) - - conn.sync() + future = asyncio.run_coroutine_threadsafe( + conn.rdma_read_cache_async(keys_offsets, 1024 * 4, dst_tensor.data_ptr()), loop + ) + future.result() print(f"read elapse time is {time.time() - now}") assert torch.equal(src_tensor[0:1024].cpu(), dst_tensor[0:1024].cpu()) @@ -64,6 +74,7 @@ def run(conn, src_device="cuda:0", dst_device="cuda:2"): dev_name="mlx5_0", ) rdma_conn = InfinityConnection(config) + try: rdma_conn.connect() m = [ @@ -75,5 +86,8 @@ def run(conn, src_device="cuda:0", dst_device="cuda:2"): for src, dst in m: print(f"rdma connection: {src} -> {dst}") run(rdma_conn, src, dst) + finally: rdma_conn.close() + loop.call_soon_threadsafe(loop.stop) + t.join() diff --git a/infinistore/example/client_async.py b/infinistore/example/client_async.py index 0aff240..aef479e 100644 --- a/infinistore/example/client_async.py +++ b/infinistore/example/client_async.py @@ -34,31 +34,28 @@ async def main(): dst_tensor = torch.zeros(4096, device="cpu", dtype=torch.float32) - rdma_conn.register_mr(src_tensor) - rdma_conn.register_mr(dst_tensor) + rdma_conn.register_mr( + src_tensor.data_ptr(), src_tensor.numel() * src_tensor.element_size() + ) + rdma_conn.register_mr( + dst_tensor.data_ptr(), dst_tensor.numel() * src_tensor.element_size() + ) keys = [generate_uuid() for _ in range(3)] - remote_addr = await rdma_conn.allocate_rdma_async(keys, 1024 * 4) - print(f"remote addrs is {remote_addr}") - - # await rdma_conn.rdma_write_cache_async( - # src_tensor, [0, 1024], 1024, remote_addr[:2] - # ) - # await rdma_conn.rdma_write_cache_async( - # src_tensor, [2048], 1024, remote_addr[2:] - # ) await asyncio.gather( rdma_conn.rdma_write_cache_async( - src_tensor, [0, 1024], 1024, remote_addr[:2] + [(keys[0], 0), (keys[1], 1024 * 4)], 1024 * 4, src_tensor.data_ptr() + ), + rdma_conn.rdma_write_cache_async( + [(keys[2], 2048 * 4)], 1024 * 4, src_tensor.data_ptr() ), - rdma_conn.rdma_write_cache_async(src_tensor, [2048], 1024, remote_addr[2:]), ) - - await rdma_conn.read_cache_async( - dst_tensor, [(keys[0], 0), (keys[1], 1024), (keys[2], 2048)], 1024 + await rdma_conn.rdma_read_cache_async( + [(keys[0], 0), (keys[1], 1024 * 4), (keys[2], 2048 * 4)], + 1024 * 4, + dst_tensor.data_ptr(), ) - assert torch.equal(src_tensor[0:3072].cpu(), dst_tensor[0:3072].cpu()) rdma_conn.close() diff --git a/infinistore/example/client_async_single.py b/infinistore/example/client_async_single.py index 6aeca33..61c8796 100644 --- a/infinistore/example/client_async_single.py +++ b/infinistore/example/client_async_single.py @@ -2,6 +2,7 @@ import uuid import asyncio import ctypes +import time def generate_uuid(): @@ -11,7 +12,7 @@ def generate_uuid(): config = infinistore.ClientConfig( host_addr="127.0.0.1", service_port=12345, - log_level="info", + log_level="warning", connection_type=infinistore.TYPE_RDMA, ib_port=1, link_type=infinistore.LINK_ETHERNET, @@ -33,26 +34,45 @@ async def main(): # src = torch.randn(4096, device="cpu", dtype=torch.float32) # dst = torch.zeros(4096, device="cpu", dtype=torch.float32) - src = bytearray(100) - dst = memoryview(bytearray(100)) + size = 128 * 1024 + src = bytearray(size) + dst = memoryview(bytearray(size)) def register_mr(): - rdma_conn.register_mr(get_ptr(src), 100) - rdma_conn.register_mr(get_ptr(dst), 100) + rdma_conn.register_mr(get_ptr(src), len(src)) + rdma_conn.register_mr(get_ptr(dst), len(dst)) await asyncio.to_thread(register_mr) + # set src + for i in range(size): + src[i] = i % 256 + is_exist = await asyncio.to_thread(rdma_conn.check_exist, key) assert not is_exist - await rdma_conn.rdma_write_cache_single_async(key, get_ptr(src), 100) - try: - await rdma_conn.read_cache_single_async(key, get_ptr(dst), 100) - except infinistore.InfiniStoreKeyNotFound: - print("Key not found") + now = time.time() + tasks = [] + N = 1000 + for i in range(N): + tasks.append( + rdma_conn.rdma_write_cache_async( + [(key + str(i), 0)], len(src), get_ptr(src) + ) + ) + await asyncio.gather(*tasks, return_exceptions=True) + print("write Time taken: ", time.time() - now) + + now = time.time() + tasks = [] + for i in range(N): + tasks.append( + rdma_conn.rdma_read_cache_async([(key + str(i), 0)], len(dst), get_ptr(dst)) + ) + await asyncio.gather(*tasks, return_exceptions=True) + print("read Time taken: ", time.time() - now) assert src == dst - rdma_conn.close() diff --git a/infinistore/example/demo_prefill.py b/infinistore/example/demo_prefill.py index 8f6e5e9..2f8d627 100644 --- a/infinistore/example/demo_prefill.py +++ b/infinistore/example/demo_prefill.py @@ -1,7 +1,5 @@ from infinistore import ( ClientConfig, - check_supported, - DisableTorchCaching, InfinityConnection, ) import infinistore @@ -41,11 +39,9 @@ def forward(self, x): def run(conn): - check_supported() - with DisableTorchCaching(): - model = nn.Sequential( - *[TransformerLayer(N, num_heads) for _ in range(num_layers)] - ).cuda() + model = nn.Sequential( + *[TransformerLayer(N, num_heads) for _ in range(num_layers)] + ).cuda() input = torch.randn(seq_length, 1, N, device="cuda:0", dtype=torch.float16) diff --git a/infinistore/lib.py b/infinistore/lib.py index 9176964..cb23fbb 100644 --- a/infinistore/lib.py +++ b/infinistore/lib.py @@ -2,16 +2,17 @@ # sphinx-doc will mock infinistore._infinistore, it has to be written like this -import torch import os import subprocess import asyncio from functools import singledispatchmethod from typing import Optional, Union, List, Tuple +import numpy as np # connection type: default is RDMA TYPE_RDMA = "RDMA" +TYPE_TCP = "TCP" # rdma link type LINK_ETHERNET = "Ethernet" LINK_IB = "IB" @@ -66,7 +67,7 @@ def __repr__(self): ) def verify(self): - if self.connection_type not in [TYPE_RDMA]: + if self.connection_type not in [TYPE_RDMA, TYPE_TCP]: raise Exception("Invalid connection type") if self.host_addr == "": raise Exception("Host address is empty") @@ -251,16 +252,6 @@ def _check_rdma_devices_ibv(): ) -def check_supported(): - # check if kernel module nv_peer_mem is available - if ( - "nv_peer_mem" not in _kernel_modules() - and "nvidia_peermem" not in _kernel_modules() # noqa: W503 - ): - Logger.warn("nv_peer_mem or nvidia_peermem module is not loaded") - _check_rdma_devices_ibv() - - class InfinityConnection: """ A class to manage connections and data transfers with an Infinistore instance using RDMA connections. @@ -280,7 +271,7 @@ def __init__(self, config: ClientConfig): self.config = config # used for async io - self.semaphore = asyncio.BoundedSemaphore(32) + self.semaphore = asyncio.BoundedSemaphore(128) Logger.set_log_level(config.log_level) async def connect_async(self): @@ -301,9 +292,10 @@ async def connect_async(self): def blocking_connect(): if self.conn.init_connection(self.config) < 0: raise Exception("Failed to initialize remote connection") - if self.conn.setup_rdma(self.config) < 0: - raise Exception("Failed to setup RDMA connection") - self.rdma_connected = True + if self.config.connection_type == TYPE_RDMA: + if self.conn.setup_rdma(self.config) < 0: + raise Exception("Failed to setup RDMA connection") + self.rdma_connected = True await loop.run_in_executor(None, blocking_connect) @@ -324,126 +316,46 @@ def connect(self): if ret < 0: raise Exception("Failed to initialize remote connection") - ret = self.conn.setup_rdma(self.config) - if ret < 0: - raise Exception(f"Failed to write to infinistore, ret = {ret}") - self.rdma_connected = True + if self.config.connection_type == TYPE_RDMA: + ret = self.conn.setup_rdma(self.config) + if ret < 0: + raise Exception(f"Failed to write to infinistore, ret = {ret}") + self.rdma_connected = True - async def rdma_write_cache_async( - self, cache: torch.Tensor, offsets: List[int], page_size, remote_blocks: List - ): + def close(self): """ - Asynchronously writes a cache tensor to remote memory using RDMA. - - Args: - cache (torch.Tensor): The tensor to be written to remote memory. - offsets (List[int]): List of offsets where the tensor data should be written. - page_size (int): The size of each page in the remote memory. - remote_blocks (List): List of remote memory blocks where the data will be written. - - Raises: - Exception: If RDMA is not connected. - - Returns: - asyncio.Future: A future that will be set to 0 when the write operation is complete. + Closes the connection to the Infinistore instance. """ - if not self.rdma_connected: - raise Exception("this function is only valid for connected rdma") - - self._verify(cache) - element_size = cache.element_size() - - # each offset should multiply by the element size - offsets_in_bytes = [offset * element_size for offset in offsets] - - loop = asyncio.get_running_loop() - - future = loop.create_future() - - await self.semaphore.acquire() - - def _callback(): - loop.call_soon_threadsafe(future.set_result, 0) - self.semaphore.release() - - self.conn.w_rdma_async( - offsets_in_bytes, - page_size * element_size, - remote_blocks, - cache.data_ptr(), - _callback, - ) - return await future + self.conn.close() - async def rdma_write_cache_single_async( - self, key: str, ptr: int, size: int, **kwargs - ): + def tcp_read_cache(self, key: str, **kwargs) -> np.ndarray: """ - Asynchronously writes data to the RDMA cache. - - This function writes data to the RDMA cache using the provided key, pointer, and size. - It ensures that the RDMA connection is established and the input parameters are valid. + Retrieve a single cached item from the TCP connection. - Args: - key (str): The key associated with the data to be written. - ptr (int): The memory address of the data to be written. - size (int): The size of the data to be written. - **kwargs: Additional keyword arguments. - - Raises: - Exception: If the RDMA connection is not established. - Exception: If the key is empty. - Exception: If the size is 0. - Exception: If the pointer is 0. - Exception: If writing to Infinistore fails. + Parameters: + key (str): The key associated with the cached item. + **kwargs: Additional keyword arguments. Returns: - int: A future that resolves to 0 upon successful completion of the write operation. + np.ndarray: The cached item retrieved from the TCP connection. """ - if not self.rdma_connected: - raise Exception("this function is only valid for connected rdma") - if key == "": - raise Exception("key is empty") - if size == 0: - raise Exception("size is 0") - if ptr == 0: - raise Exception("ptr is 0") + return self.conn.r_tcp(key) - await self.semaphore.acquire() - - remote_addrs = await self.allocate_rdma_async([key], size) - - loop = asyncio.get_running_loop() - future = loop.create_future() - - def _callback(): - loop.call_soon_threadsafe(future.set_result, 0) - self.semaphore.release() - - ret = self.conn.w_rdma_async([0], size, remote_addrs, ptr, _callback) - if ret < 0: - raise Exception(f"Failed to write to infinistore, ret = {ret}") - - return await future - - def rdma_write_cache_single(self, key: str, ptr: int, size: int, **kwargs): + def tcp_write_cache(self, key: str, ptr: int, size: int, **kwargs): """ - Perform an RDMA write operation to cache a single item in the remote memory. + Writes a single cache entry to the remote memory using TCP. Args: - key (str): The key associated with the data to be written. - ptr (int): The local memory pointer to the data to be written. - size (int): The size of the data to be written. + key (str): The key of the cache entry to write. + ptr (int): The pointer to the memory location where the data should be written. + size (int): The size of the data to write. **kwargs: Additional keyword arguments. Raises: Exception: If the key is empty. Exception: If the size is 0. - Exception: If the ptr is 0. - Exception: If the RDMA write operation fails. - - Returns: - None + Exception: If the pointer is 0. + Exception: If the write operation fails. """ if key == "": raise Exception("key is empty") @@ -451,157 +363,58 @@ def rdma_write_cache_single(self, key: str, ptr: int, size: int, **kwargs): raise Exception("size is 0") if ptr == 0: raise Exception("ptr is 0") - # allocate remote rdma memory - remote_addrs = self.allocate_rdma([key], size) - - assert len(remote_addrs) == 1 - ret = self.conn.w_rdma( - [0], - size, - remote_addrs, - ptr, - ) - if ret < 0: - raise Exception(f"Failed to write to infinistore, ret = {ret}") - return - - def close(self): - """ - Closes the connection to the Infinistore instance. - """ - self.conn.close() - - def rdma_write_cache( - self, cache: torch.Tensor, offsets: List[int], page_size, remote_blocks: List - ): - """ - Writes the given cache tensor to remote memory using RDMA (Remote Direct Memory Access). - - Args: - cache (torch.Tensor): The tensor containing the data to be written to remote memory. - offsets (List[int]): A list of offsets (in elements) where the data should be written. - page_size (int): The size of each page to be written, in elements. - remote_blocks (List): A list of remote memory blocks where the data should be written. - - Raises: - AssertionError: If RDMA is not connected. - Exception: If the RDMA write operation fails. - - Returns: - int: Returns 0 on success. - """ - - assert self.rdma_connected - self._verify(cache) - ptr = cache.data_ptr() - element_size = cache.element_size() - - # each offset should multiply by the element size - offsets_in_bytes = [offset * element_size for offset in offsets] - - ret = self.conn.w_rdma( - offsets_in_bytes, - page_size * element_size, - remote_blocks, - ptr, - ) + ret = self.conn.w_tcp(key, ptr, size) if ret < 0: raise Exception(f"Failed to write to infinistore, ret = {ret}") - return 0 - async def read_cache_async( - self, cache: torch.Tensor, blocks: List[Tuple[str, int]], page_size: int + async def rdma_write_cache_async( + self, blocks: List[Tuple[str, int]], block_size: int, ptr: int ): - """ - Asynchronously reads data from the RDMA cache into the provided tensor. - - Args: - cache (torch.Tensor): The tensor to read data into. - blocks (List[Tuple[str, int]]): A list of tuples where each tuple contains a key and an offset. - page_size (int): The size of each page to read. - - Raises: - Exception: If RDMA is not connected or if reading from Infinistore fails. - Exception: If the tensor is not contiguous. - - - Returns: - None: This function returns None but completes the future when the read operation is done. - """ if not self.rdma_connected: raise Exception("this function is only valid for connected rdma") - self._verify(cache) - ptr = cache.data_ptr() - element_size = cache.element_size() - blocks_in_bytes = [(key, offset * element_size) for key, offset in blocks] + await self.semaphore.acquire() loop = asyncio.get_running_loop() future = loop.create_future() - await self.semaphore.acquire() + keys, offsets = zip(*blocks) def _callback(code): - if code == 404: - loop.call_soon_threadsafe( - future.set_exception, InfiniStoreKeyNotFound("some keys not found") - ) - elif code != 200: + if code != 200: loop.call_soon_threadsafe( future.set_exception, - Exception(f"Failed to read to infinistore, ret = {code}"), + Exception(f"Failed to write to infinistore, ret = {code}"), ) else: loop.call_soon_threadsafe(future.set_result, code) self.semaphore.release() - ret = self.conn.r_rdma_async( - blocks_in_bytes, - page_size * element_size, + ret = self.conn.w_rdma_async( + keys, + offsets, + block_size, ptr, _callback, ) - if ret < 0: - raise Exception(f"Failed to read to infinistore, ret = {ret}") + raise Exception(f"Failed to write to infinistore, ret = {ret}") return await future - async def read_cache_single_async(self, key: str, ptr: int, size: int, **kwargs): - """ - Asynchronously reads a single cache entry from the InfiniStore. - - Args: - key (str): The key of the cache entry to read. - ptr (int): The pointer to the memory location where the data should be read. - size (int): The size of the data to read. - **kwargs: Additional keyword arguments. - - Raises: - Exception: If the key is empty. - Exception: If the size is 0. - Exception: If the ptr is 0. - Exception: If async read for local GPU is not supported. - InfiniStoreKeyNotFound: If the key is not found in the InfiniStore. - Exception: If there is a failure in reading from the InfiniStore. - - Returns: - int: The result code of the read operation. - """ - if key == "": - raise Exception("key is empty") - if size == 0: - raise Exception("size is 0") - if ptr == 0: - raise Exception("ptr is 0") + async def rdma_read_cache_async( + self, blocks: List[Tuple[str, int]], block_size: int, ptr: int + ): + if not self.rdma_connected: + raise Exception("this function is only valid for connected rdma") + pass + await self.semaphore.acquire() loop = asyncio.get_running_loop() future = loop.create_future() - await self.semaphore.acquire() - def _callback(code): if code == 404: loop.call_soon_threadsafe( - future.set_exception, InfiniStoreKeyNotFound(f"Key {key} not found") + future.set_exception, InfiniStoreKeyNotFound("some keys not found") ) elif code != 200: loop.call_soon_threadsafe( @@ -612,103 +425,18 @@ def _callback(code): loop.call_soon_threadsafe(future.set_result, code) self.semaphore.release() - ret = self.conn.r_rdma_async([(key, 0)], size, ptr, _callback) - + keys, offsets = zip(*blocks) + ret = self.conn.r_rdma_async( + keys, + offsets, + block_size, + ptr, + _callback, + ) if ret < 0: raise Exception(f"Failed to read to infinistore, ret = {ret}") - return await future - def read_cache_single(self, key: str, ptr: int, size: int, **kwargs): - """ - Reads a single cache entry from the infinistore. - - Parameters: - key (str): The key of the cache entry to read. - ptr (int): The pointer to the memory location where the data should be read. - size (int): The size of the data to read. - kwargs: Additional keyword arguments. - - Raises: - Exception: If the key is empty. - Exception: If the size is 0. - Exception: If the ptr is 0. - Exception: If not connected to any instance. - Exception: If the read operation fails. - """ - if key == "": - raise Exception("key is empty") - if size == 0: - raise Exception("size is 0") - if ptr == 0: - raise Exception("ptr is 0") - ret = 0 - if self.rdma_connected: - ret = self.conn.r_rdma( - [(key, 0)], - size, - ptr, - ) - else: - raise Exception("Not connected to any instance") - if ret < 0: - raise Exception(f"Failed to read to infinistore, ret = {ret}") - - def read_cache( - self, cache: torch.Tensor, blocks: List[Tuple[str, int]], page_size: int - ): - """ - Reads data from the cache using either RDMA connection. - - Args: - cache (torch.Tensor): The tensor containing the cache data. - blocks (List[Tuple[str, int]]): A list of tuples where each tuple contains a key and an offset. - each pair represents a page to be written to. The page is fixed size and is specified by the page_size parameter. - page_size (int): The size of the page to read. - - Raises: - Exception: If the read operation fails or if not connected to any instance. - """ - self._verify(cache) - ptr = cache.data_ptr() - element_size = cache.element_size() - # each offset should multiply by the element size - blocks_in_bytes = [(key, offset * element_size) for key, offset in blocks] - if self.rdma_connected: - ret = self.conn.r_rdma( - blocks_in_bytes, - page_size * element_size, - ptr, - ) - if ret < 0: - raise Exception(f"Failed to read to infinistore, ret = {ret}") - else: - raise Exception("Not connected to any instance") - - def sync(self): - """ - Synchronizes the current instance with the connected infinistore instance. - This method attempts to synchronize the current instance using either a local - connection or an RDMA connection. If neither connection is available, it raises - an exception. - Raises: - Exception: If not connected to any instance. - Exception: If synchronization fails with a negative return code. - """ - ret = 0 - if self.rdma_connected: - ret = self.conn.sync_rdma() - else: - raise Exception("Not connected to any instance") - - if ret < 0: - raise Exception(f"Failed to sync to infinistore, ret = {ret}") - return - - def _verify(self, cache: torch.Tensor): - if cache.is_contiguous() is False: - raise Exception("Tensor must be contiguous") - def check_exist(self, key: str): """ Check if a given key exists in the store. @@ -746,13 +474,12 @@ def get_match_last_index(self, keys: List[str]): return ret @singledispatchmethod - def register_mr(self, arg: Union[torch.Tensor, int], size: Optional[int] = None): + def register_mr(self, arg: Union[int], size: Optional[int] = None): """ Registers a memory region (MR) for the given argument. Args: - arg (Union[torch.Tensor, int]): The argument for which the memory region is to be registered. - It can be either a torch.Tensor or an pointer. + arg (Union[int]): The argument for which the memory region is to be registered. size (Optional[int], optional): The size of the memory region to be registered. Defaults to None. Raises: @@ -784,96 +511,6 @@ def _(self, ptr: int, size): raise Exception("register memory region failed") return ret - @register_mr.register - def _(self, cache: torch.Tensor, size: Optional[int] = None): - """ - Registers a memory region (MR) for a torch.Tensor. - - Args: - cache (torch.Tensor): The tensor for which the memory region is to be registered. - size (Optional[int], optional): The size of the memory region. Defaults to None. - - Raises: - Exception: If the RDMA connection is not established. - Exception: If the memory region registration fails. - - Returns: - int: The result of the memory region registration. - """ - self._verify(cache) - ptr = cache.data_ptr() - element_size = cache.element_size() - if not self.rdma_connected: - raise Exception("this function is only valid for connected rdma") - - ret = self.conn.register_mr(ptr, cache.numel() * element_size) - if ret < 0: - raise Exception("register memory region failed") - return ret - - async def allocate_rdma_async(self, keys: List[str], page_size_in_bytes: int): - """ - Asynchronously allocate RDMA (Remote Direct Memory Access) resources for the given keys. - - This function initiates an asynchronous RDMA allocation request and returns a future - that will be completed when the allocation is done. The allocation is performed by - invoking a callback function from the C++ code. - - Args: - keys (List[str]): A list of keys for which RDMA resources are to be allocated. - page_size_in_bytes (int): The size of each page in bytes. - - Raises: - Exception: If the RDMA connection is not established. - - Returns: - Awaitable: A future that will be set with the remote addresses once the allocation is complete. - """ - if not self.rdma_connected: - raise Exception("this function is only valid for connected rdma") - - loop = asyncio.get_running_loop() - future = loop.create_future() - - def _callback(remote_addrs, error_code): - # _callback is invoked by the C++ code in cq_thread, - # so we need to call_soon_threadsafe - if error_code != 200: - loop.call_soon_threadsafe( - future.set_exception, - Exception("allocate memory failed, error code: " + str(error_code)), - ) - else: - loop.call_soon_threadsafe(future.set_result, remote_addrs) - - self.conn.allocate_rdma_async(keys, page_size_in_bytes, _callback) - - return await future - - def allocate_rdma(self, keys: List[str], page_size_in_bytes: int) -> List[Tuple]: - """ - Allocates RDMA memory for the given keys. For RDMA writes, user must first allocate RDMA memory. - and then use the allocated RDMA memory address to write data to the remote memory. - - Args: - keys (List[str]): A list of keys for which RDMA memory is to be allocated. - page_size_in_bytes (int): The size of each page in bytes. - - Returns: - List: A list of allocated RDMA memory addresses. - - Raises: - Exception: If RDMA is not connected. - Exception: If memory allocation fails. - """ - if not self.rdma_connected: - raise Exception("this function is only valid for connected rdma") - - ret = self.conn.allocate_rdma(keys, page_size_in_bytes) - if len(ret) == 0: - raise Exception("allocate memory failed") - return ret - def delete_keys(self, keys: List[str]): """ Delete a list of keys diff --git a/infinistore/server.py b/infinistore/server.py index e950ff4..cd67eed 100644 --- a/infinistore/server.py +++ b/infinistore/server.py @@ -1,19 +1,15 @@ -import infinistore import uuid from infinistore import ( register_server, purge_kv_map, get_kvmap_len, - check_supported, ServerConfig, Logger, ) - import asyncio import uvloop from fastapi import FastAPI import uvicorn -import torch import argparse import logging import os @@ -37,77 +33,11 @@ def generate_uuid(): return str(uuid.uuid4()) -@app.post("/selftest/{number}") -async def selftest(number: int): - Logger.info("selftest") - - config = infinistore.ClientConfig( - host_addr="127.0.0.1", - service_port=number, - log_level="info", - connection_type=infinistore.TYPE_RDMA, - ib_port=1, - link_type=infinistore.LINK_ETHERNET, - dev_name="mlx5_2", - ) - - rdma_conn = infinistore.InfinityConnection(config) - - await rdma_conn.connect_async() - - def blocking_io(rdma_conn): - src_tensor = torch.tensor( - [i for i in range(4096)], device="cpu", dtype=torch.float32 - ) - dst_tensor = torch.zeros(4096, device="cpu", dtype=torch.float32) - rdma_conn.register_mr(src_tensor) - rdma_conn.register_mr(dst_tensor) - return src_tensor, dst_tensor - - src_tensor, dst_tensor = await asyncio.to_thread(blocking_io, rdma_conn) - - # keys = ["key1", "key2", "key3"] - keys = [generate_uuid() for i in range(3)] - remote_addr = await rdma_conn.allocate_rdma_async(keys, 1024 * 4) - print(f"remote addrs is {remote_addr}") - - await rdma_conn.rdma_write_cache_async(src_tensor, [0, 1024], 1024, remote_addr[:2]) - await rdma_conn.rdma_write_cache_async(src_tensor, [2048], 1024, remote_addr[2:]) - - # # await asyncio.gather(rdma_conn.rdma_write_cache_async(src_tensor, [0, 1024], 1024, remote_addr[:2]), - # # rdma_conn.rdma_write_cache_async(src_tensor, [2048], 1024, remote_addr[2:])) - - await rdma_conn.read_cache_async( - dst_tensor, [(keys[0], 0), (keys[1], 1024), (keys[2], 2048)], 1024 - ) - - # put assert into asyncio.to_thread - - assert await asyncio.to_thread(torch.equal, src_tensor[0:3072], dst_tensor[0:3072]) - - # assert torch.equal(, - rdma_conn.close() - return {"status": "ok"} - - @app.get("/kvmap_len") async def kvmap_len(): return {"len": get_kvmap_len()} -def check_p2p_access(): - num_devices = torch.cuda.device_count() - for i in range(num_devices): - for j in range(num_devices): - if i != j: - can_access = torch.cuda.can_device_access_peer(i, j) - if can_access: - # print(f"Peer access supported between device {i} and {j}") - pass - else: - Logger.warn(f"Peer access NOT supported between device {i} and {j}") - - def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -213,8 +143,6 @@ def main(): auto_increase=args.auto_increase, ) config.verify() - # check_p2p_access() - check_supported() Logger.set_log_level(config.log_level) Logger.info(config) diff --git a/infinistore/test_infinistore.py b/infinistore/test_infinistore.py index 6d4f768..ea4453e 100644 --- a/infinistore/test_infinistore.py +++ b/infinistore/test_infinistore.py @@ -9,6 +9,7 @@ import string import asyncio import json +import ctypes from multiprocessing import Process @@ -57,7 +58,7 @@ def server(): def generate_random_string(length): - letters_and_digits = string.ascii_letters + string.digits # 字母和数字的字符集 + letters_and_digits = string.ascii_letters + string.digits random_string = "".join(random.choice(letters_and_digits) for i in range(length)) return random_string @@ -95,8 +96,7 @@ def get_gpu_count(): @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -@pytest.mark.parametrize("new_connection", [True, False]) -def test_basic_read_write_cache(server, dtype, new_connection): +def test_basic_read_write_cache(server, dtype): config = infinistore.ClientConfig( host_addr="127.0.0.1", service_port=92345, @@ -113,18 +113,21 @@ def test_basic_read_write_cache(server, dtype, new_connection): key = generate_random_string(10) src = [i for i in range(4096)] - # local GPU write is tricky, we need to disable the pytorch allocator's caching src_tensor = torch.tensor(src, device="cuda:0", dtype=dtype) torch.cuda.synchronize(src_tensor.device) - conn.register_mr(src_tensor) + conn.register_mr( + src_tensor.data_ptr(), src_tensor.numel() * src_tensor.element_size() + ) element_size = torch._utils._element_size(dtype) - remote_addrs = conn.allocate_rdma([key], 4096 * element_size) - conn.rdma_write_cache(src_tensor, [0], 4096, remote_addrs) + async def run_write(): + await conn.rdma_write_cache_async( + [(key, 0)], len(src) * element_size, src_tensor.data_ptr() + ) - conn.sync() + asyncio.run(run_write()) conn.close() conn = infinistore.InfinityConnection(config) @@ -132,9 +135,14 @@ def test_basic_read_write_cache(server, dtype, new_connection): dst = torch.zeros(4096, device="cuda:0", dtype=dtype) - conn.register_mr(dst) - conn.read_cache(dst, [(key, 0)], 4096) - conn.sync() + conn.register_mr(dst.data_ptr(), dst.numel() * dst.element_size()) + + async def run_read(): + await conn.rdma_read_cache_async( + [(key, 0)], len(dst) * element_size, dst.data_ptr() + ) + + asyncio.run(run_read()) assert torch.equal(src_tensor, dst) conn.close() @@ -172,30 +180,37 @@ def test_batch_read_write_cache(server, separated_gpu): src_tensor = torch.tensor(src, device=src_device, dtype=torch.float32) torch.cuda.synchronize(src_tensor.device) - # write/read 3 times - for i in range(3): - keys = [generate_random_string(num_of_blocks) for i in range(10)] - blocks = [(keys[i], i * block_size) for i in range(num_of_blocks)] - conn.register_mr(src_tensor) - remote_addrs = conn.allocate_rdma(keys, block_size * 4) - conn.rdma_write_cache( - src_tensor, - [i * block_size for i in range(num_of_blocks)], - block_size, - remote_addrs, - ) - - conn.sync() - - dst = torch.zeros( - num_of_blocks * block_size, device=dst_device, dtype=torch.float32 - ) + async def run(): + # write/read 3 times + for i in range(3): + keys = [generate_random_string(num_of_blocks) for i in range(10)] + await asyncio.to_thread( + conn.register_mr, + src_tensor.data_ptr(), + src_tensor.numel() * src_tensor.element_size(), + ) + + blocks_offsets = [ + (keys[i], i * block_size * 4) for i in range(num_of_blocks) + ] + + await conn.rdma_write_cache_async( + blocks_offsets, block_size * 4, src_tensor.data_ptr() + ) + + dst = torch.zeros( + num_of_blocks * block_size, device=dst_device, dtype=torch.float32 + ) + await asyncio.to_thread( + conn.register_mr, dst.data_ptr(), dst.numel() * dst.element_size() + ) + + await conn.rdma_read_cache_async( + blocks_offsets, block_size * 4, dst.data_ptr() + ) + assert torch.equal(src_tensor.cpu(), dst.cpu()) - conn.register_mr(dst) - conn.read_cache(dst, blocks, block_size) - conn.sync() - # import pdb; pdb.set_trace() - assert torch.equal(src_tensor.cpu(), dst.cpu()) + asyncio.run(run()) conn.close() @@ -221,22 +236,26 @@ def run(): src_tensor = torch.tensor(src, device="cuda:0", dtype=torch.float32) torch.cuda.synchronize(src_tensor.device) - conn.register_mr(src_tensor) + conn.register_mr( + src_tensor.data_ptr(), src_tensor.numel() * src_tensor.element_size() + ) element_size = torch._utils._element_size(torch.float32) - remote_addrs = conn.allocate_rdma([key], 4096 * element_size) - conn.rdma_write_cache(src_tensor, [0], 4096, remote_addrs) - - conn.sync() + asyncio.run( + conn.rdma_write_cache_async( + [(key, 0)], 4096 * element_size, src_tensor.data_ptr() + ) + ) conn.close() conn = infinistore.InfinityConnection(config) conn.connect() dst = torch.zeros(4096, device="cuda:0", dtype=torch.float32) - conn.register_mr(dst) - conn.read_cache(dst, [(key, 0)], 4096) - conn.sync() + conn.register_mr(dst.data_ptr(), dst.numel() * dst.element_size()) + asyncio.run( + conn.rdma_read_cache_async([(key, 0)], 4096 * element_size, dst.data_ptr()) + ) assert torch.equal(src_tensor, dst) conn.close() @@ -261,13 +280,10 @@ def test_key_check(server): conn.connect() key = generate_random_string(5) src = torch.randn(4096, device="cuda", dtype=torch.float32) - conn.register_mr(src) - remote_addrs = conn.allocate_rdma([key], 4096 * 4) - + conn.register_mr(src.data_ptr(), src.numel() * src.element_size()) torch.cuda.synchronize(src.device) - conn.rdma_write_cache(src, [0], 4096, remote_addrs) - conn.sync() + asyncio.run(conn.rdma_write_cache_async([(key, 0)], 4096 * 4, src.data_ptr())) assert conn.check_exist(key) conn.close() @@ -283,12 +299,14 @@ def test_get_match_last_index(server): conn = infinistore.InfinityConnection(config) conn.connect() src = torch.randn(4096, device="cuda", dtype=torch.float32) - conn.register_mr(src) - remote_addrs = conn.allocate_rdma(["key1", "key2", "key3"], 4096 * 4) - torch.cuda.synchronize(src.device) - conn.rdma_write_cache(src, [0, 1024, 2048], 4096, remote_addrs) + conn.register_mr(src.data_ptr(), src.numel() * src.element_size()) + asyncio.run( + conn.rdma_write_cache_async( + [("key1", 0), ("key2", 1024), ("key3", 2048)], 1024 * 4, src.data_ptr() + ) + ) assert conn.get_match_last_index(["A", "B", "C", "key1", "D", "E"]) == 3 conn.close() @@ -308,10 +326,10 @@ async def run(): await conn.connect_async() key = "not_exist_key" dst = torch.randn(4096, device="cuda", dtype=torch.float32) - conn.register_mr(dst) + conn.register_mr(dst.data_ptr(), dst.numel() * dst.element_size()) # expect raise exception with pytest.raises(Exception): - await conn.read_cache_async(dst, [(key, 0)], 4096) + await conn.rdma_read_cache_async([(key, 0)], 4096 * 4, dst.data_ptr()) finally: conn.close() @@ -336,107 +354,33 @@ def test_upload_cpu_download_gpu(server): src_conn = infinistore.InfinityConnection(src_config) src_conn.connect() + dst_conn = infinistore.InfinityConnection(dst_config) + dst_conn.connect() + key = generate_random_string(5) src = torch.randn(4096, dtype=torch.float32, device="cpu") # NOTE: not orch.cuda.synchronize required for CPU tensor - src_conn.register_mr(src) - remote_addrs = src_conn.allocate_rdma([key], 4096 * 4) - src_conn.rdma_write_cache(src, [0], 4096, remote_addrs) - src_conn.sync() - src_conn.close() - - dst_conn = infinistore.InfinityConnection(dst_config) - dst_conn.connect() + src_conn.register_mr(src.data_ptr(), src.numel() * src.element_size()) dst = torch.zeros(4096, dtype=torch.float32, device="cuda:0") - dst_conn.register_mr(dst) - dst_conn.read_cache(dst, [(key, 0)], 4096) - dst_conn.sync() - assert torch.equal(src, dst.cpu()) - dst_conn.close() + dst_conn.register_mr(dst.data_ptr(), dst.numel() * dst.element_size()) - -def test_deduplicate(server): - config = infinistore.ClientConfig( - host_addr="127.0.0.1", - service_port=92345, - link_type=infinistore.LINK_ETHERNET, - dev_name=f"{RDMA_DEV[0]}", - ) - - config.connection_type = infinistore.TYPE_RDMA - - conn = infinistore.InfinityConnection(config) - conn.connect() - - key = "duplicate_key" - src = [i for i in range(4096)] - src_tensor = torch.tensor(src, device="cuda:0", dtype=torch.float32) - - torch.cuda.synchronize(src_tensor.device) - conn.register_mr(src_tensor) - element_size = torch._utils._element_size(torch.float32) - - remote_addrs = conn.allocate_rdma([key], 4096 * element_size) - print(remote_addrs) - conn.rdma_write_cache(src_tensor, [0], 4096, remote_addrs) - - conn.sync() - - src2_tensor = torch.randn(4096, device="cuda:0", dtype=torch.float32) - - # test_deduplicate - conn.register_mr(src2_tensor) - element_size = torch._utils._element_size(torch.float32) - - remote_addrs = conn.allocate_rdma([key], 4096 * element_size) - conn.rdma_write_cache(src_tensor, [0], 4096, remote_addrs) - - conn.sync() - - dst_tensor = torch.zeros(4096, dtype=torch.float32, device="cpu") - conn.register_mr(dst_tensor) - - conn.read_cache(dst_tensor, [(key, 0)], 4096) - conn.sync() - - assert torch.equal(src_tensor.cpu(), dst_tensor.cpu()) - assert not torch.equal(src2_tensor.cpu(), dst_tensor.cpu()) - conn.close() - - -def test_async_api(server): - config = infinistore.ClientConfig( - host_addr="127.0.0.1", - service_port=92345, - link_type=infinistore.LINK_ETHERNET, - dev_name=f"{RDMA_DEV[0]}", - connection_type=infinistore.TYPE_RDMA, - ) - conn = infinistore.InfinityConnection(config) - - # use asyncio async def run(): - await conn.connect_async() - key = generate_random_string(5) - src = torch.randn(4096, device="cuda", dtype=torch.float32) - dst = torch.zeros(4096, device="cuda", dtype=torch.float32) + await src_conn.rdma_write_cache_async([(key, 0)], 4096 * 4, src.data_ptr()) + await dst_conn.rdma_read_cache_async([(key, 0)], 4096 * 4, dst.data_ptr()) + assert torch.equal(src, dst.cpu()) - def register_mr(): - conn.register_mr(src) - conn.register_mr(dst) + asyncio.run(run()) + src_conn.close() + dst_conn.close() - await asyncio.to_thread(register_mr) - remote_addrs = await conn.allocate_rdma_async([key], 4096 * 4) - await conn.rdma_write_cache_async(src, [0], 4096, remote_addrs) - await conn.read_cache_async(dst, [(key, 0)], 4096) - assert torch.equal(src, dst) - conn.close() - asyncio.run(run()) +def test_overwrite(server): + # FIXME: implement this test + pass -def test_single_async_api(server): +def test_async_api(server): config = infinistore.ClientConfig( host_addr="127.0.0.1", service_port=92345, @@ -444,7 +388,6 @@ def test_single_async_api(server): dev_name=f"{RDMA_DEV[0]}", connection_type=infinistore.TYPE_RDMA, ) - conn = infinistore.InfinityConnection(config) # use asyncio @@ -455,14 +398,12 @@ async def run(): dst = torch.zeros(4096, device="cuda", dtype=torch.float32) def register_mr(): - conn.register_mr(src) - conn.register_mr(dst) + conn.register_mr(src.data_ptr(), src.numel() * src.element_size()) + conn.register_mr(dst.data_ptr(), dst.numel() * dst.element_size()) await asyncio.to_thread(register_mr) - - await conn.rdma_write_cache_single_async(key, src.data_ptr(), 4096 * 4) - - await conn.read_cache_single_async(key, dst.data_ptr(), 4096 * 4) + await conn.rdma_write_cache_async([(key, 0)], 4096 * 4, src.data_ptr()) + await conn.rdma_read_cache_async([(key, 0)], 4096 * 4, dst.data_ptr()) assert torch.equal(src, dst) conn.close() @@ -484,12 +425,12 @@ async def run(): try: await conn.connect_async() dst = torch.zeros(4096, device="cuda", dtype=torch.float32) - await asyncio.to_thread(conn.register_mr, dst) - with pytest.raises(infinistore.InfiniStoreKeyNotFound): - await conn.read_cache_async(dst, [("non_exist_key", 0)], 4096) + await asyncio.to_thread( + conn.register_mr, dst.data_ptr(), dst.numel() * dst.element_size() + ) with pytest.raises(infinistore.InfiniStoreKeyNotFound): - await conn.read_cache_single_async( - "non_exist_key", dst.data_ptr(), 4096 * 4 + await conn.rdma_read_cache_async( + [("non_exist_key", 0)], 4096 * 4, dst.data_ptr() ) finally: conn.close() @@ -543,22 +484,22 @@ def test_delete_keys(server, test_dtype): conn.connect() src_tensor = torch.randn(BLOCK_SIZE, device="cuda", dtype=test_dtype) - # Generate the names of the keys keys = [generate_random_string(10) for i in range(KEY_COUNT)] - conn.register_mr(src_tensor) - # Allocate BLOB_SIZE elements for each key - remote_addrs = conn.allocate_rdma(keys, BLOB_SIZE) - - conn.rdma_write_cache( - src_tensor, - [i * BLOB_SIZE for i in range(KEY_COUNT)], - BLOB_SIZE, - remote_addrs, + conn.register_mr( + src_tensor.data_ptr(), src_tensor.numel() * src_tensor.element_size() ) + element_size = torch._utils._element_size(test_dtype) - torch.cuda.synchronize(src_tensor.device) - conn.sync() + async def run(): + block_offsets = [ + (keys[i], i * BLOB_SIZE * element_size) for i in range(KEY_COUNT) + ] + await conn.rdma_write_cache_async( + block_offsets, BLOB_SIZE * element_size, src_tensor.data_ptr() + ) + + asyncio.run(run()) # Check all the keys exist for i in range(KEY_COUNT): @@ -571,3 +512,33 @@ def test_delete_keys(server, test_dtype): assert conn.check_exist(keys[1]) assert not conn.check_exist(keys[0]) assert not conn.check_exist(keys[2]) + conn.close() + + +def get_ptr(mv: memoryview): + return ctypes.addressof(ctypes.c_char.from_buffer(mv)) + + +def test_simple_tcp_read_write(server): + config = infinistore.ClientConfig( + host_addr="127.0.0.1", + service_port=92345, + connection_type=infinistore.TYPE_TCP, + ) + + try: + conn = infinistore.InfinityConnection(config) + conn.connect() + key = generate_random_string(10) + size = 256 * 1024 + src = bytearray(size) + for i in range(size): + src[i] = i % 200 + conn.tcp_write_cache(key, get_ptr(src), len(src)) + + dst = conn.tcp_read_cache(key) + assert len(dst) == len(src) + for i in range(len(src)): + assert dst[i] == src[i] + finally: + conn.close() diff --git a/src/Makefile b/src/Makefile index d126955..7fda7af 100644 --- a/src/Makefile +++ b/src/Makefile @@ -1,5 +1,5 @@ CXX = g++ -CXXFLAGS = -std=c++17 -Wall -O2 -g +CXXFLAGS = -std=c++17 -Wall -O3 -g INCLUDES = -I/usr/local/ LDFLAGS = -rdynamic diff --git a/src/infinistore.cpp b/src/infinistore.cpp index 2ff3b13..8dadac6 100644 --- a/src/infinistore.cpp +++ b/src/infinistore.cpp @@ -48,8 +48,12 @@ std::unordered_map> kv_map; typedef enum { READ_HEADER, READ_BODY, + READ_VALUE_THROUGH_TCP, } read_state_t; +// the max data could be send in uv_write +static const size_t MAX_SEND_SIZE = 256 << 10; + struct Client { uv_tcp_t *handle_ = NULL; // uv_stream_t read_state_t state_; // state of the client, for parsing the request @@ -57,6 +61,8 @@ struct Client { size_t expected_bytes_ = 0; // expected size of the body header_t header_; + boost::intrusive_ptr current_tcp_task_; + // RDMA recv buffer char *recv_buffer_[MAX_RECV_WR] = {}; struct ibv_mr *recv_mr_[MAX_RECV_WR] = {}; @@ -64,8 +70,8 @@ struct Client { // RDMA send buffer char *send_buffer_ = NULL; struct ibv_mr *send_mr_ = NULL; - int outstanding_rdma_writes_ = 0; - std::deque> outstanding_rdma_writes_queue_; + int outstanding_rdma_ops_ = 0; + std::deque> outstanding_rdma_ops_queue_; // TCP send buffer char *tcp_send_buffer_ = NULL; @@ -90,9 +96,12 @@ struct Client { void cq_poll_handle(uv_poll_t *handle, int status, int events); int read_rdma_cache(const RemoteMetaRequest *req); + int write_rdma_cache(const RemoteMetaRequest *req); + void post_ack(int return_code); int allocate_rdma(const RemoteMetaRequest *req); // send response to client through TCP void send_resp(int return_code, void *buf, size_t size); + int tcp_payload_request(const TCPPayloadRequest *request); int sync_stream(); void reset_client_read_state(); int check_key(const std::string &key_to_check); @@ -100,6 +109,9 @@ struct Client { int delete_keys(const DeleteKeysRequest *request); int rdma_exchange(); int prepare_recv_rdma_request(int buf_idx); + void perform_batch_rdma(const RemoteMetaRequest *remote_meta_req, + std::vector> *inflight_rdma_ops, + enum ibv_wr_opcode opcode); }; typedef struct Client client_t; @@ -172,9 +184,144 @@ Client::~Client() { } } -void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { - DEBUG("Polling CQ"); +void on_close(uv_handle_t *handle) { + client_t *client = (client_t *)handle->data; + delete client; +} + +struct BulkWriteCtx { + client_t *client; + uint32_t *header_buf; + boost::intrusive_ptr ptr; + size_t offset; + size_t total_size; +}; + +void on_chunk_write(uv_write_t *req, int status) { + BulkWriteCtx *ctx = (BulkWriteCtx *)req->data; + if (status < 0) { + ERROR("Write error {}", uv_strerror(status)); + uv_close((uv_handle_t *)req->handle, on_close); + free(req); + delete ctx; + return; + } + + if (ctx->offset == ctx->total_size) { + DEBUG("write done"); + ctx->client->reset_client_read_state(); + free(req); + delete ctx; + return; + } + size_t remain = ctx->total_size - ctx->offset; + size_t send_size = MIN(remain, MAX_SEND_SIZE); + uv_buf_t buf = uv_buf_init((char *)ctx->ptr->ptr + ctx->offset, send_size); + ctx->offset += send_size; + uv_write_t *write_req = (uv_write_t *)malloc(sizeof(uv_write_t)); + write_req->data = ctx; + uv_write(write_req, (uv_stream_t *)ctx->client->handle_, &buf, 1, on_chunk_write); + free(req); +} + +void on_head_write(uv_write_t *req, int status) { + BulkWriteCtx *ctx = (BulkWriteCtx *)req->data; + if (status < 0) { + ERROR("Write error {}", uv_strerror(status)); + free(ctx->header_buf); + delete ctx; + free(req); + uv_close((uv_handle_t *)req->handle, on_close); + return; + } + + DEBUG("header write done"); + size_t remain = ctx->total_size; + size_t send_size = MIN(remain, MAX_SEND_SIZE); + uv_buf_t buf = uv_buf_init((char *)ctx->ptr->ptr, send_size); + ctx->offset += send_size; + uv_write_t *write_req = (uv_write_t *)malloc(sizeof(uv_write_t)); + write_req->data = ctx; + uv_write(write_req, (uv_stream_t *)ctx->client->handle_, &buf, 1, on_chunk_write); + free(req); +} + +int Client::tcp_payload_request(const TCPPayloadRequest *req) { + DEBUG("do tcp_payload_request... {}", op_name(req->op())); + + switch (req->op()) { + case OP_TCP_PUT: { + int ret = mm->allocate(req->value_length(), 1, + [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { + current_tcp_task_ = boost::intrusive_ptr( + new PTR(addr, req->value_length(), pool_idx, false)); + }); + if (ret < 0) { + ERROR("Failed to allocate memory"); + return OUT_OF_MEMORY; + } + DEBUG("allocated memory: addr: {}, lkey: {}, rkey: {}", current_tcp_task_->ptr, + mm->get_lkey(current_tcp_task_->pool_idx), + mm->get_rkey(current_tcp_task_->pool_idx)); + + kv_map[req->key()->str()] = current_tcp_task_; + // set state machine + state_ = READ_VALUE_THROUGH_TCP; + bytes_read_ = 0; + expected_bytes_ = req->value_length(); + break; + } + case OP_TCP_GET: { + auto it = kv_map.find(req->key()->str()); + if (it == kv_map.end()) { + return KEY_NOT_FOUND; + } + if (!it->second->committed) { + return KEY_NOT_FOUND; + } + auto ptr = it->second; + + uint32_t *header_buf = (uint32_t *)malloc(sizeof(uint32_t) * 2); + header_buf[0] = FINISH; + header_buf[1] = static_cast(ptr->size); + + uv_write_t *write_req = (uv_write_t *)malloc(sizeof(uv_write_t)); + + // safe PTR to prevent it from being deleted early. + write_req->data = new BulkWriteCtx{.client = this, + .header_buf = header_buf, + .ptr = ptr, + .offset = 0, + .total_size = ptr->size}; + + uv_buf_t buf = uv_buf_init((char *)header_buf, sizeof(uint32_t) * 2); + + uv_write(write_req, (uv_stream_t *)handle_, &buf, 1, on_head_write); + + break; + } + } + return 0; +} + +void Client::post_ack(int return_code) { + // send an error code back + struct ibv_send_wr wr = {0}; + struct ibv_send_wr *bad_wr = NULL; + wr.wr_id = 0; + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.imm_data = return_code; + wr.send_flags = 0; + wr.sg_list = NULL; + wr.num_sge = 0; + wr.next = NULL; + int ret = ibv_post_send(qp_, &wr, &bad_wr); + if (ret) { + ERROR("Failed to send WITH_IMM message: {}", strerror(ret)); + } +} +void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { // TODO: handle completion if (status < 0) { ERROR("Poll error: {}", uv_strerror(status)); @@ -197,69 +344,23 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { while (ibv_poll_cq(cq, 1, &wc) > 0) { if (wc.status == IBV_WC_SUCCESS) { if (wc.opcode == IBV_WC_RECV) { // recv RDMA read/write request - INFO("RDMA Send completed successfully, recved {}", wc.byte_len); const RemoteMetaRequest *request = GetRemoteMetaRequest(recv_buffer_[wc.wr_id]); - INFO("Received remote meta request OP {}", op_name(request->op())); + DEBUG("Received remote meta request OP {}", op_name(request->op())); switch (request->op()) { - case OP_RDMA_READ: { - int ret = read_rdma_cache(request); - // send an error code back - struct ibv_send_wr wr = {0}; - struct ibv_send_wr *bad_wr = NULL; - wr.wr_id = 1; - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.imm_data = ret; - wr.send_flags = 0; - wr.sg_list = NULL; - wr.num_sge = 0; - wr.next = NULL; - ret = ibv_post_send(qp_, &wr, &bad_wr); - if (ret) { - ERROR("Failed to send WITH_IMM message: {}", strerror(ret)); + case OP_RDMA_WRITE: { + int ret = write_rdma_cache(request); + if (ret != 0) { + post_ack(ret); } break; } - case OP_RDMA_ALLOCATE: { - auto start = std::chrono::high_resolution_clock::now(); - allocate_rdma(request); - INFO("allocate_rdma time: {} micro seconds", - std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start) - .count()); - break; - } - case OP_RDMA_WRITE_COMMIT: { - INFO("RDMA write commit, #addrs: {}", request->remote_addrs()->size()); - if (request->remote_addrs()->size() == 0) { - ERROR("remote_addrs size should not be 0"); - } - for (auto addr : *request->remote_addrs()) { - auto it = inflight_rdma_writes.find(addr); - if (it == inflight_rdma_writes.end()) { - ERROR("commit msg: Key not found: {}", addr); - continue; - } - it->second->committed = true; - inflight_rdma_writes.erase(it); - } - - // send ACK - struct ibv_send_wr wr = {0}; - struct ibv_send_wr *bad_wr = NULL; - wr.wr_id = 1; - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.imm_data = 1234; - wr.send_flags = 0; - wr.sg_list = NULL; - wr.num_sge = 0; - wr.next = NULL; - int ret = ibv_post_send(qp_, &wr, &bad_wr); - if (ret) { - ERROR("Failed to send WITH_IMM message: {}", strerror(ret)); + case OP_RDMA_READ: { + int ret = read_rdma_cache(request); + if (ret != 0) { + post_ack(ret); } - DEBUG("inflight_rdma_kv_map size: {}", inflight_rdma_writes.size()); break; } default: @@ -267,7 +368,7 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { break; } - INFO("ready for next request"); + DEBUG("ready for next request"); if (prepare_recv_rdma_request(wc.wr_id) < 0) { ERROR("Failed to prepare recv rdma request"); return; @@ -286,15 +387,14 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { return; } } - else if (wc.opcode == IBV_WC_RDMA_WRITE) { + else if (wc.opcode == IBV_WC_RDMA_WRITE || wc.opcode == IBV_WC_RDMA_READ) { // some RDMA write(read cache WRs) is finished - DEBUG("RDMA_WRITE done wr_id: {}", wc.wr_id); - assert(outstanding_rdma_writes_ >= 0); - outstanding_rdma_writes_ -= MAX_WR_BATCH; + assert(outstanding_rdma_ops_ >= 0); + outstanding_rdma_ops_ -= MAX_WR_BATCH; - if (!outstanding_rdma_writes_queue_.empty()) { - auto item = outstanding_rdma_writes_queue_.front(); + if (!outstanding_rdma_ops_queue_.empty()) { + auto item = outstanding_rdma_ops_queue_.front(); struct ibv_send_wr *wrs = item.first; struct ibv_sge *sges = item.second; ibv_send_wr *bad_wr = nullptr; @@ -304,16 +404,29 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { ERROR("Failed to post RDMA write {}", strerror(ret)); throw std::runtime_error("Failed to post RDMA write"); } - outstanding_rdma_writes_ += MAX_WR_BATCH; + outstanding_rdma_ops_ += MAX_WR_BATCH; delete[] wrs; delete[] sges; - outstanding_rdma_writes_queue_.pop_front(); + outstanding_rdma_ops_queue_.pop_front(); } if (wc.wr_id > 0) { // last WR will inform that all RDMA write is finished,so we can dereference PTR - auto inflight_rdma_reads = (std::vector> *)wc.wr_id; - delete inflight_rdma_reads; + if (wc.opcode == IBV_WC_RDMA_READ) { + auto inflight_rdma_writes = + (std::vector> *)wc.wr_id; + for (auto ptr : *inflight_rdma_writes) { + ptr->committed = true; + } + delete inflight_rdma_writes; + post_ack(FINISH); + } + else if (wc.opcode == IBV_WC_RDMA_WRITE) { + post_ack(FINISH); + auto inflight_rdma_reads = + (std::vector> *)wc.wr_id; + delete inflight_rdma_reads; + } } } else { @@ -334,84 +447,6 @@ void add_mempool_completion(uv_work_t *req, int status) { delete req; } -int Client::allocate_rdma(const RemoteMetaRequest *req) { - INFO("do allocate_rdma..."); - - FixedBufferAllocator allocator(send_buffer_, PROTOCOL_BUFFER_SIZE); - FlatBufferBuilder builder(64 << 10, &allocator); - - int key_idx = 0; - int block_size = req->block_size(); - std::vector blocks; - blocks.reserve(req->keys()->size()); - - unsigned int error_code = FINISH; - if (!mm->allocate(block_size, req->keys()->size(), - [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { - // FIXME: rdma write should have a msg to update committed to true - - const auto *key = req->keys()->Get(key_idx); - - if (kv_map.count(key->str()) != 0) { - // WARN("rdma_write: Key already exists: {}", key->str()); - // put fake addr, and send to client - blocks.push_back(FAKE_REMOTE_BLOCK); - key_idx++; - return; - } - - auto ptr = - boost::intrusive_ptr(new PTR(addr, block_size, pool_idx, false)); - - // save in kv_map, but committed is false, no one can read it - kv_map[key->str()] = ptr; - - // save in inflight_rdma_kv_map, when write is finished, we can merge it - // into kv_map - inflight_rdma_writes[(uintptr_t)addr] = ptr; - - blocks.push_back(RemoteBlock(rkey, (uint64_t)addr)); - key_idx++; - })) { - ERROR("Failed to allocate memory"); - error_code = OUT_OF_MEMORY; - blocks.clear(); - } - - if (global_config.auto_increase && mm->need_extend && !extend_in_flight) { - INFO("Extend another mempool"); - uv_work_t *req = new uv_work_t(); - uv_queue_work(loop, req, add_mempool, add_mempool_completion); - extend_in_flight = true; - } - - auto resp = CreateRdmaAllocateResponseDirect(builder, &blocks, error_code); - builder.Finish(resp); - - // send RDMA request - struct ibv_sge sge = {0}; - struct ibv_send_wr wr = {0}; - struct ibv_send_wr *bad_wr = NULL; - - sge.addr = (uintptr_t)builder.GetBufferPointer(); - sge.length = builder.GetSize(); - sge.lkey = send_mr_->lkey; - - wr.wr_id = 0; - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - wr.send_flags = IBV_SEND_SIGNALED; - - int ret = ibv_post_send(qp_, &wr, &bad_wr); - if (ret) { - ERROR("Failed to post RDMA send :{}", strerror(ret)); - return -1; - } - - return 0; -} - int Client::prepare_recv_rdma_request(int buf_idx) { struct ibv_sge sge = {0}; struct ibv_recv_wr rwr = {0}; @@ -431,37 +466,10 @@ int Client::prepare_recv_rdma_request(int buf_idx) { return 0; } -int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { - INFO("do rdma read... num of keys: {}", remote_meta_req->keys()->size()); - - if (remote_meta_req->keys()->size() != remote_meta_req->remote_addrs()->size()) { - ERROR("keys size and remote_addrs size mismatch"); - return INVALID_REQ; - } - - auto *inflight_rdma_reads = new std::vector>; - - inflight_rdma_reads->reserve(remote_meta_req->keys()->size()); - - for (const auto *key : *remote_meta_req->keys()) { - auto it = kv_map.find(key->str()); - if (it == kv_map.end()) { - WARN("Key not found: {}", key->str()); - return KEY_NOT_FOUND; - } - - if (!it->second->committed) { - WARN("Key not committed: {}, return KEY_NOT_FOUND", key->str()); - return KEY_NOT_FOUND; - } - - const auto &ptr = it->second; - - DEBUG("rkey: {}, local_addr: {}, size : {}", mm->get_lkey(ptr->pool_idx), - (uintptr_t)ptr->ptr, ptr->size); - - inflight_rdma_reads->push_back(ptr); - } +void Client::perform_batch_rdma(const RemoteMetaRequest *remote_meta_req, + std::vector> *inflight_rdma_ops, + enum ibv_wr_opcode opcode) { + assert(opcode == IBV_WR_RDMA_READ || opcode == IBV_WR_RDMA_WRITE); const size_t max_wr = MAX_WR_BATCH; struct ibv_send_wr local_wrs[max_wr]; @@ -473,32 +481,37 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { size_t num_wr = 0; bool wr_full = false; - if (outstanding_rdma_writes_ + max_wr > MAX_RDMA_WRITE_WR) { + if (outstanding_rdma_ops_ + max_wr > MAX_RDMA_OPS_WR) { wr_full = true; wrs = new struct ibv_send_wr[max_wr]; sges = new struct ibv_sge[max_wr]; } - for (size_t i = 0; i < remote_meta_req->keys()->size(); i++) { - sges[num_wr].addr = (uintptr_t)(*inflight_rdma_reads)[i]->ptr; + + int n = remote_meta_req->keys()->size(); + for (int i = 0; i < n; i++) { + sges[num_wr].addr = (uintptr_t)(*inflight_rdma_ops)[i]->ptr; sges[num_wr].length = remote_meta_req->block_size(); - sges[num_wr].lkey = mm->get_lkey((*inflight_rdma_reads)[i]->pool_idx); + sges[num_wr].lkey = mm->get_lkey((*inflight_rdma_ops)[i]->pool_idx); wrs[num_wr].wr_id = 0; - wrs[num_wr].opcode = IBV_WR_RDMA_WRITE; + wrs[num_wr].opcode = opcode; wrs[num_wr].sg_list = &sges[num_wr]; wrs[num_wr].num_sge = 1; wrs[num_wr].wr.rdma.remote_addr = remote_meta_req->remote_addrs()->Get(i); wrs[num_wr].wr.rdma.rkey = remote_meta_req->rkey(); - wrs[num_wr].next = (num_wr == max_wr - 1 || i == remote_meta_req->keys()->size() - 1) + + // wrs[num_wr].wr.rdma.rkey = remote_meta_req->rkey(); + wrs[num_wr].next = (num_wr == max_wr - 1 || i == (int)remote_meta_req->keys()->size() - 1) ? nullptr : &wrs[num_wr + 1]; - wrs[num_wr].send_flags = (num_wr == max_wr - 1 || i == remote_meta_req->keys()->size() - 1) - ? IBV_SEND_SIGNALED - : 0; + wrs[num_wr].send_flags = + (num_wr == max_wr - 1 || i == (int)remote_meta_req->keys()->size() - 1) + ? IBV_SEND_SIGNALED + : 0; if (i == remote_meta_req->keys()->size() - 1) { - wrs[num_wr].wr_id = (uintptr_t)inflight_rdma_reads; + wrs[num_wr].wr_id = (uintptr_t)inflight_rdma_ops; } num_wr++; @@ -506,16 +519,15 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { if (num_wr == max_wr || i == remote_meta_req->keys()->size() - 1) { if (!wr_full) { struct ibv_send_wr *bad_wr = nullptr; - DEBUG("local write"); int ret = ibv_post_send(qp_, &wrs[0], &bad_wr); if (ret) { ERROR("Failed to post RDMA write {}", strerror(ret)); - return -1; + return; } - outstanding_rdma_writes_ += max_wr; + outstanding_rdma_ops_ += max_wr; // check if next iteration will exceed the limit - if (outstanding_rdma_writes_ + max_wr > MAX_RDMA_WRITE_WR) { + if (outstanding_rdma_ops_ + max_wr > MAX_RDMA_OPS_WR) { wr_full = true; } } @@ -526,7 +538,7 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { "last op code: {} ", num_wr, wrs[0].wr_id, wrs[num_wr - 1].wr_id, static_cast(wrs[num_wr - 1].opcode)); - outstanding_rdma_writes_queue_.push_back({&wrs[0], &sges[0]}); + outstanding_rdma_ops_queue_.push_back({&wrs[0], &sges[0]}); } if (wr_full) { @@ -537,8 +549,77 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { num_wr = 0; // Reset the counter for the next batch } } +} - return FINISH; +int Client::write_rdma_cache(const RemoteMetaRequest *remote_meta_req) { + DEBUG("do rdma write... num of keys: {}", remote_meta_req->keys()->size()); + if (remote_meta_req->keys()->size() != remote_meta_req->remote_addrs()->size()) { + ERROR("keys size and remote_addrs size mismatch"); + return INVALID_REQ; + } + + // allocate memory + int block_size = remote_meta_req->block_size(); + int n = remote_meta_req->keys()->size(); + + // create something. + + auto *inflight_rdma_writes = new std::vector>; + inflight_rdma_writes->reserve(n); + + int key_idx = 0; + bool allocated = + mm->allocate(block_size, n, [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { + const auto *key = remote_meta_req->keys()->Get(key_idx); + auto ptr = boost::intrusive_ptr(new PTR(addr, block_size, pool_idx, false)); + inflight_rdma_writes->push_back(ptr); + kv_map[key->str()] = ptr; + key_idx++; + }); + + if (!allocated) { + ERROR("Failed to allocate memory"); + delete inflight_rdma_writes; + return OUT_OF_MEMORY; + } + // perform rdma read to receive data from client + // read remote address data to local address + perform_batch_rdma(remote_meta_req, inflight_rdma_writes, IBV_WR_RDMA_READ); + return 0; +} + +int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { + DEBUG("do rdma read... num of keys: {}", remote_meta_req->keys()->size()); + + if (remote_meta_req->keys()->size() != remote_meta_req->remote_addrs()->size()) { + ERROR("keys size and remote_addrs size mismatch"); + return INVALID_REQ; + } + + auto *inflight_rdma_reads = new std::vector>; + + inflight_rdma_reads->reserve(remote_meta_req->keys()->size()); + + for (const auto *key : *remote_meta_req->keys()) { + auto it = kv_map.find(key->str()); + if (it == kv_map.end()) { + WARN("Key not found: {}", key->str()); + return KEY_NOT_FOUND; + } + + if (!it->second->committed) { + WARN("Key not committed: {}, return KEY_NOT_FOUND", key->str()); + return KEY_NOT_FOUND; + } + + const auto &ptr = it->second; + + inflight_rdma_reads->push_back(ptr); + } + + // write to remote address data from local address + perform_batch_rdma(remote_meta_req, inflight_rdma_reads, IBV_WR_RDMA_WRITE); + return 0; } // FIXME: @@ -550,11 +631,6 @@ void Client::reset_client_read_state() { // keep the tcp_recv_buffer/tcp_send_buffer as it is } -void on_close(uv_handle_t *handle) { - client_t *client = (client_t *)handle->data; - delete client; -} - void alloc_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf) { buf->base = (char *)malloc(suggested_size); buf->len = suggested_size; @@ -905,7 +981,6 @@ int Client::delete_keys(const DeleteKeysRequest *request) { void handle_request(uv_stream_t *stream, client_t *client) { auto start = std::chrono::high_resolution_clock::now(); int error_code = 0; - int op = client->header_.op; // if error_code is not 0, close the connection switch (client->header_.op) { case OP_RDMA_EXCHANGE: { @@ -916,7 +991,7 @@ void handle_request(uv_stream_t *stream, client_t *client) { } case OP_CHECK_EXIST: { std::string key_to_check(client->tcp_recv_buffer_, client->expected_bytes_); - INFO("check key: {}", key_to_check); + DEBUG("check key: {}", key_to_check); error_code = client->check_key(key_to_check); break; } @@ -927,11 +1002,17 @@ void handle_request(uv_stream_t *stream, client_t *client) { break; } case OP_DELETE_KEYS: { - INFO("delete keys..."); + DEBUG("delete keys..."); const DeleteKeysRequest *request = GetDeleteKeysRequest(client->tcp_recv_buffer_); error_code = client->delete_keys(request); break; } + case OP_TCP_PAYLOAD: { + DEBUG("TCP GET/PUT data..."); + const TCPPayloadRequest *request = GetTCPPayloadRequest(client->tcp_recv_buffer_); + error_code = client->tcp_payload_request(request); + break; + } default: ERROR("Invalid request"); error_code = INVALID_REQ; @@ -969,23 +1050,19 @@ void on_read(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) { if (client->bytes_read_ == FIXED_HEADER_SIZE) { DEBUG("HEADER: op: {}, body_size :{}", client->header_.op, (unsigned int)client->header_.body_size); - if (client->header_.op == OP_CHECK_EXIST || - client->header_.op == OP_GET_MATCH_LAST_IDX || - client->header_.op == OP_RDMA_EXCHANGE || - client->header_.op == OP_DELETE_KEYS) { - int ret = verify_header(&client->header_); - if (ret != 0) { - ERROR("Invalid header"); - uv_close((uv_handle_t *)stream, on_close); - goto clean_up; - } - // prepare for reading body - client->expected_bytes_ = client->header_.body_size; - client->bytes_read_ = 0; - client->tcp_recv_buffer_ = - (char *)realloc(client->tcp_recv_buffer_, client->expected_bytes_); - client->state_ = READ_BODY; + + int ret = verify_header(&client->header_); + if (ret != 0) { + ERROR("Invalid header"); + uv_close((uv_handle_t *)stream, on_close); + goto clean_up; } + // prepare for reading body + client->expected_bytes_ = client->header_.body_size; + client->bytes_read_ = 0; + client->tcp_recv_buffer_ = + (char *)realloc(client->tcp_recv_buffer_, client->expected_bytes_); + client->state_ = READ_BODY; } break; } @@ -1006,6 +1083,19 @@ void on_read(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) { } break; } + case READ_VALUE_THROUGH_TCP: { + size_t to_copy = MIN(nread - offset, client->expected_bytes_ - client->bytes_read_); + memcpy(client->current_tcp_task_->ptr + client->bytes_read_, buf->base + offset, + to_copy); + client->bytes_read_ += to_copy; + offset += to_copy; + if (client->bytes_read_ == client->expected_bytes_) { + client->current_tcp_task_->committed = true; + client->current_tcp_task_.reset(); + client->send_resp(FINISH, NULL, 0); + client->reset_client_read_state(); + } + } } } clean_up: diff --git a/src/libinfinistore.cpp b/src/libinfinistore.cpp index 37855c7..fc234a2 100644 --- a/src/libinfinistore.cpp +++ b/src/libinfinistore.cpp @@ -65,10 +65,9 @@ void Connection::close_conn() { send_wr.send_flags = IBV_SEND_SIGNALED; struct ibv_send_wr *bad_send_wr; - { - std::unique_lock lock(rdma_post_send_mutex_); - ibv_post_send(qp_, &send_wr, &bad_send_wr); - } + + ibv_post_send(qp_, &send_wr, &bad_send_wr); + // wait thread done cq_future_.get(); } @@ -97,13 +96,13 @@ Connection::~Connection() { } local_mr_.clear(); - if (recv_mr_) { - ibv_dereg_mr(recv_mr_); - } + // if (recv_mr_) { + // ibv_dereg_mr(recv_mr_); + // } - if (recv_buffer_) { - free(recv_buffer_); - } + // if (recv_buffer_) { + // free(recv_buffer_); + // } if (qp_) { struct ibv_qp_attr attr; @@ -279,18 +278,6 @@ int Connection::modify_qp_to_init() { return 0; } -int Connection::sync_rdma() { - std::unique_lock lock(mutex_); - bool ret = - cv_.wait_for(lock, std::chrono::seconds(10), [this] { return rdma_inflight_count_ == 0; }); - - if (!ret) { - ERROR("timeout to sync RDMA"); - return -1; - } - return 0; -} - void Connection::cq_handler() { assert(comp_channel_ != NULL); while (!stop_) { @@ -328,101 +315,25 @@ void Connection::cq_handler() { else if (wc[i].opcode == IBV_WC_RECV) { // allocate msg recved. rdma_info_base *ptr = reinterpret_cast(wc[i].wr_id); switch (ptr->get_wr_type()) { - case WrType::ALLOCATE: { - auto *info = reinterpret_cast(ptr); - info->callback(); - delete info; - break; - } - case WrType::READ_COMMIT: { - INFO("read cache done: Received IMM, imm_data: {}", wc[i].imm_data); - auto *info = reinterpret_cast(ptr); + case WrType::RDMA_READ_ACK: { + DEBUG("read cache done: Received IMM, imm_data: {}", + wc[i].imm_data); + auto *info = reinterpret_cast(ptr); info->callback(wc[i].imm_data); delete info; - rdma_inflight_count_--; - cv_.notify_all(); break; } - case WrType::WRITE_ACK: { - INFO("write cache done: Received IMM, imm_data: {}", - wc[i].imm_data); - auto *info = reinterpret_cast(ptr); - info->callback(); + case WrType::RDMA_WRITE_ACK: { + DEBUG("RDMA write cache done: Received IMM, imm_data: {}", + wc[i].imm_data); + auto *info = reinterpret_cast(ptr); + info->callback(wc->imm_data); delete info; - rdma_inflight_count_--; - cv_.notify_all(); break; } - } - } - else if (wc[i].opcode == IBV_WC_RDMA_WRITE) { // write cache done - - assert(outstanding_rdma_writes_ >= 0); - - std::unique_lock lock(rdma_post_send_mutex_); - - outstanding_rdma_writes_ -= MAX_WR_BATCH; - DEBUG("RDMA_WRITE completed, wr_id: {}, outstanding_rdma_writes: {}", - wc[i].wr_id, outstanding_rdma_writes_.load()); - - // drain the queue - if (!outstanding_rdma_writes_queue_.empty()) { - auto item = outstanding_rdma_writes_queue_.front(); - struct ibv_send_wr *wrs = item.first; - struct ibv_sge *sges = item.second; - ibv_send_wr *bad_wr = nullptr; - DEBUG("IBV POST SEND, wr_id: {}", wrs[0].wr_id); - int ret = ibv_post_send(qp_, &wrs[0], &bad_wr); - if (ret) { - ERROR("Failed to post RDMA write {}", strerror(ret)); - throw std::runtime_error("Failed to post RDMA write"); - } - outstanding_rdma_writes_ += MAX_WR_BATCH; - delete[] wrs; - delete[] sges; - outstanding_rdma_writes_queue_.pop_front(); - } - - // If this is the last WR of w_rdma, send RDMA COMMIT msg to server - if (wc[i].wr_id != 0) { - SendBuffer *send_buffer = get_send_buffer(); - FixedBufferAllocator allocator(send_buffer->buffer_, - PROTOCOL_BUFFER_SIZE); - FlatBufferBuilder builder(64 << 10, &allocator); - auto *info = reinterpret_cast(wc[i].wr_id); - - auto remote_addrs_offset = builder.CreateVector(info->remote_addrs); - - auto req = CreateRemoteMetaRequest( - builder, 0, 0, 0, remote_addrs_offset, OP_RDMA_WRITE_COMMIT); - builder.Finish(req); - - // recv RDMA COMMIT's ACK from server - post_recv(NULL, info); - - // send RDMA COMMIT msg to server - struct ibv_sge sge = {0}; - struct ibv_send_wr wr = {0}; - struct ibv_send_wr *bad_wr = NULL; - - sge.addr = (uintptr_t)builder.GetBufferPointer(); - sge.length = builder.GetSize(); - sge.lkey = send_buffer->mr_->lkey; - - wr.wr_id = (uintptr_t)send_buffer; - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - wr.send_flags = IBV_SEND_SIGNALED; - - int ret = ibv_post_send(qp_, &wr, &bad_wr); - if (ret) { - ERROR("Failed to post RDMA send :{}", strerror(ret)); + default: + ERROR("Unexpected wr type: {}", (int)ptr->get_wr_type()); return; - } - - // release lock before callback to prevent deadlock - lock.unlock(); } } else { @@ -480,17 +391,6 @@ int Connection::setup_rdma(client_config_t config) { return -1; } - if (posix_memalign(&recv_buffer_, 4096, PROTOCOL_BUFFER_SIZE) != 0) { - ERROR("Failed to allocate recv buffer"); - return -1; - } - recv_mr_ = ibv_reg_mr(pd_, recv_buffer_, PROTOCOL_BUFFER_SIZE, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); - if (!recv_mr_) { - ERROR("Failed to register recv MR"); - return -1; - } - /* This is MAX_RECV_WR not MAX_SEND_WR, because server also has the same number of buffers @@ -499,7 +399,6 @@ int Connection::setup_rdma(client_config_t config) { send_buffers_.push(new SendBuffer(pd_, PROTOCOL_BUFFER_SIZE)); } - rdma_inflight_count_ = 0; stop_ = false; cq_future_ = std::async(std::launch::async, [this]() { cq_handler(); }); @@ -645,7 +544,8 @@ int Connection::exchange_conn_info() { int Connection::check_exist(std::string key) { header_t header; - header = {.magic = MAGIC, .op = OP_CHECK_EXIST, .body_size = key.size()}; + header = { + .magic = MAGIC, .op = OP_CHECK_EXIST, .body_size = static_cast(key.size())}; struct iovec iov[2]; struct msghdr msg; @@ -793,149 +693,136 @@ int Connection::delete_keys(const std::vector &keys) { return count; } -std::vector *Connection::allocate_rdma(std::vector &keys, - int block_size) { - // convert allocate_rdma_async to sync version - std::promise promise; - auto future = promise.get_future(); - std::vector *ret_blocks; - allocate_rdma_async( - keys, block_size, - [&promise, &ret_blocks](std::vector *blocks, unsigned int error_code) { - ret_blocks = blocks; - if (error_code != FINISH) { - ERROR("allocate_rdma failed, error_code: {}", error_code); - } - promise.set_value(); - }); - - auto status = future.wait_for(std::chrono::seconds(5)); // timeout 5s - if (status == std::future_status::timeout) { - ERROR("allocate_rdma timeout"); - return nullptr; - } - else { - future.get(); - } - return ret_blocks; -} - -void Connection::post_recv(struct ibv_sge *recv_sge, rdma_info_base *info) { +void Connection::post_recv_ack(rdma_info_base *info) { struct ibv_recv_wr recv_wr = {0}; struct ibv_recv_wr *bad_recv_wr = NULL; recv_wr.wr_id = (uintptr_t)info; - if (recv_sge != NULL) { - recv_wr.next = NULL; - recv_wr.sg_list = recv_sge; - recv_wr.num_sge = 1; - } - else { - recv_wr.next = NULL; - recv_wr.sg_list = NULL; - recv_wr.num_sge = 0; - } + + recv_wr.next = NULL; + recv_wr.sg_list = NULL; + recv_wr.num_sge = 0; int ret = ibv_post_recv(qp_, &recv_wr, &bad_recv_wr); if (ret) { ERROR("Failed to post recv wr :{}", strerror(ret)); } } -// send a message to allocate memory and return the address -int Connection::allocate_rdma_async( - std::vector &keys, int block_size, - std::function *, unsigned int error_code)> callback) { - /* - ENCODING - remote_meta_request req = { - .keys = keys, - .block_size = block_size, - .op = OP_RDMA_ALLOCATE, + +std::vector *Connection::r_tcp(const std::string &key) { + FlatBufferBuilder builder(64 << 10); + auto req = CreateTCPPayloadRequestDirect(builder, key.c_str(), 0, OP_TCP_GET); + builder.Finish(req); + + header_t header = { + .magic = MAGIC, + .op = OP_TCP_PAYLOAD, + .body_size = builder.GetSize(), + }; + + struct iovec iov[2]; + struct msghdr msg; + memset(&msg, 0, sizeof(msg)); + + iov[0].iov_base = &header; + iov[0].iov_len = FIXED_HEADER_SIZE; + iov[1].iov_base = builder.GetBufferPointer(); + iov[1].iov_len = builder.GetSize(); + + msg.msg_iov = iov; + msg.msg_iovlen = 2; + + if (sendmsg(sock_, &msg, 0) < 0) { + ERROR("r_tcp: Failed to send header"); + return nullptr; } - */ - int ret; - // post recv msg first - struct ibv_sge recv_sge = {0}; - struct ibv_recv_wr *bad_recv_wr = NULL; - struct ibv_recv_wr recv_wr = {0}; + uint32_t buf[2]; + if (recv(sock_, &buf, RETURN_CODE_SIZE * 2, MSG_WAITALL) != RETURN_CODE_SIZE * 2) { + ERROR("r_tcp: Failed to receive return code"); + return nullptr; + } - // recv all remote addresses - recv_sge.addr = (uintptr_t)recv_buffer_; - recv_sge.length = PROTOCOL_BUFFER_SIZE; - recv_sge.lkey = recv_mr_->lkey; - - auto *info = new rdma_allocate_info([this, callback]() { - const RdmaAllocateResponse *resp = GetRdmaAllocateResponse(recv_buffer_); - INFO("Received allocate response, #keys: {}", resp->blocks()->size()); - - std::vector *blocks = new std::vector(); - blocks->reserve(resp->blocks()->size()); - for (const auto *block : *resp->blocks()) { - remote_block_t remote_block = { - .rkey = block->rkey(), - .remote_addr = block->remote_addr(), - }; - blocks->push_back(remote_block); - } - callback(blocks, resp->error_code()); - }); - // build a new callback function: + int return_code = buf[0]; + int size = buf[1]; - { - std::unique_lock lock(rdma_post_send_mutex_); - post_recv(&recv_sge, info); + if (return_code != FINISH) { + ERROR("r_tcp: Failed to get value, return code: {}", key, return_code); + return nullptr; } - // Send RDMA request - SendBuffer *send_buffer = get_send_buffer(); + if (size == 0) { + ERROR("r_tcp: size is 0"); + return nullptr; + } - FixedBufferAllocator allocator(send_buffer->buffer_, PROTOCOL_BUFFER_SIZE); - FlatBufferBuilder builder(64 << 10, &allocator); - auto keys_offset = builder.CreateVectorOfStrings(keys); + auto ret_buf = new std::vector(size); - auto req = CreateRemoteMetaRequest(builder, keys_offset, block_size, 0, 0, OP_RDMA_ALLOCATE); + if (recv(sock_, ret_buf->data(), size, MSG_WAITALL) != size) { + ERROR("r_tcp: Failed to receive payload"); + return nullptr; + } + return ret_buf; +} + +int Connection::w_tcp(const std::string &key, void *ptr, size_t size) { + assert(ptr != NULL); + FlatBufferBuilder builder(64 << 10); + auto req = CreateTCPPayloadRequestDirect(builder, key.c_str(), size, OP_TCP_PUT); builder.Finish(req); - struct ibv_sge sge = {0}; - struct ibv_send_wr wr = {0}; - struct ibv_send_wr *bad_wr = NULL; + header_t header = { + .magic = MAGIC, + .op = OP_TCP_PAYLOAD, + .body_size = builder.GetSize(), + }; - sge.addr = (uintptr_t)builder.GetBufferPointer(); - sge.length = builder.GetSize(); - sge.lkey = send_buffer->mr_->lkey; + struct iovec iov[2]; + struct msghdr msg; + memset(&msg, 0, sizeof(msg)); - wr.wr_id = (uintptr_t)send_buffer; - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - wr.send_flags = IBV_SEND_SIGNALED; - { - std::unique_lock lock(rdma_post_send_mutex_); - ret = ibv_post_send(qp_, &wr, &bad_wr); + iov[0].iov_base = &header; + iov[0].iov_len = FIXED_HEADER_SIZE; + iov[1].iov_base = builder.GetBufferPointer(); + iov[1].iov_len = builder.GetSize(); + + msg.msg_iov = iov; + msg.msg_iovlen = 2; + + if (sendmsg(sock_, &msg, MSG_MORE) < 0) { + ERROR("w_tcp: Failed to send header"); + return -1; } - if (ret) { - ERROR("Failed to post RDMA send :{}", strerror(ret)); + + // reuse iov[0] and msghdr + iov[0].iov_base = ptr; + iov[0].iov_len = size; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + if (sendmsg(sock_, &msg, 0) < 0) { + ERROR("w_tcp: Failed to send payload"); + return -1; + } + + int return_code = 0; + if (recv(sock_, &return_code, RETURN_CODE_SIZE, MSG_WAITALL) != RETURN_CODE_SIZE) { + ERROR("w_tcp: Failed to receive return code"); + return -1; + } + if (return_code != FINISH) { + ERROR("w_tcp: Failed to put key: {}, return code: {}", key, return_code); return -1; } - return 0; -} -int Connection::w_rdma(unsigned long *p_offsets, size_t offsets_len, int block_size, - remote_block_t *p_remote_blocks, size_t remote_blocks_len, void *base_ptr) { - return w_rdma_async(p_offsets, offsets_len, block_size, p_remote_blocks, remote_blocks_len, - base_ptr, []() {}); + return 0; } -int Connection::w_rdma_async(unsigned long *p_offsets, size_t offsets_len, int block_size, - remote_block_t *p_remote_blocks, size_t remote_blocks_len, - void *base_ptr, std::function callback) { +int Connection::w_rdma_async(const std::vector &keys, + const std::vector offsets, int block_size, void *base_ptr, + std::function callback) { assert(base_ptr != NULL); - assert(p_remote_blocks != NULL); - assert(offsets_len == remote_blocks_len); - - INFO("w_rdma, block_size: {}, base_ptr: {}", block_size, base_ptr); + assert(offsets.size() == keys.size()); if (!local_mr_.count((uintptr_t)base_ptr)) { ERROR("Please register memory first {}", (uint64_t)base_ptr); @@ -944,138 +831,58 @@ int Connection::w_rdma_async(unsigned long *p_offsets, size_t offsets_len, int b struct ibv_mr *mr = local_mr_[(uintptr_t)base_ptr]; - std::unique_lock lock(rdma_post_send_mutex_); - - const size_t max_wr = MAX_WR_BATCH; - - struct ibv_send_wr local_wrs[max_wr]; - struct ibv_sge local_sges[max_wr]; - - struct ibv_send_wr *wrs = local_wrs; - struct ibv_sge *sges = local_sges; - - size_t num_wr = 0; + // remote_meta_request req = { + // .keys = keys, + // .block_size = block_size, + // .op = OP_RDMA_WRITE, + // .remote_addrs = remote_addrs, + // } - bool wr_full = false; - - auto *info = new rdma_write_commit_info([callback]() { callback(); }, remote_blocks_len); + SendBuffer *send_buffer = get_send_buffer(); + FixedBufferAllocator allocator(send_buffer->buffer_, PROTOCOL_BUFFER_SIZE); + FlatBufferBuilder builder(64 << 10, &allocator); + auto keys_offset = builder.CreateVectorOfStrings(keys); - if (outstanding_rdma_writes_ + max_wr > MAX_RDMA_WRITE_WR) { - wr_full = true; - wrs = new struct ibv_send_wr[max_wr]; - sges = new struct ibv_sge[max_wr]; + // address is base_ptr + offset + std::vector remote_addrs; + for (size_t i = 0; i < offsets.size(); i++) { + remote_addrs.push_back((unsigned long)base_ptr + offsets[i]); } + auto remote_addrs_offset = builder.CreateVector(remote_addrs); + auto req = CreateRemoteMetaRequest(builder, keys_offset, block_size, mr->rkey, + remote_addrs_offset, OP_RDMA_WRITE); - size_t skipped = 0; - for (size_t i = 0; i < remote_blocks_len; i++) { - // skip duplicated remote blocks - if (is_fake_remote_block(p_remote_blocks[i])) { - skipped++; - continue; - } - - sges[num_wr].addr = (uintptr_t)(base_ptr + p_offsets[i]); - sges[num_wr].length = block_size; - sges[num_wr].lkey = mr->lkey; - - wrs[num_wr].opcode = IBV_WR_RDMA_WRITE; - if (i == remote_blocks_len - 1) { - // save all the remote addresses for committing keys - for (size_t j = 0; j < remote_blocks_len; j++) { - info->remote_addrs.push_back(p_remote_blocks[j].remote_addr); - } - - wrs[num_wr].wr_id = reinterpret_cast(info); - } - else { - wrs[num_wr].wr_id = 0; - } - - wrs[num_wr].sg_list = &sges[num_wr]; - wrs[num_wr].num_sge = 1; - wrs[num_wr].send_flags = - (num_wr == max_wr - 1 || i == remote_blocks_len - 1) ? IBV_SEND_SIGNALED : 0; - - wrs[num_wr].wr.rdma.remote_addr = p_remote_blocks[i].remote_addr; - wrs[num_wr].wr.rdma.rkey = p_remote_blocks[i].rkey; - wrs[num_wr].next = - (num_wr == max_wr - 1 || i == remote_blocks_len - 1) ? nullptr : &wrs[num_wr + 1]; - num_wr++; - - if (num_wr == max_wr || i == remote_blocks_len - 1) { - if (!wr_full) { - struct ibv_send_wr *bad_wr = nullptr; - int ret = ibv_post_send(qp_, &wrs[0], &bad_wr); - if (ret) { - ERROR("Failed to post RDMA write {}", strerror(ret)); - return -1; - } - outstanding_rdma_writes_ += max_wr; + builder.Finish(req); - // check if next iteration will exceed the limit - if (outstanding_rdma_writes_ + max_wr > MAX_RDMA_WRITE_WR) { - wr_full = true; - } - } - else { - // if WR queue is full, we need to put them into queue - DEBUG( - "WR queue full: push into temp queue, len: {}, first wr_id: {}, last wr_id: " - "{}", - num_wr, wrs[0].wr_id, wrs[num_wr - 1].wr_id); - outstanding_rdma_writes_queue_.push_back({&wrs[0], &sges[0]}); - } + // post recv msg first + auto *info = new rdma_write_info(callback); + post_recv_ack(info); - if (wr_full) { - wrs = new struct ibv_send_wr[max_wr]; - sges = new struct ibv_sge[max_wr]; - } - num_wr = 0; // Reset the counter for the next batch - } - } + // send msg + struct ibv_sge sge = {0}; + struct ibv_send_wr wr = {0}; + struct ibv_send_wr *bad_wr = NULL; + sge.addr = (uintptr_t)builder.GetBufferPointer(); + sge.length = builder.GetSize(); + sge.lkey = send_buffer->mr_->lkey; - // Check if there are remaining WRs to be sent - if (num_wr > 0) { - if (wr_full) { - DEBUG("WR queue full: push into temp queue, len: {}, first wr_id: {}, last wr_id: {}", - num_wr, wrs[0].wr_id, wrs[num_wr - 1].wr_id); - outstanding_rdma_writes_queue_.push_back({&wrs[0], &sges[0]}); - } - else { - struct ibv_send_wr *bad_wr = nullptr; - int ret = ibv_post_send(qp_, &wrs[0], &bad_wr); - if (ret) { - ERROR("Failed to post RDMA write {}", strerror(ret)); - return -1; - } - } - } + wr.wr_id = (uintptr_t)send_buffer; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.send_flags = IBV_SEND_SIGNALED; - if (skipped > 0) { - WARN("Skipped {} duplicated keys", skipped); - if (skipped == remote_blocks_len) { - // All keys are duplicated, skip RDMA write - lock.unlock(); - info->callback(); - delete info; - return 0; - } + int ret = ibv_post_send(qp_, &wr, &bad_wr); + if (ret) { + ERROR("Failed to post RDMA send :{}", strerror(ret)); + return -1; } - rdma_inflight_count_++; - DEBUG("rdma_inflight_count: {}", rdma_inflight_count_.load()); return 0; } -int Connection::r_rdma(std::vector &blocks, int block_size, void *base_ptr) { - return r_rdma_async(blocks, block_size, base_ptr, [](unsigned int code) { - if (code != FINISH) { - ERROR("Failed to read cache, error code: {}", code); - } - }); -} - -int Connection::r_rdma_async(std::vector &blocks, int block_size, void *base_ptr, +int Connection::r_rdma_async(const std::vector &keys, + const std::vector offsets, int block_size, void *base_ptr, std::function callback) { assert(base_ptr != NULL); @@ -1088,18 +895,15 @@ int Connection::r_rdma_async(std::vector &blocks, int block_size, void struct ibv_mr *mr = local_mr_[(uintptr_t)base_ptr]; assert(mr != NULL); - auto *info = new rdma_read_commit_info([callback](unsigned int code) { callback(code); }); - { - std::unique_lock lock(rdma_post_send_mutex_); - post_recv(NULL, info); - } + auto *info = new rdma_read_info([callback](unsigned int code) { callback(code); }); + post_recv_ack(info); - std::vector keys; + // std::vector keys; std::vector remote_addrs; - for (auto &block : blocks) { - keys.push_back(block.key); - remote_addrs.push_back((uintptr_t)(base_ptr + block.offset)); + for (auto &offset : offsets) { + remote_addrs.push_back((uintptr_t)(base_ptr + offset)); } + /* remote_meta_req = { .keys = keys, @@ -1136,16 +940,12 @@ int Connection::r_rdma_async(std::vector &blocks, int block_size, void wr.send_flags = IBV_SEND_SIGNALED; int ret; - { - std::unique_lock lock(rdma_post_send_mutex_); - ret = ibv_post_send(qp_, &wr, &bad_wr); - } + ret = ibv_post_send(qp_, &wr, &bad_wr); if (ret) { ERROR("Failed to post RDMA send :{}", strerror(ret)); return -1; } - rdma_inflight_count_++; return 0; } diff --git a/src/libinfinistore.h b/src/libinfinistore.h index d0aa959..abed0eb 100644 --- a/src/libinfinistore.h +++ b/src/libinfinistore.h @@ -33,9 +33,8 @@ struct SendBuffer { enum class WrType { BASE, - ALLOCATE, - READ_COMMIT, - WRITE_ACK, + RDMA_READ_ACK, + RDMA_WRITE_ACK, }; struct rdma_info_base { @@ -48,29 +47,17 @@ struct rdma_info_base { WrType get_wr_type() const { return wr_type; } }; -struct rdma_allocate_info : rdma_info_base { - std::function callback; - rdma_allocate_info(std::function callback) - : rdma_info_base(WrType::ALLOCATE), callback(callback) {} +struct rdma_write_info : rdma_info_base { + std::function callback; + rdma_write_info(std::function callback) + : rdma_info_base(WrType::RDMA_WRITE_ACK), callback(callback) {} }; -struct rdma_read_commit_info : rdma_info_base { +struct rdma_read_info : rdma_info_base { // call back function. std::function callback; - rdma_read_commit_info(std::function callback) - : rdma_info_base(WrType::READ_COMMIT), callback(callback) {} -}; - -struct rdma_write_commit_info : rdma_info_base { - // call back function. - std::function callback; - // the number of blocks that have been written. - std::vector remote_addrs; - - rdma_write_commit_info(std::function callback, int n) - : rdma_info_base(WrType::WRITE_ACK), callback(callback), remote_addrs() { - remote_addrs.reserve(n); - } + rdma_read_info(std::function callback) + : rdma_info_base(WrType::RDMA_READ_ACK), callback(callback) {} }; class Connection { @@ -101,26 +88,10 @@ class Connection { */ boost::lockfree::spsc_queue send_buffers_{MAX_RECV_WR}; - // this recv buffer is used in - // 1. allocate rdma - // 2. recv IMM data, although IMM DATA is not put into recv_buffer, - // but for compatibility, we still use a zero-length recv_buffer. - void *recv_buffer_ = NULL; - struct ibv_mr *recv_mr_ = NULL; - struct ibv_comp_channel *comp_channel_ = NULL; std::future cq_future_; // cq thread - std::atomic rdma_inflight_count_{0}; std::atomic stop_{false}; - // protect rdma_inflight_count - std::mutex mutex_; - std::condition_variable cv_; - - // protect ibv_post_send, outstanding_rdma_writes_queue - std::mutex rdma_post_send_mutex_; - std::atomic outstanding_rdma_writes_{0}; - std::deque> outstanding_rdma_writes_queue_; public: Connection() = default; @@ -131,25 +102,14 @@ class Connection { // close cq_handler thread void close_conn(); int init_connection(client_config_t config); - // async rw local cpu memory, even rw_local returns, it is not guaranteed that - // the operation is completed until sync_local is recved. - int rw_local(char op, const std::vector &blocks, int block_size, void *ptr, - int device_id); - int sync_local(); int setup_rdma(client_config_t config); - int r_rdma(std::vector &blocks, int block_size, void *base_ptr); - int r_rdma_async(std::vector &blocks, int block_size, void *base_ptr, - std::function callback); - int w_rdma(unsigned long *p_offsets, size_t offsets_len, int block_size, - remote_block_t *p_remote_blocks, size_t remote_blocks_len, void *base_ptr); - int w_rdma_async(unsigned long *p_offsets, size_t offsets_len, int block_size, - remote_block_t *p_remote_blocks, size_t remote_blocks_len, void *base_ptr, - std::function callback); - int sync_rdma(); - std::vector *allocate_rdma(std::vector &keys, int block_size); - int allocate_rdma_async( - std::vector &keys, int block_size, - std::function *, unsigned int error_code)> callback); + int r_rdma_async(const std::vector &keys, const std::vector offsets, + int block_size, void *base_ptr, std::function callback); + int w_rdma_async(const std::vector &keys, const std::vector offsets, + int block_size, void *base_ptr, std::function callback); + int w_tcp(const std::string &key, void *ptr, size_t size); + std::vector *r_tcp(const std::string &key); + int check_exist(std::string key); int get_match_last_index(std::vector &keys); int delete_keys(const std::vector &keys); @@ -161,12 +121,15 @@ class Connection { int exchange_conn_info(); int init_rdma_resources(client_config_t config); - void post_recv(struct ibv_sge *recv_sge, rdma_info_base *info); + void post_recv_ack(rdma_info_base *info); void cq_handler(); // TODO: refactor to c++ style SendBuffer *get_send_buffer(); void release_send_buffer(SendBuffer *buffer); + + SendBuffer *get_recv_buffer(); + void release_recv_buffer(SendBuffer *buffer); }; #endif // LIBINFINISTORE_H diff --git a/src/mempool.cpp b/src/mempool.cpp index 1ea3536..a805484 100644 --- a/src/mempool.cpp +++ b/src/mempool.cpp @@ -54,7 +54,7 @@ MemoryPool::~MemoryPool() { int MemoryPool::allocate(size_t size, size_t n, SimpleAllocationCallback callback) { size_t required_blocks = (size + block_size_ - 1) / block_size_; // round up - int num_allocated = 0; + size_t num_allocated = 0; if (required_blocks > total_blocks_) { return 0; @@ -157,8 +157,8 @@ void MM::add_mempool(size_t pool_size, size_t block_size, struct ibv_pd* pd) { bool MM::allocate(size_t size, size_t n, AllocationCallback callback) { bool allocated = false; - int mempool_cnt = mempools_.size(); - for (int i = 0; i < mempool_cnt; ++i) { + size_t mempool_cnt = mempools_.size(); + for (size_t i = 0; i < mempool_cnt; ++i) { // create a new callback from the original callback auto simple_callback = [callback, i](void* ptr, uint32_t lkey, uint32_t rkey) { callback(ptr, lkey, rkey, i); diff --git a/src/mempool.h b/src/mempool.h index 5e4c287..f2dd801 100644 --- a/src/mempool.h +++ b/src/mempool.h @@ -66,11 +66,11 @@ class MM { bool allocate(size_t size, size_t n, AllocationCallback callback); void deallocate(void* ptr, size_t size, int pool_idx); uint32_t get_lkey(int pool_idx) const { - assert(pool_idx >= 0 && pool_idx < mempools_.size()); + assert(pool_idx >= 0 && (size_t)pool_idx < mempools_.size()); return mempools_[pool_idx]->get_lkey(); } uint32_t get_rkey(int pool_idx) const { - assert(pool_idx >= 0 && pool_idx < mempools_.size()); + assert(pool_idx >= 0 && (size_t)pool_idx < mempools_.size()); return mempools_[pool_idx]->get_rkey(); } diff --git a/src/meta_request.fbs b/src/meta_request.fbs index 2b7e0e9..b0b538b 100644 --- a/src/meta_request.fbs +++ b/src/meta_request.fbs @@ -1,4 +1,4 @@ -//RDMA read/allocate request +//RDMA read/write request table RemoteMetaRequest { keys: [string]; block_size: int; diff --git a/src/protocol.cpp b/src/protocol.cpp index 4997e6e..7079ae1 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -2,13 +2,12 @@ #include -std::unordered_map op_map = {{OP_RDMA_EXCHANGE, "RDMA_EXCHANGE"}, - {OP_RDMA_READ, "RDMA_READ"}, - {OP_RDMA_WRITE_COMMIT, "RDMA_WRITE_COMMIT"}, - {OP_RDMA_ALLOCATE, "RDMA_ALLOCATE"}, - {OP_CHECK_EXIST, "CHECK_EXIST"}, - {OP_GET_MATCH_LAST_IDX, "GET_MATCH_LAST_IDX"}, - {OP_DELETE_KEYS, "DELETE_KEYS"}}; +std::unordered_map op_map = { + {OP_TCP_PAYLOAD, "TCP_PAYLOAD"}, {OP_TCP_PUT, "TCP_PUT"}, + {OP_TCP_GET, "TCP_GET"}, {OP_RDMA_EXCHANGE, "RDMA_EXCHANGE"}, + {OP_RDMA_READ, "RDMA_READ"}, {OP_RDMA_WRITE, "RDMA_WRITE"}, + {OP_CHECK_EXIST, "CHECK_EXIST"}, {OP_GET_MATCH_LAST_IDX, "GET_MATCH_LAST_IDX"}, + {OP_DELETE_KEYS, "DELETE_KEYS"}}; std::string op_name(char op_code) { auto it = op_map.find(op_code); if (it != op_map.end()) { @@ -29,7 +28,3 @@ uint8_t* FixedBufferAllocator::allocate(size_t size) { void FixedBufferAllocator::deallocate(uint8_t*, size_t) { // no-op } - -bool is_fake_remote_block(remote_block_t& block) { - return block.remote_addr == 0 && block.rkey == 0; -} diff --git a/src/protocol.h b/src/protocol.h index 53cada7..db6f79d 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -14,6 +14,7 @@ // local TCP protocols #include "delete_keys_generated.h" #include "get_match_last_index_generated.h" +#include "tcp_payload_request_generated.h" using namespace flatbuffers; @@ -22,11 +23,11 @@ using namespace flatbuffers; // this is only used for recving RDMA_SEND or IMM data. this should be bigger than max layers of // model. -#define MAX_RECV_WR 64 +#define MAX_RECV_WR 128 // how many RDMA write requests can be outstanding, this should be bigger than MAX_WR_BATCH and less // than MAX_SEND_WR -#define MAX_RDMA_WRITE_WR 4096 +#define MAX_RDMA_OPS_WR 8000 // every MAX_WR_BATCH RDMA write requests will have a RDMA_SIGNAL #define MAX_WR_BATCH 32 @@ -35,13 +36,17 @@ using namespace flatbuffers; #define MAGIC_SIZE 4 #define OP_RDMA_EXCHANGE 'E' -#define OP_RDMA_ALLOCATE 'D' #define OP_RDMA_READ 'A' -#define OP_RDMA_WRITE_COMMIT 'T' +#define OP_RDMA_WRITE 'W' #define OP_CHECK_EXIST 'C' #define OP_GET_MATCH_LAST_IDX 'M' #define OP_DELETE_KEYS 'X' #define OP_SIZE 1 + +#define OP_TCP_PUT 'P' +#define OP_TCP_GET 'G' +#define OP_TCP_PAYLOAD 'L' + // please add op name in protocol.cpp std::string op_name(char op); @@ -66,18 +71,6 @@ typedef struct __attribute__((packed)) { unsigned int body_size; } header_t; -// remote_block_t is used to to talk to PYTHON layer. not used in RDMA/TCP layer. -typedef struct { - uint32_t rkey; - uintptr_t remote_addr; -} remote_block_t; - -// block_t is used to to talk to PYTHON layer. not used in RDMA/TCP layer. -typedef struct { - std::string key; - unsigned long offset; -} block_t; - typedef struct __attribute__((packed)) rdma_conn_info_t { uint32_t qpn; uint32_t psn; @@ -101,7 +94,4 @@ class FixedBufferAllocator : public Allocator { size_t offset_; }; -const RemoteBlock FAKE_REMOTE_BLOCK = RemoteBlock(0, 0); -bool is_fake_remote_block(remote_block_t& block); - #endif diff --git a/src/pybind.cpp b/src/pybind.cpp index 35136f6..8b0d15a 100644 --- a/src/pybind.cpp +++ b/src/pybind.cpp @@ -44,118 +44,44 @@ PYBIND11_MODULE(_infinistore, m) { .def_readwrite("link_type", &client_config_t::link_type) .def_readwrite("host_addr", &client_config_t::host_addr); - PYBIND11_NUMPY_DTYPE(remote_block_t, rkey, remote_addr); - py::class_>(m, "Connection") .def(py::init<>()) .def("close", &Connection::close_conn, py::call_guard(), "close the connection") - - .def( - "r_rdma", - [](Connection &self, const std::vector> &blocks, - int block_size, uintptr_t ptr) { - std::vector c_blocks; - for (const auto &block : blocks) { - c_blocks.push_back(block_t{std::get<0>(block), std::get<1>(block)}); - } - return self.r_rdma(c_blocks, block_size, (void *)ptr); - }, - py::call_guard(), "Read remote memory") - .def( - "w_rdma", - [](Connection &self, - py::array_t offsets, - int block_size, - py::array_t remote_blocks, - uintptr_t base_ptr) { - py::buffer_info block_buf = remote_blocks.request(); - py::buffer_info offset_buf = offsets.request(); - - assert(block_buf.ndim == 1); - assert(offset_buf.ndim == 1); - - remote_block_t *p_remote_blocks = static_cast(block_buf.ptr); - unsigned long *p_offsets = static_cast(offset_buf.ptr); - size_t remote_blocks_len = block_buf.shape[0]; - size_t offsets_len = offset_buf.shape[0]; - return self.w_rdma(p_offsets, offsets_len, block_size, p_remote_blocks, - remote_blocks_len, (void *)base_ptr); + "w_tcp", + [](Connection &self, const std::string &key, uintptr_t ptr, size_t size) { + return self.w_tcp(key, (void *)ptr, size); }, - "Write remote memory") - + py::call_guard(), "Write remote memory using TCP") .def( - "r_rdma_async", - [](Connection &self, const std::vector> &blocks, - int block_size, uintptr_t ptr, std::function callback) { - std::vector c_blocks; - for (const auto &block : blocks) { - c_blocks.push_back(block_t{std::get<0>(block), std::get<1>(block)}); - } - return self.r_rdma_async(c_blocks, block_size, (void *)ptr, callback); + "r_tcp", + [](Connection &self, const std::string &key) { + auto vector_ptr = self.r_tcp(key); + py::gil_scoped_acquire acquire; + return as_pyarray(std::move(*vector_ptr)); }, - py::call_guard(), "Read remote memory asynchronously") - + py::call_guard(), "Read remote memory using TCP") .def( "w_rdma_async", - [](Connection &self, - py::array_t offsets, - int block_size, - py::array_t remote_blocks, - uintptr_t base_ptr, std::function callback) { - py::buffer_info block_buf = remote_blocks.request(); - py::buffer_info offset_buf = offsets.request(); - - assert(block_buf.ndim == 1); - assert(offset_buf.ndim == 1); - - remote_block_t *p_remote_blocks = static_cast(block_buf.ptr); - unsigned long *p_offsets = static_cast(offset_buf.ptr); - size_t remote_blocks_len = block_buf.shape[0]; - size_t offsets_len = offset_buf.shape[0]; - return self.w_rdma_async(p_offsets, offsets_len, block_size, p_remote_blocks, - remote_blocks_len, (void *)base_ptr, [callback]() { - // python code will take_gil by itself - callback(); - }); + [](Connection &self, const std::vector &keys, + const std::vector offsets, int block_size, uintptr_t base_ptr, + std::function callback) { + return self.w_rdma_async(keys, offsets, block_size, (void *)base_ptr, callback); }, - "Write remote memory asynchronously") - + py::call_guard(), "write rdma async") .def( - "allocate_rdma", - [](Connection &self, std::vector &keys, int block_size) { - std::vector *blocks = self.allocate_rdma(keys, block_size); - // throw python exception if blocks is nullptr - if (blocks == nullptr) { - throw std::runtime_error("Failed to allocate remote memory"); - } - py::gil_scoped_acquire acquire; - return as_pyarray(std::move(*blocks)); - }, - py::call_guard(), "Allocate remote memory") - - .def( - "allocate_rdma_async", - [](Connection &self, std::vector &keys, int block_size, - std::function callback) { - self.allocate_rdma_async( - keys, block_size, - [callback](std::vector *blocks, unsigned int error_code) { - py::gil_scoped_acquire acquire; - callback(as_pyarray(std::move(*blocks)), error_code); - delete blocks; - }); - return; + "r_rdma_async", + [](Connection &self, const std::vector &keys, + const std::vector offsets, int block_size, uintptr_t base_ptr, + std::function callback) { + return self.r_rdma_async(keys, offsets, block_size, (void *)base_ptr, callback); }, - py::call_guard(), "Allocate remote memory asynchronously") - + py::call_guard(), "Read remote memory asynchronously") .def("init_connection", &Connection::init_connection, py::call_guard(), "init connection") .def("setup_rdma", &Connection::setup_rdma, py::call_guard(), "setup rdma connection") - .def("sync_rdma", &Connection::sync_rdma, py::call_guard(), - "sync the remote server") .def("check_exist", &Connection::check_exist, py::call_guard(), "check if the key exists in the store") .def("get_match_last_index", &Connection::get_match_last_index, @@ -163,7 +89,6 @@ PYBIND11_MODULE(_infinistore, m) { "get the last index of a key list which is in the store") .def("delete_keys", &Connection::delete_keys, py::call_guard(), "delete a list of keys which are in store") - .def( "register_mr", [](Connection &self, uintptr_t ptr, size_t ptr_region_size) { diff --git a/src/tcp_payload_request.fbs b/src/tcp_payload_request.fbs new file mode 100644 index 0000000..1154cb7 --- /dev/null +++ b/src/tcp_payload_request.fbs @@ -0,0 +1,7 @@ +table TCPPayloadRequest { + key: string; + value_length: int; + op: byte; +} + +root_type TCPPayloadRequest; \ No newline at end of file diff --git a/src/tcp_payload_request_generated.h b/src/tcp_payload_request_generated.h new file mode 100644 index 0000000..afe3162 --- /dev/null +++ b/src/tcp_payload_request_generated.h @@ -0,0 +1,94 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_TCPPAYLOADREQUEST_H_ +#define FLATBUFFERS_GENERATED_TCPPAYLOADREQUEST_H_ + +#include "flatbuffers/flatbuffers.h" + +struct TCPPayloadRequest; +struct TCPPayloadRequestBuilder; + +struct TCPPayloadRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TCPPayloadRequestBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEY = 4, + VT_VALUE_LENGTH = 6, + VT_OP = 8 + }; + const flatbuffers::String *key() const { + return GetPointer(VT_KEY); + } + int32_t value_length() const { return GetField(VT_VALUE_LENGTH, 0); } + int8_t op() const { return GetField(VT_OP, 0); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_KEY) && + verifier.VerifyString(key()) && VerifyField(verifier, VT_VALUE_LENGTH) && + VerifyField(verifier, VT_OP) && verifier.EndTable(); + } +}; + +struct TCPPayloadRequestBuilder { + typedef TCPPayloadRequest Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset key) { + fbb_.AddOffset(TCPPayloadRequest::VT_KEY, key); + } + void add_value_length(int32_t value_length) { + fbb_.AddElement(TCPPayloadRequest::VT_VALUE_LENGTH, value_length, 0); + } + void add_op(int8_t op) { fbb_.AddElement(TCPPayloadRequest::VT_OP, op, 0); } + explicit TCPPayloadRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTCPPayloadRequest( + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset key = 0, + int32_t value_length = 0, int8_t op = 0) { + TCPPayloadRequestBuilder builder_(_fbb); + builder_.add_value_length(value_length); + builder_.add_key(key); + builder_.add_op(op); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTCPPayloadRequestDirect( + flatbuffers::FlatBufferBuilder &_fbb, const char *key = nullptr, int32_t value_length = 0, + int8_t op = 0) { + auto key__ = key ? _fbb.CreateString(key) : 0; + return CreateTCPPayloadRequest(_fbb, key__, value_length, op); +} + +inline const TCPPayloadRequest *GetTCPPayloadRequest(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const TCPPayloadRequest *GetSizePrefixedTCPPayloadRequest(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline bool VerifyTCPPayloadRequestBuffer(flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(nullptr); +} + +inline bool VerifySizePrefixedTCPPayloadRequestBuffer(flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(nullptr); +} + +inline void FinishTCPPayloadRequestBuffer(flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root); +} + +inline void FinishSizePrefixedTCPPayloadRequestBuffer(flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root); +} + +#endif // FLATBUFFERS_GENERATED_TCPPAYLOADREQUEST_H_ From 9000a9126cfe54522e9609c9d4fea52e28d5dd7b Mon Sep 17 00:00:00 2001 From: thesues Date: Thu, 13 Mar 2025 22:37:27 +0000 Subject: [PATCH 04/13] support overwrite existing key --- infinistore/test_infinistore.py | 37 ++++++++++++++++++++++++++++----- src/infinistore.cpp | 27 +++++++----------------- src/infinistore.h | 11 +++++----- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/infinistore/test_infinistore.py b/infinistore/test_infinistore.py index ea4453e..368f735 100644 --- a/infinistore/test_infinistore.py +++ b/infinistore/test_infinistore.py @@ -375,11 +375,6 @@ async def run(): dst_conn.close() -def test_overwrite(server): - # FIXME: implement this test - pass - - def test_async_api(server): config = infinistore.ClientConfig( host_addr="127.0.0.1", @@ -542,3 +537,35 @@ def test_simple_tcp_read_write(server): assert dst[i] == src[i] finally: conn.close() + + +def test_overwrite_tcp(server): + config = infinistore.ClientConfig( + host_addr="127.0.0.1", + service_port=92345, + connection_type=infinistore.TYPE_TCP, + ) + + try: + conn = infinistore.InfinityConnection(config) + conn.connect() + key = generate_random_string(10) + size = 256 * 1024 + src = bytearray(size) + for i in range(size): + src[i] = i % 200 + conn.tcp_write_cache(key, get_ptr(src), len(src)) + dst = conn.tcp_read_cache(key) + assert len(dst) == len(src) + for i in range(len(src)): + assert dst[i] == src[i] + + # overwrite the key + src = bytearray(size) + for i in range(size): + src[i] = i % 100 + conn.tcp_write_cache(key, get_ptr(src), len(src)) + dst = conn.tcp_read_cache(key) + assert len(dst) == len(src) + finally: + conn.close() diff --git a/src/infinistore.cpp b/src/infinistore.cpp index 8dadac6..b444031 100644 --- a/src/infinistore.cpp +++ b/src/infinistore.cpp @@ -12,7 +12,6 @@ #include #include -#include #include #include #include @@ -41,8 +40,7 @@ ibv_mtu active_mtu; // indicate if the MM extend is in flight bool extend_in_flight = false; -std::unordered_map> inflight_rdma_writes; - +static std::deque> lru_queue; std::unordered_map> kv_map; typedef enum { @@ -253,8 +251,8 @@ int Client::tcp_payload_request(const TCPPayloadRequest *req) { case OP_TCP_PUT: { int ret = mm->allocate(req->value_length(), 1, [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { - current_tcp_task_ = boost::intrusive_ptr( - new PTR(addr, req->value_length(), pool_idx, false)); + current_tcp_task_ = boost::intrusive_ptr(new PTR( + addr, req->value_length(), pool_idx, req->key()->str())); }); if (ret < 0) { ERROR("Failed to allocate memory"); @@ -276,9 +274,6 @@ int Client::tcp_payload_request(const TCPPayloadRequest *req) { if (it == kv_map.end()) { return KEY_NOT_FOUND; } - if (!it->second->committed) { - return KEY_NOT_FOUND; - } auto ptr = it->second; uint32_t *header_buf = (uint32_t *)malloc(sizeof(uint32_t) * 2); @@ -416,7 +411,7 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { auto inflight_rdma_writes = (std::vector> *)wc.wr_id; for (auto ptr : *inflight_rdma_writes) { - ptr->committed = true; + kv_map[std::move(ptr->key)] = ptr; } delete inflight_rdma_writes; post_ack(FINISH); @@ -571,9 +566,8 @@ int Client::write_rdma_cache(const RemoteMetaRequest *remote_meta_req) { bool allocated = mm->allocate(block_size, n, [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { const auto *key = remote_meta_req->keys()->Get(key_idx); - auto ptr = boost::intrusive_ptr(new PTR(addr, block_size, pool_idx, false)); + auto ptr = boost::intrusive_ptr(new PTR(addr, block_size, pool_idx, key->str())); inflight_rdma_writes->push_back(ptr); - kv_map[key->str()] = ptr; key_idx++; }); @@ -606,12 +600,6 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { WARN("Key not found: {}", key->str()); return KEY_NOT_FOUND; } - - if (!it->second->committed) { - WARN("Key not committed: {}, return KEY_NOT_FOUND", key->str()); - return KEY_NOT_FOUND; - } - const auto &ptr = it->second; inflight_rdma_reads->push_back(ptr); @@ -919,7 +907,7 @@ void Client::send_resp(int return_code, void *buf, size_t size) { int Client::check_key(const std::string &key_to_check) { int ret; // check if the key exists and committed - if (kv_map.count(key_to_check) > 0 && kv_map[key_to_check]->committed) { + if (kv_map.count(key_to_check) > 0) { ret = 0; } else { @@ -1090,7 +1078,8 @@ void on_read(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) { client->bytes_read_ += to_copy; offset += to_copy; if (client->bytes_read_ == client->expected_bytes_) { - client->current_tcp_task_->committed = true; + auto ptr = client->current_tcp_task_; + kv_map[std::move(ptr->key)] = ptr; client->current_tcp_task_.reset(); client->send_resp(FINISH, NULL, 0); client->reset_client_read_state(); diff --git a/src/infinistore.h b/src/infinistore.h index 8aad059..83a9074 100644 --- a/src/infinistore.h +++ b/src/infinistore.h @@ -2,6 +2,8 @@ #define INFINISTORE_H #include +#include + #include "config.h" #include "log.h" #include "mempool.h" @@ -23,8 +25,6 @@ extern ibv_mtu active_mtu; // indicate if the MM extend is in flight extern bool extend_in_flight; -// indicate the number of cudaIpcOpenMemHandle -extern std::atomic opened_ipc; // PTR is shared by kv_map and inflight_rdma_kv_map class PTR : public IntrusivePtrTarget { @@ -32,9 +32,10 @@ class PTR : public IntrusivePtrTarget { void *ptr = nullptr; size_t size; int pool_idx; - bool committed; - PTR(void *ptr, size_t size, int pool_idx, bool committed = false) - : ptr(ptr), size(size), pool_idx(pool_idx), committed(committed) {} + std::string key; + std::deque>::iterator lru_it; + PTR(void *ptr, size_t size, int pool_idx, const std::string &key) + : ptr(ptr), size(size), pool_idx(pool_idx), key(key) {} ~PTR() { if (ptr) { DEBUG("deallocate ptr: {}, size: {}, pool_idx: {}", ptr, size, pool_idx); From b116efcdd35c7383e48a8ae333ec6408f19b458a Mon Sep 17 00:00:00 2001 From: thesues Date: Fri, 14 Mar 2025 06:43:47 +0000 Subject: [PATCH 05/13] Feature: add evict policy(#131) --- infinistore/__init__.py | 2 ++ infinistore/lib.py | 28 +++++++++++++-- infinistore/server.py | 40 ++++++++++++++++++--- src/config.h | 1 - src/infinistore.cpp | 78 +++++++++++++++++++++++++++++++++-------- src/infinistore.h | 5 +-- src/mempool.cpp | 16 +++++++++ src/mempool.h | 1 + src/pybind.cpp | 2 +- 9 files changed, 147 insertions(+), 26 deletions(-) diff --git a/infinistore/__init__.py b/infinistore/__init__.py index e96dc30..46bcd6c 100644 --- a/infinistore/__init__.py +++ b/infinistore/__init__.py @@ -12,6 +12,7 @@ get_kvmap_len, InfiniStoreException, InfiniStoreKeyNotFound, + evict_cache, ) __all__ = [ @@ -28,4 +29,5 @@ "get_kvmap_len", "InfiniStoreException", "InfiniStoreKeyNotFound", + "evict_cache", ] diff --git a/infinistore/lib.py b/infinistore/lib.py index cb23fbb..11d14af 100644 --- a/infinistore/lib.py +++ b/infinistore/lib.py @@ -98,7 +98,6 @@ class ServerConfig: link_type (str): The type of link. Defaults to "IB". prealloc_size (int): The preallocation size. Defaults to 16. minimal_allocate_size (int): The minimal allocation size. Defaults to 64. - num_stream (int): The number of streams. Defaults to 1. auto_increase (bool): indicate if infinistore will be automatically increased. 10GB each time. Default False. """ @@ -112,8 +111,10 @@ def __init__(self, **kwargs): self.link_type = kwargs.get("link_type", "IB") self.prealloc_size = kwargs.get("prealloc_size", 16) self.minimal_allocate_size = kwargs.get("minimal_allocate_size", 64) - self.num_stream = kwargs.get("num_stream", 1) self.auto_increase = kwargs.get("auto_increase", False) + self.evict_min_threshold = kwargs.get("evict_min_threshold", 0.1) + self.evict_max_threshold = kwargs.get("evict_max_threshold", 0.2) + self.evict_interval = kwargs.get("evict_interval", 5) def __repr__(self): return ( @@ -121,7 +122,8 @@ def __repr__(self): f"log_level='{self.log_level}', " f"dev_name='{self.dev_name}', ib_port={self.ib_port}, link_type='{self.link_type}', " f"prealloc_size={self.prealloc_size}, minimal_allocate_size={self.minimal_allocate_size}, " - f"num_stream={self.num_stream}" + f"auto_increase={self.auto_increase}, evict_min_threshold={self.evict_min_threshold}, " + f"evict_max_threshold={self.evict_max_threshold}, evict_interval={self.evict_interval}" ) def verify(self): @@ -216,6 +218,26 @@ def register_server(loop, config: ServerConfig): raise Exception("Failed to register server") +def evict_cache(min_threshold: float, max_threshold: float): + """ + Evicts the cache in the infinistore. + + This function calls the underlying _infinistore.evict_cache() method to + clear all entries in the cache, effectively resetting it. + + Returns: + The result of the _infinistore.evict_cache() method call. + """ + if min_threshold >= max_threshold: + raise Exception("min_threshold should be less than max_threshold") + if min_threshold > 1 or min_threshold < 0: + raise Exception("min_threshold should be in [0, 1]") + if max_threshold > 1 or max_threshold < 0: + raise Exception("max_threshold should be in [0, 1]") + + return _infinistore.evict_cache(min_threshold, max_threshold) + + def _kernel_modules(): modules = set() try: diff --git a/infinistore/server.py b/infinistore/server.py index cd67eed..2777ffa 100644 --- a/infinistore/server.py +++ b/infinistore/server.py @@ -5,6 +5,7 @@ get_kvmap_len, ServerConfig, Logger, + evict_cache, ) import asyncio import uvloop @@ -112,11 +113,29 @@ def parse_args(): type=int, ) parser.add_argument( - "--num-stream", + "--evict-interval", required=False, - default=1, - help="(deprecated)number of streams, default 1, can only be 1, 2, 4", - type=int, + default=5, + help="evict interval, default 5s", + ) + parser.add_argument( + "--evict-min-threshold", + required=False, + default=0.6, + help="evict min threshold, default 0.6", + ) + parser.add_argument( + "--evict-max-threshold", + required=False, + default=0.8, + help="evict max threshold, default 0.8", + ) + parser.add_argument( + "--enable-periodic-evict", + required=False, + action="store_true", + default=False, + help="enable evict cache, default False", ) return parser.parse_args() @@ -128,6 +147,12 @@ def prevent_oom(): f.write("-1000") +async def periodic_evict(min_threshold=0.6, max_threshold=0.9, interval=5): + while True: + evict_cache(min_threshold, max_threshold) + await asyncio.sleep(interval) + + def main(): args = parse_args() config = ServerConfig( @@ -139,8 +164,10 @@ def main(): ib_port=args.ib_port, link_type=args.link_type, minimal_allocate_size=args.minimal_allocate_size, - num_stream=args.num_stream, auto_increase=args.auto_increase, + evict_interval=args.evict_interval, + evict_min_threshold=args.evict_min_threshold, + evict_max_threshold=args.evict_max_threshold, ) config.verify() @@ -153,7 +180,10 @@ def main(): # TODO: find the minimum size for pinning memory and ib_reg_mr register_server(loop, config) + if args.enable_periodic_evict: + loop.create_task(periodic_evict()) prevent_oom() + Logger.info("set oom_score_adj to -1000 to prevent OOM") http_config = uvicorn.Config( diff --git a/src/config.h b/src/config.h index efb09ed..dc7942c 100644 --- a/src/config.h +++ b/src/config.h @@ -18,7 +18,6 @@ typedef struct ServerConfig { int ib_port; std::string link_type; int minimal_allocate_size; // unit: KB - int num_stream; // can only be 1,2,4, number of stream for each client bool auto_increase; } server_config_t; diff --git a/src/infinistore.cpp b/src/infinistore.cpp index b444031..5c543d2 100644 --- a/src/infinistore.cpp +++ b/src/infinistore.cpp @@ -40,7 +40,9 @@ ibv_mtu active_mtu; // indicate if the MM extend is in flight bool extend_in_flight = false; -static std::deque> lru_queue; +// evict memory from head to tail, the PTR is shared by lru_queue and kv_map +// so we have to pop from both to evict the memory +std::list> lru_queue; std::unordered_map> kv_map; typedef enum { @@ -52,6 +54,9 @@ typedef enum { // the max data could be send in uv_write static const size_t MAX_SEND_SIZE = 256 << 10; +const float ON_DEMAND_MIN_THRESHOLD = 0.8; +const float ON_DEMAND_MAX_THRESHOLD = 0.95; + struct Client { uv_tcp_t *handle_ = NULL; // uv_stream_t read_state_t state_; // state of the client, for parsing the request @@ -244,17 +249,33 @@ void on_head_write(uv_write_t *req, int status) { free(req); } +void evict_cache(float min_threshold, float max_threshold) { + if (mm->usage() >= max_threshold) { + // stop when mm->usage is below min_threshold + float usage = mm->usage(); + while (mm->usage() >= min_threshold && !lru_queue.empty()) { + auto ptr = lru_queue.front(); + lru_queue.pop_front(); + kv_map.erase(ptr->key); + } + INFO("evict memory done, usage: from {:.2f} => {:.2f}", usage, mm->usage()); + } +} + int Client::tcp_payload_request(const TCPPayloadRequest *req) { DEBUG("do tcp_payload_request... {}", op_name(req->op())); switch (req->op()) { case OP_TCP_PUT: { - int ret = mm->allocate(req->value_length(), 1, - [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { - current_tcp_task_ = boost::intrusive_ptr(new PTR( - addr, req->value_length(), pool_idx, req->key()->str())); - }); - if (ret < 0) { + evict_cache(ON_DEMAND_MIN_THRESHOLD, ON_DEMAND_MAX_THRESHOLD); + + bool allocated = + mm->allocate(req->value_length(), 1, + [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { + current_tcp_task_ = boost::intrusive_ptr(new PTR( + addr, req->value_length(), pool_idx, req->key()->str())); + }); + if (!allocated) { ERROR("Failed to allocate memory"); return OUT_OF_MEMORY; } @@ -276,6 +297,11 @@ int Client::tcp_payload_request(const TCPPayloadRequest *req) { } auto ptr = it->second; + // move ptr to the end of lru_queue + lru_queue.erase(ptr->lru_it); + lru_queue.push_back(ptr); + ptr->lru_it = --lru_queue.end(); + uint32_t *header_buf = (uint32_t *)malloc(sizeof(uint32_t) * 2); header_buf[0] = FINISH; header_buf[1] = static_cast(ptr->size); @@ -411,7 +437,9 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { auto inflight_rdma_writes = (std::vector> *)wc.wr_id; for (auto ptr : *inflight_rdma_writes) { - kv_map[std::move(ptr->key)] = ptr; + kv_map[ptr->key] = ptr; + lru_queue.push_back(ptr); + ptr->lru_it = --lru_queue.end(); } delete inflight_rdma_writes; post_ack(FINISH); @@ -442,6 +470,15 @@ void add_mempool_completion(uv_work_t *req, int status) { delete req; } +void extend_mempool() { + if (global_config.auto_increase && mm->need_extend && !extend_in_flight) { + INFO("Extend another mempool"); + uv_work_t *req = new uv_work_t(); + uv_queue_work(loop, req, add_mempool, add_mempool_completion); + extend_in_flight = true; + } +} + int Client::prepare_recv_rdma_request(int buf_idx) { struct ibv_sge sge = {0}; struct ibv_recv_wr rwr = {0}; @@ -562,6 +599,7 @@ int Client::write_rdma_cache(const RemoteMetaRequest *remote_meta_req) { auto *inflight_rdma_writes = new std::vector>; inflight_rdma_writes->reserve(n); + evict_cache(ON_DEMAND_MIN_THRESHOLD, ON_DEMAND_MAX_THRESHOLD); int key_idx = 0; bool allocated = mm->allocate(block_size, n, [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { @@ -605,6 +643,13 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) { inflight_rdma_reads->push_back(ptr); } + // loop over inflight_rdma_reads to update lru_queue + for (auto ptr : *inflight_rdma_reads) { + lru_queue.erase(ptr->lru_it); + lru_queue.push_back(ptr); + ptr->lru_it = --lru_queue.end(); + } + // write to remote address data from local address perform_batch_rdma(remote_meta_req, inflight_rdma_reads, IBV_WR_RDMA_WRITE); return 0; @@ -954,7 +999,11 @@ int Client::get_match_last_index(const GetMatchLastIndexRequest *request) { int Client::delete_keys(const DeleteKeysRequest *request) { int count = 0; for (const auto *key : *request->keys()) { - if (kv_map.erase(key->str()) == 1) { + auto it = kv_map.find(key->str()); + if (it != kv_map.end()) { + auto ptr = it->second; + kv_map.erase(it); + lru_queue.erase(ptr->lru_it); count++; } } @@ -1079,7 +1128,12 @@ void on_read(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) { offset += to_copy; if (client->bytes_read_ == client->expected_bytes_) { auto ptr = client->current_tcp_task_; - kv_map[std::move(ptr->key)] = ptr; + kv_map[ptr->key] = ptr; + + // put the ptr into lru queue + lru_queue.push_back(ptr); + ptr->lru_it = --lru_queue.end(); + client->current_tcp_task_.reset(); client->send_resp(FINISH, NULL, 0); client->reset_client_read_state(); @@ -1124,10 +1178,6 @@ int register_server(unsigned long loop_ptr, server_config_t config) { signal(SIGFPE, signal_handler); signal(SIGILL, signal_handler); - // verification - assert(config.num_stream > 0 && - (config.num_stream == 1 || config.num_stream == 2 || config.num_stream == 4)); - global_config = config; loop = uv_default_loop(); diff --git a/src/infinistore.h b/src/infinistore.h index 83a9074..c7fd476 100644 --- a/src/infinistore.h +++ b/src/infinistore.h @@ -2,7 +2,7 @@ #define INFINISTORE_H #include -#include +#include #include "config.h" #include "log.h" @@ -33,7 +33,7 @@ class PTR : public IntrusivePtrTarget { size_t size; int pool_idx; std::string key; - std::deque>::iterator lru_it; + std::list>::iterator lru_it; PTR(void *ptr, size_t size, int pool_idx, const std::string &key) : ptr(ptr), size(size), pool_idx(pool_idx), key(key) {} ~PTR() { @@ -48,6 +48,7 @@ extern std::unordered_map> inflight_rdma_wr // global function to bind with python int register_server(unsigned long loop_ptr, server_config_t config); +void evict_cache(float min_threshold, float max_threshold); void purge_kv_map(); #endif \ No newline at end of file diff --git a/src/mempool.cpp b/src/mempool.cpp index a805484..774dd34 100644 --- a/src/mempool.cpp +++ b/src/mempool.cpp @@ -145,6 +145,7 @@ void MemoryPool::deallocate(void* ptr, size_t size) { } } last_search_position_ = 0; + allocated_blocks_ -= blocks_to_free; } void MM::add_mempool(struct ibv_pd* pd) { @@ -155,6 +156,16 @@ void MM::add_mempool(size_t pool_size, size_t block_size, struct ibv_pd* pd) { mempools_.push_back(new MemoryPool(pool_size, block_size, pd)); } +float MM::usage() { + size_t total_blocks = 0; + size_t allocated_blocks = 0; + for (auto pool : mempools_) { + total_blocks += pool->get_total_blocks(); + allocated_blocks += pool->get_allocated_blocks(); + } + return (float)allocated_blocks / total_blocks; +} + bool MM::allocate(size_t size, size_t n, AllocationCallback callback) { bool allocated = false; size_t mempool_cnt = mempools_.size(); @@ -173,10 +184,15 @@ bool MM::allocate(size_t size, size_t n, AllocationCallback callback) { "Mempool Count: {}, Pool idx: {}, Total blocks: {}, allocated blocks: {}, block usage: " "{}%", mempool_cnt, i, total_blocks, allocated_blocks, 100 * allocated_blocks / total_blocks); + if (i == mempools_.size() - 1 && (float)allocated_blocks / total_blocks > BLOCK_USAGE_RATIO) { need_extend = true; } + else { + need_extend = false; + } + if (n == 0) { allocated = true; break; diff --git a/src/mempool.h b/src/mempool.h index f2dd801..34609d3 100644 --- a/src/mempool.h +++ b/src/mempool.h @@ -65,6 +65,7 @@ class MM { void add_mempool(size_t pool_size, size_t block_size, struct ibv_pd* pd); bool allocate(size_t size, size_t n, AllocationCallback callback); void deallocate(void* ptr, size_t size, int pool_idx); + float usage(); uint32_t get_lkey(int pool_idx) const { assert(pool_idx >= 0 && (size_t)pool_idx < mempools_.size()); return mempools_[pool_idx]->get_lkey(); diff --git a/src/pybind.cpp b/src/pybind.cpp index 8b0d15a..5a9c627 100644 --- a/src/pybind.cpp +++ b/src/pybind.cpp @@ -106,13 +106,13 @@ PYBIND11_MODULE(_infinistore, m) { .def_readwrite("link_type", &ServerConfig::link_type) .def_readwrite("prealloc_size", &ServerConfig::prealloc_size) .def_readwrite("minimal_allocate_size", &ServerConfig::minimal_allocate_size) - .def_readwrite("num_stream", &ServerConfig::num_stream) .def_readwrite("auto_increase", &ServerConfig::auto_increase); m.def( "purge_kv_map", []() { kv_map.clear(); }, "purge kv map"); m.def( "get_kvmap_len", []() { return kv_map.size(); }, "get kv map size"); m.def("register_server", ®ister_server, "register the server"); + m.def("evict_cache", &evict_cache, "evict the mempool"); // //both side m.def("log_msg", &log_msg, "log"); From 6b29bec732573445663522456c8778c4d1a720b5 Mon Sep 17 00:00:00 2001 From: thesues Date: Fri, 14 Mar 2025 06:50:47 +0000 Subject: [PATCH 06/13] Example: add tcp_client.py --- infinistore/example/tcp_client.py | 59 +++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 infinistore/example/tcp_client.py diff --git a/infinistore/example/tcp_client.py b/infinistore/example/tcp_client.py new file mode 100644 index 0000000..6775ed9 --- /dev/null +++ b/infinistore/example/tcp_client.py @@ -0,0 +1,59 @@ +import infinistore +import uuid +import ctypes +import time + + +def generate_uuid(): + return str(uuid.uuid4()) + + +def get_ptr(mv: memoryview): + return ctypes.addressof(ctypes.c_char.from_buffer(mv)) + + +config = infinistore.ClientConfig( + host_addr="127.0.0.1", + service_port=12345, + log_level="warning", + connection_type=infinistore.TYPE_TCP, + ib_port=1, + link_type=infinistore.LINK_ETHERNET, + dev_name="mlx5_0", +) + + +def main(): + try: + conn = infinistore.InfinityConnection(config) + conn.connect() + key = generate_uuid() + + size = 128 * 1024 + src = bytearray(size) + # dst = memoryview(bytearray(100)) + for i in range(1000): + src[i] = i % 200 + + now = time.time() + N = 1000 + for i in range(N): + conn.tcp_write_cache(key + str(i), get_ptr(src), len(src)) + print("TCP write time taken: ", time.time() - now) + + now = time.time() + ret = [] + for i in range(N): + ret.append(conn.tcp_read_cache(key + str(i))) + print("TCP read Time taken: ", time.time() - now) + + assert len(ret) == len(src) + for i in range(len(src)): + assert ret[i] == src[i] + except Exception as e: + print(e) + finally: + conn.close() + + +main() From 5e9495438bda4e2d01627d9db6698477f17ecaca Mon Sep 17 00:00:00 2001 From: thesues Date: Mon, 17 Mar 2025 06:23:36 +0000 Subject: [PATCH 07/13] minor fix --- infinistore/lib.py | 8 ++++---- infinistore/server.py | 10 ++++++++-- src/infinistore.cpp | 2 +- src/mempool.cpp | 10 ---------- src/mempool.h | 10 +++++++++- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/infinistore/lib.py b/infinistore/lib.py index 11d14af..4be7e63 100644 --- a/infinistore/lib.py +++ b/infinistore/lib.py @@ -112,8 +112,8 @@ def __init__(self, **kwargs): self.prealloc_size = kwargs.get("prealloc_size", 16) self.minimal_allocate_size = kwargs.get("minimal_allocate_size", 64) self.auto_increase = kwargs.get("auto_increase", False) - self.evict_min_threshold = kwargs.get("evict_min_threshold", 0.1) - self.evict_max_threshold = kwargs.get("evict_max_threshold", 0.2) + self.evict_min_threshold = kwargs.get("evict_min_threshold", 0.6) + self.evict_max_threshold = kwargs.get("evict_max_threshold", 0.8) self.evict_interval = kwargs.get("evict_interval", 5) def __repr__(self): @@ -231,9 +231,9 @@ def evict_cache(min_threshold: float, max_threshold: float): if min_threshold >= max_threshold: raise Exception("min_threshold should be less than max_threshold") if min_threshold > 1 or min_threshold < 0: - raise Exception("min_threshold should be in [0, 1]") + raise Exception("min_threshold should be in (0, 1)") if max_threshold > 1 or max_threshold < 0: - raise Exception("max_threshold should be in [0, 1]") + raise Exception("max_threshold should be in (0, 1)") return _infinistore.evict_cache(min_threshold, max_threshold) diff --git a/infinistore/server.py b/infinistore/server.py index 2777ffa..e97d866 100644 --- a/infinistore/server.py +++ b/infinistore/server.py @@ -147,7 +147,7 @@ def prevent_oom(): f.write("-1000") -async def periodic_evict(min_threshold=0.6, max_threshold=0.9, interval=5): +async def periodic_evict(min_threshold: float, max_threshold: float, interval: int): while True: evict_cache(min_threshold, max_threshold) await asyncio.sleep(interval) @@ -181,7 +181,13 @@ def main(): register_server(loop, config) if args.enable_periodic_evict: - loop.create_task(periodic_evict()) + loop.create_task( + periodic_evict( + config.evict_min_threshold, + config.evict_max_threshold, + config.evict_interval, + ) + ) prevent_oom() Logger.info("set oom_score_adj to -1000 to prevent OOM") diff --git a/src/infinistore.cpp b/src/infinistore.cpp index 5c543d2..2c16e82 100644 --- a/src/infinistore.cpp +++ b/src/infinistore.cpp @@ -233,8 +233,8 @@ void on_head_write(uv_write_t *req, int status) { ERROR("Write error {}", uv_strerror(status)); free(ctx->header_buf); delete ctx; - free(req); uv_close((uv_handle_t *)req->handle, on_close); + free(req); return; } diff --git a/src/mempool.cpp b/src/mempool.cpp index 774dd34..4017ff5 100644 --- a/src/mempool.cpp +++ b/src/mempool.cpp @@ -156,16 +156,6 @@ void MM::add_mempool(size_t pool_size, size_t block_size, struct ibv_pd* pd) { mempools_.push_back(new MemoryPool(pool_size, block_size, pd)); } -float MM::usage() { - size_t total_blocks = 0; - size_t allocated_blocks = 0; - for (auto pool : mempools_) { - total_blocks += pool->get_total_blocks(); - allocated_blocks += pool->get_allocated_blocks(); - } - return (float)allocated_blocks / total_blocks; -} - bool MM::allocate(size_t size, size_t n, AllocationCallback callback) { bool allocated = false; size_t mempool_cnt = mempools_.size(); diff --git a/src/mempool.h b/src/mempool.h index 34609d3..00d0156 100644 --- a/src/mempool.h +++ b/src/mempool.h @@ -65,7 +65,15 @@ class MM { void add_mempool(size_t pool_size, size_t block_size, struct ibv_pd* pd); bool allocate(size_t size, size_t n, AllocationCallback callback); void deallocate(void* ptr, size_t size, int pool_idx); - float usage(); + float usage() { + size_t total_blocks = 0; + size_t allocated_blocks = 0; + for (auto pool : mempools_) { + total_blocks += pool->get_total_blocks(); + allocated_blocks += pool->get_allocated_blocks(); + } + return (float)allocated_blocks / total_blocks; + } uint32_t get_lkey(int pool_idx) const { assert(pool_idx >= 0 && (size_t)pool_idx < mempools_.size()); return mempools_[pool_idx]->get_lkey(); From 7c2c971454f4cf788ca122bdbd238bd22c903291 Mon Sep 17 00:00:00 2001 From: jxp Date: Wed, 19 Mar 2025 10:43:59 -0700 Subject: [PATCH 08/13] add debug info and fix bug (#158) --- src/infinistore.cpp | 2 ++ src/libinfinistore.cpp | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/infinistore.cpp b/src/infinistore.cpp index 2c16e82..615b4b3 100644 --- a/src/infinistore.cpp +++ b/src/infinistore.cpp @@ -438,6 +438,7 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) { (std::vector> *)wc.wr_id; for (auto ptr : *inflight_rdma_writes) { kv_map[ptr->key] = ptr; + DEBUG("writing key done, {}", ptr->key); lru_queue.push_back(ptr); ptr->lru_it = --lru_queue.end(); } @@ -605,6 +606,7 @@ int Client::write_rdma_cache(const RemoteMetaRequest *remote_meta_req) { mm->allocate(block_size, n, [&](void *addr, uint32_t lkey, uint32_t rkey, int pool_idx) { const auto *key = remote_meta_req->keys()->Get(key_idx); auto ptr = boost::intrusive_ptr(new PTR(addr, block_size, pool_idx, key->str())); + DEBUG("writing key: {}", key->str()); inflight_rdma_writes->push_back(ptr); key_idx++; }); diff --git a/src/libinfinistore.cpp b/src/libinfinistore.cpp index fc234a2..6c2d677 100644 --- a/src/libinfinistore.cpp +++ b/src/libinfinistore.cpp @@ -327,7 +327,8 @@ void Connection::cq_handler() { DEBUG("RDMA write cache done: Received IMM, imm_data: {}", wc[i].imm_data); auto *info = reinterpret_cast(ptr); - info->callback(wc->imm_data); + info->callback(wc[i].imm_data); + DEBUG("RDMA_WRITE_ACK callback done"); delete info; break; } From 5f52fa3819a53cf144c1ce06c22445cef303335a Mon Sep 17 00:00:00 2001 From: dongmao zhang Date: Thu, 20 Mar 2025 16:04:57 -0700 Subject: [PATCH 09/13] fix: numpy can have too many threads(#156) --- infinistore/lib.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/infinistore/lib.py b/infinistore/lib.py index 4be7e63..7f5363c 100644 --- a/infinistore/lib.py +++ b/infinistore/lib.py @@ -7,6 +7,11 @@ import asyncio from functools import singledispatchmethod from typing import Optional, Union, List, Tuple + +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" import numpy as np From b043302d28a90c57e5c5442395962292b49e7d0d Mon Sep 17 00:00:00 2001 From: "henry.guo" Date: Mon, 24 Mar 2025 10:55:58 -0700 Subject: [PATCH 10/13] Add Dockerfile.build for a docker image to build infinistore project --- Dockerfile.build | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/Dockerfile.build b/Dockerfile.build index b1530d1..f0f50ed 100644 --- a/Dockerfile.build +++ b/Dockerfile.build @@ -5,30 +5,37 @@ RUN yum -y install rdma-core-devel libuv-devel RUN dnf clean all RUN dnf makecache -# build spdlog +# build spdlog v1.15.1 WORKDIR /tmp -RUN git clone --branch v1.15.1 --recurse-submodules https://github.com/gabime/spdlog.git +RUN git clone --recurse-submodules https://github.com/gabime/spdlog.git && \ + cd spdlog && \ + git checkout f355b3d && \ + git submodule update --recursive WORKDIR /tmp/spdlog RUN cmake -G "Unix Makefiles" && \ - make && \ + make -j8 && \ make install RUN rm -rf /tmp/spdlog -# build fmt +# build fmt 11.1.3 WORKDIR /tmp -RUN git clone --branch 11.1.3 https://github.com/fmtlib/fmt.git +RUN git clone https://github.com/fmtlib/fmt.git && \ + cd fmt && \ + git checkout 9cf9f38 WORKDIR /tmp/fmt RUN cmake -G "Unix Makefiles" && \ - make && \ + make -j8 && \ make install RUN rm -rf /tmp/fmt -# build flatbuffer +# build flatbuffer v25.2.10 WORKDIR /tmp -RUN git clone --branch v25.2.10 https://github.com/google/flatbuffers.git +RUN git clone https://github.com/google/flatbuffers.git && \ + cd flatbuffers && \ + git checkout 1c51462 WORKDIR /tmp/flatbuffers RUN cmake -G "Unix Makefiles" && \ - make && \ + make -j8 && \ make install ENV PATH=/usr/local/flatbuffers/bin:$PATH From 783d182917971c49982b99676bd0d75d21b08971 Mon Sep 17 00:00:00 2001 From: "henry.guo" Date: Mon, 24 Mar 2025 10:55:58 -0700 Subject: [PATCH 11/13] Add Dockerfile.build for a docker image to build infinistore project --- Dockerfile.build | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.build b/Dockerfile.build index f0f50ed..c636a6e 100644 --- a/Dockerfile.build +++ b/Dockerfile.build @@ -32,7 +32,7 @@ RUN rm -rf /tmp/fmt WORKDIR /tmp RUN git clone https://github.com/google/flatbuffers.git && \ cd flatbuffers && \ - git checkout 1c51462 + git checkout 33e2d80 WORKDIR /tmp/flatbuffers RUN cmake -G "Unix Makefiles" && \ make -j8 && \ From 1149137a72933f11ba9bc6ea402165807a9c87d4 Mon Sep 17 00:00:00 2001 From: "henry.guo" Date: Mon, 24 Mar 2025 11:28:42 -0700 Subject: [PATCH 12/13] add -j8 and use commit SHA instead of tag --- Dockerfile.build | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.build b/Dockerfile.build index c636a6e..f967f69 100644 --- a/Dockerfile.build +++ b/Dockerfile.build @@ -28,7 +28,7 @@ RUN cmake -G "Unix Makefiles" && \ make install RUN rm -rf /tmp/fmt -# build flatbuffer v25.2.10 +# build flatbuffer WORKDIR /tmp RUN git clone https://github.com/google/flatbuffers.git && \ cd flatbuffers && \ From 3e672de3c42958023750f56c5b908d3466a2a189 Mon Sep 17 00:00:00 2001 From: "henry.guo" Date: Mon, 7 Apr 2025 22:19:11 -0700 Subject: [PATCH 13/13] Infinistore service in cluster-mode --- README.md | 72 +++++ infinistore/__init__.py | 7 + infinistore/cluster_mgr.py | 450 ++++++++++++++++++++++++++ infinistore/example/cluster_client.py | 81 +++++ infinistore/lib.py | 2 + infinistore/server.py | 83 ++++- 6 files changed, 691 insertions(+), 4 deletions(-) create mode 100644 infinistore/cluster_mgr.py create mode 100644 infinistore/example/cluster_client.py diff --git a/README.md b/README.md index 8a2a358..107e26d 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,78 @@ Check the following example code to run an InfiniStore client: * ```infinistore/example/client_async.py``` * ```infinistore/example/client_async_single.py``` +## Run infinistore server in cluster mode + +1. **Start consul cluster** +Infinistore leverages consul to manage the membership of service nodes, so you must setup a consul cluster in order to run infinistore in cluster mode. + + a. Run a consul cluster with a single node + +```shell + docker run -d \ + --name=consul \ + --network=host \ + -e CONSUL_BIND_INTERFACE=eth0 \ + -e CONSUL_CLIENT_ADDR=0.0.0.0 \ + -p 8500:8500 \ + hashicorp/consul + + b. Start a consul cluster with three nodes +# On the first node +```shell + docker run -d \ + --name consul-node1 \ + -p 8301:8301 -p 8301:8301/udp \ + -p 8302:8302 -p 8302:8302/udp \ + -p 8500:8500 -p 8300:8300 \ + -advertise {ip of first node} \ + -server \ + -bootstrap-expect=3 \ + hashicorp/consul + +# On the second node +```shell + docker run -d \ + --name consul-node2 \ + -p 8301:8301 -p 8301:8301/udp \ + -p 8302:8302 -p 8302:8302/udp \ + -p 8500:8500 -p 8300:8300 \ + -advertise {ip of the second node} \ + -server \ + hashicorp/consul \ + -join consul-node1 + +# On the third node +```shell + docker run -d \ + --name consul-node3 \ + -p 8301:8301 -p 8301:8301/udp \ + -p 8302:8302 -p 8302:8302/udp \ + -p 8500:8500 -p 8300:8300 \ + -advertise {ip of the third node} \ + -server \ + hashicorp/consul \ + -join consul-node1 + +2. **Start InfiniStore Server** +For RDMA, the followings are the command to start the infinistore store on two nodes(you can add more nodes) + +# On the first node(may be different from consul cluster node) +```shell + infinistore --service-port 12345 \ + --dev-name mlx5_0 --link-type Ethernet \ + --manage-port 8081 --bootstrap-ip {ip of one of the consul cluster nodes}:8500 \ + --service-id infinistore1 --cluster-mode \ + --host {ip of the current host} + +# On the second node(may be different from consul cluster node) +```shell + infinistore --service-port 12345 \ + --dev-name mlx5_0 --link-type Ethernet \ + --manage-port 8081 --bootstrap-ip {ip of one of the consul cluster nodes}:8500 \ + --service-id infinistore1 --cluster-mode \ + --host {ip of the current host} + ## Run Within a vLLM Cluster As illustrated in the previous section, InfiniStore enables different functionalities in a vLLM cluster: KV cache transfer between prefill nodes and decoding nodes, extended KV cache pool, cross-node KV cache reuse, etc. diff --git a/infinistore/__init__.py b/infinistore/__init__.py index 46bcd6c..16336c8 100644 --- a/infinistore/__init__.py +++ b/infinistore/__init__.py @@ -14,6 +14,11 @@ InfiniStoreKeyNotFound, evict_cache, ) +from .cluster_mgr import ( + ConsulClusterMgr, + NoClusterMgr, +) + __all__ = [ "InfinityConnection", @@ -30,4 +35,6 @@ "InfiniStoreException", "InfiniStoreKeyNotFound", "evict_cache", + "ConsulClusterMgr", + "NoClusterMgr", ] diff --git a/infinistore/cluster_mgr.py b/infinistore/cluster_mgr.py new file mode 100644 index 0000000..a9a6364 --- /dev/null +++ b/infinistore/cluster_mgr.py @@ -0,0 +1,450 @@ +from infinistore import ( + ClientConfig, + Logger, + InfinityConnection, +) +import infinistore +import asyncio +import hashlib +from consul import Consul +from typing import Dict +import json +import requests +from requests.exceptions import HTTPError +from http import HTTPStatus +from dataclasses import dataclass + + +__all__ = ["ConsulClusterMgr", "NoClusterMgr"] + + +# The consistent hashing function which always hashes the same string to same integer +def sha256_hash(key: str) -> int: + # Create a sha256 hash object + sha256 = hashlib.sha256() + # Update the hash object with the string (encode to bytes) + sha256.update(key.encode("utf-8")) + hex_digest = sha256.hexdigest() + return int(hex_digest, 16) + + +@dataclass +class ServiceNode: + host: str + port: int + manage_port: int + conn: InfinityConnection + + +class ClusterMgrBase: + def __init__( + self, + bootstrap_address: str, + cluster_mode: bool = True, + service_manage_port: int = 8080, + refresh_interval=10, + ): + """ + Args: + bootstrap_address (str): The initial address in ip:port format used to query cluster information + cluster_mode (bool): whether the infinistore service runs in cluster mode (requires a consul cluster) + service_manage_port (int): the port which service uses to provide management functionalites + """ + self.bootstrap_ip, self.bootstrap_port = bootstrap_address.split(":") + self.cluster_nodes = [bootstrap_address] + self.cluster_mode = cluster_mode + self.service_manage_port = service_manage_port + self.service_nodes: Dict[str, ServiceNode] = {} + self.refresh_interval = refresh_interval + + def get_cluster_info(self, cluster_node_ip: str) -> list[str]: + """ + The function get the current alive cluster nodes in the cluster. One of the nodes will + be chosen to send request to + + Args: + cluster_node_ip (str): The node ip to query + + Returns: + list[str]: The list of addresses(ip:port) of the alive nodes in the cluster + """ + pass + + def get_service_config(self, service_host: str, service_manage_port: int) -> dict: + """ + The function retrieves the service config parameters + Args: + service_host (str): The host(ip) where you can query the service config from + service_manage_port (int): the port number(may be different with service port) of the service Web APIs + """ + # Default values for insfinistore server config parameters + conn_type = infinistore.TYPE_RDMA + link_type = infinistore.LINK_ETHERNET + dev_name = "mlx5_0" + ib_port = 1 + service_port = 9988 + manage_port = 8080 + + # The infinistore server must implement the API to provide the running parameters + # TODO: Alternative way is registering the parameters to consul cluster, but it + # doesn't work for the case non-cluster setup of infinistore server + + url = f"http://{service_host}:{service_manage_port}/service/config" + with requests.get(url=url) as resp: + if resp.status_code == HTTPStatus.OK: + json_data = json.loads(resp.json()) + manage_port = json_data["manage_port"] + conn_type = json_data["connection_type"] + link_type = json_data["link_type"] + dev_name = json_data["dev_name"] + ib_port = json_data["ib_port"] + service_port = int(json_data["service_port"]) + + return { + "manage_port": manage_port, + "connection_type": conn_type, + "link_type": link_type, + "dev_name": dev_name, + "ib_port": ib_port, + "service_port": service_port, + } + + def refresh_service_nodes(self, service_name: str = "infinistore") -> bool: + """ + The function refresh the alive nodes which have infinistore servers running + Currently only infinistore service is supported(tested) + """ + pass + + def register_service_node( + self, + service_id: str = None, + service_name: str = "infinistore", + service_host: str = "", + service_port: int = 12345, + service_manage_port: int = 8080, + check: dict = None, + ) -> bool: + """ + The function is called by a service node to register itself to the cluster + service_id is uniquely identify a running instance of the service + + Args: + service_id (str): The unique ID of the service instance + service_name (str): str="infinistore", + service_host (str): IP address of the host where the server is running on + service_port (int): The service port which provides domain APIs + service_manage_port (int): The port number which provides management APIs + check:dict check is a dict struct which contains (http|tcp|script and interval fields) + Returns: + bool: If the register success or exists, return true. Otherwise return false + """ + pass + + def deregister_service(self, service_id: str = None): + """ + The function is called to deregister a service id + + Args: + service_id (str): The unique ID of the service instance + Returns: + bool: If the deregister success, return true. Otherwise return false + """ + pass + + def refresh_cluster(self): + """ + The function refresh the alive nodes of the cluster + """ + # If not in cluster mode, do nothing + if not self.cluster_mode: + return + for node_ip in self.cluster_nodes: + try: + updated_nodes = self.get_cluster_info(node_ip) + if len(updated_nodes) != 0: + self.cluster_nodes = updated_nodes + # A non-empty list indicates a working node, so no need to query further + break + except Exception: + Logger.warn(f"Cannot refresh cluster info from {node_ip}") + # Check next node if something wrong with this node + continue + + async def refresh_task(self): + """ + Task function to refresh cluster periodically + """ + loop = asyncio.get_running_loop() + while True: + await loop.run_in_executor(None, self.refresh_cluster) + await loop.run_in_executor(None, self.refresh_service_nodes) + await asyncio.sleep(self.refresh_interval) + + def setup_connection( + self, service_host: str, service_port: int, service_manage_port: int + ) -> bool: + """ + The function setup a connection to an infinistore service instance + Args: + service_host (str): The host(ip) to connect to + service_port (int): The port number the infinistore service is running at + service_manage_port (int): The port number the infinistore web server is running at + """ + try: + service_config = self.get_service_config( + service_host=service_host, service_manage_port=service_manage_port + ) + except Exception as ex: + Logger.warn( + f"Cannot get service config for {service_host}:{service_port}, exception: {ex} " + ) + return False + + config = ClientConfig( + host_addr=service_host, + service_port=int(service_port), + log_level="info", + connection_type=service_config["connection_type"], + ib_port=service_config["ib_port"], + link_type=service_config["link_type"], + dev_name=service_config["dev_name"], + ) + service_node = ServiceNode( + host=service_host, + port=service_port, + manage_port=service_config["manage_port"], + conn=infinistore.InfinityConnection(config), + ) + service_key = f"{service_host}:{service_port}" + self.service_nodes[service_key] = service_node + return True + + def get_connection(self, key: str = None) -> InfinityConnection: + """ + The function chooses an infinistore service connection + based upon a query key. If no key is specified, return the first + available connection + + Args: + key (str, optional): The key to choose service node + + Returns: + InfinityConnection: The connection to infinistore server node + """ + if len(self.service_nodes) == 0: + Logger.warn( + "There are no live nodes in the cluster, forgot to register the service node?" + ) + return None + + # Default to choose the first one + k = 0 if key is None else sha256_hash(key) % len(self.service_nodes) + + # Retrieve the service connection based upon the service address + keys = list(self.service_nodes.keys()) + service_node = self.service_nodes[keys[k]] + if service_node.conn is None: + ok = self.setup_connection( + service_host=service_node.host, + service_port=service_node.port, + service_manage_port=service_node.manage_port, + ) + if not ok: + return None + + assert self.service_nodes[keys[k]].conn is not None + return self.service_nodes[keys[k]].conn + + +class ConsulClusterMgr(ClusterMgrBase): + def __init__(self, bootstrap_address: str, service_manage_port: int = 8080): + super().__init__( + bootstrap_address=bootstrap_address, service_manage_port=service_manage_port + ) + + def get_consul(self, cluster_node_ip: str) -> Consul: + consul_ip, consul_port = cluster_node_ip.split(":") + return Consul(host=consul_ip, port=consul_port) + + def get_cluster_info(self, cluster_node_ip: str) -> list[str]: + updated_cluster_nodes: list[str] = [] + consul = self.get_consul(cluster_node_ip) + try: + members = consul.agent.members() + for member in members: + # member['Port'] is the port which the consul agents communicate + # not the port which can be queries for members, so change it to bootstrap port + updated_cluster_nodes.append(f"{member['Addr']}:{self.bootstrap_port}") + except Exception as ex: + Logger.error( + f"Could not get cluster info from {cluster_node_ip}, exception: {ex}" + ) + + return updated_cluster_nodes + + def get_service_status(self, service_name: str = "infinistore"): + # Deregister the service nodes which are not healthy + failed_services = [] + live_services = [] + for cluster_node_ip in self.cluster_nodes: + consul = self.get_consul(cluster_node_ip) + _, service_checks = consul.health.service(service_name) + + # We got a list of service nodes, no need to try further + if len(service_checks) > 0: + break + + if service_checks: + for service_check in service_checks: + service = service_check["Service"] + if service["Service"] != service_name: + continue + for i in range(len(service_check["Checks"])): + check = service_check["Checks"][i] + if check["ServiceName"] != service_name: + continue + if check["Status"] == "critical": + self.deregister_service(service["ID"]) + failed_services.append(service) + else: + live_services.append(service) + return live_services, failed_services + + def refresh_service_nodes(self, service_name: str = "infinistore"): + refresh_services = {} + Logger.info("Refresh service nodes for {service_name}...") + # Get the registered service nodes + for cluster_node_ip in self.cluster_nodes: + consul = self.get_consul(cluster_node_ip) + _, registered_services = consul.catalog.service(service_name) + # We got a list of service nodes, no need to try further + if len(registered_services) > 0: + break + + # Get the service_manage_port (in tags) and put them into dict key service_host:service_port + for service in registered_services: + key = f"{service['ServiceAddress']}:{service['ServicePort']}" + for tag in service["ServiceTags"]: + if tag.startswith("service_manage_port"): + service_manage_port = tag.split("=")[1] + refresh_services[key] = service_manage_port + + # Remove the services which are not in the live node list + # from service_nodes + for service_key in self.service_nodes: + if service_key not in refresh_services: + service_node = self.service_nodes.pop(service_key) + if service_node.conn is not None: + service_node.conn.close() + + # Add the new services(which are not in the current service node list) + for s in refresh_services: + # We don't support update operation for now. + Logger.info(f"Service node {s} added") + if s in self.service_nodes: + continue + + service_host, service_port = s.split(":") + service_node = ServiceNode( + host=service_host, + port=service_port, + manage_port=refresh_services[s], + conn=None, + ) + self.service_nodes[s] = service_node + + def register_service_node( + self, + service_id: str = "infinistore", + service_name: str = "infinistore", + service_host: str = "", + service_port: int = 12345, + service_manage_port: int = 8080, + check: dict = None, + ) -> bool: + ret = True + try: + # Create a Consul client + consul = self.get_consul(self.cluster_nodes[0]) + + # Register the service with Consul + consul.agent.service.register( + name=service_name, + service_id=service_id, + address=service_host, + port=service_port, + tags=[f"service_manage_port={service_manage_port}"], + check={ + "http": check["http"], + "interval": check["interval"], + "timeout": "5s", + }, + timeout="5s", + ) + except HTTPError as ex: + # Check for 409 Conflict if the service already exists + if ex.response.status_code == 409: + Logger.warn(f"Service {service_name} already exists.") + else: + ret = False + Logger.error( + f"Error registering service {service_name}, exception: {ex}" + ) + + return ret + + def deregister_service(self, service_id: str): + ret = True + try: + # Create a Consul client + consul = self.get_consul(self.cluster_nodes[0]) + + # Deregister the service with Consul + consul.agent.service.deregister(service_id) + except HTTPError as ex: + ret = False + Logger.error(f"Error deregistering service {service_id}, exception: {ex}") + + return ret + + +class NoClusterMgr(ClusterMgrBase): + def __init__(self, bootstrap_address: str, service_manage_port: int = 8080): + super().__init__( + bootstrap_address, + cluster_mode=False, + service_manage_port=service_manage_port, + ) + + def refresh_service_nodes(self, service_name: str = "infinistore"): + # For NoCluster cluster, the service node address is + if len(self.service_nodes) > 0: + return + cluster_node_ip = self.cluster_nodes[0] + service_host, service_port = cluster_node_ip.split(":") + # Call service to get service running arguments + service_config = self.get_service_config( + service_host=service_host, service_manage_port=self.service_manage_port + ) + service_host = cluster_node_ip.split(":")[0] + # Setup a ClientConfig + config = ClientConfig( + host_addr=service_host, + service_port=service_config["service_port"], + log_level="info", + connection_type=service_config["connection_type"], + ib_port=service_config["ib_port"], + link_type=service_config["link_type"], + dev_name=service_config["dev_name"], + ) + service_port = service_config["service_port"] + service_key = f"{service_host}:{service_port}" + service_node = ServiceNode( + host=service_host, + port=service_port, + manage_port=service_config["manage_port"], + conn=infinistore.InfinityConnection(config), + ) + self.service_nodes[service_key] = service_node + diff --git a/infinistore/example/cluster_client.py b/infinistore/example/cluster_client.py new file mode 100644 index 0000000..96749bb --- /dev/null +++ b/infinistore/example/cluster_client.py @@ -0,0 +1,81 @@ +from infinistore import ( + ClientConfig, + Logger, + InfinityConnection, + ConsulClusterMgr, + NoClusterMgr, +) +import uvloop +import torch +import time +import asyncio +import threading + +def start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + +loop = asyncio.new_event_loop() +t = threading.Thread(target=start_loop, args=(loop,)) +t.start() + +def run(conn, src_device="cuda:0", dst_device="cuda:2"): + src_tensor = torch.tensor( + [i for i in range(4096)], device=src_device, dtype=torch.float32 + ) + conn.register_mr( + src_tensor.data_ptr(), src_tensor.numel() * src_tensor.element_size() + ) + keys_offsets = [("key1", 0), ("key2", 1024 * 4), ("key3", 2048 * 4)] + now = time.time() + + future = asyncio.run_coroutine_threadsafe( + conn.rdma_write_cache_async(keys_offsets, 1024 * 4, src_tensor.data_ptr()), loop + ) + + future.result() + print(f"write elapse time is {time.time() - now}") + + dst_tensor = torch.zeros(4096, device=dst_device, dtype=torch.float32) + conn.register_mr( + dst_tensor.data_ptr(), dst_tensor.numel() * dst_tensor.element_size() + ) + now = time.time() + + future = asyncio.run_coroutine_threadsafe( + conn.rdma_read_cache_async(keys_offsets, 1024 * 4, dst_tensor.data_ptr()), loop + ) + future.result() + print(f"read elapse time is {time.time() - now}") + + assert torch.equal(src_tensor[0:1024].cpu(), dst_tensor[0:1024].cpu()) + assert torch.equal(src_tensor[1024:2048].cpu(), dst_tensor[1024:2048].cpu()) + +if __name__ == "__main__": + cluster_mode = True + if cluster_mode: + cluster_mgr = ConsulClusterMgr(bootstrap_address="127.0.0.1:8500") + else: + cluster_mgr = NoClusterMgr(bootstrap_address="127.0.0.1:8081", service_manage_port=8081) + # Refresh cluster first to get the alive service nodes + cluster_mgr.refresh_service_nodes() + #asyncio.create_task(cluster_mgr.refresh_task()) + cluster_mgr.refresh_task() + + rdma_conn = cluster_mgr.get_connection() + + try: + rdma_conn.connect() + m = [ + ("cpu", "cuda:0"), + ("cuda:0", "cuda:1"), + ("cuda:0", "cpu"), + ("cpu", "cpu"), + ] + for src, dst in m: + print(f"rdma connection: {src} -> {dst}") + run(rdma_conn, src, dst) + finally: + rdma_conn.close() + loop.call_soon_threadsafe(loop.stop) + t.join() diff --git a/infinistore/lib.py b/infinistore/lib.py index f5e5bda..78a4605 100644 --- a/infinistore/lib.py +++ b/infinistore/lib.py @@ -97,6 +97,7 @@ class ServerConfig: ServerConfig is a configuration class for the server settings. Attributes: + host (str): The IP address of the server manage_port (int): The port used for management. Defaults to 0. service_port (int): The port used for service. Defaults to 0. log_level (str): The logging level. Defaults to "warning". @@ -112,6 +113,7 @@ class ServerConfig: def __init__(self, **kwargs): super().__init__() + self.host = kwargs.get("host", "127.0.0.1") self.manage_port = kwargs.get("manage_port", 0) self.service_port = kwargs.get("service_port", 0) self.log_level = kwargs.get("log_level", "warning") diff --git a/infinistore/server.py b/infinistore/server.py index 5c39bdd..f518461 100644 --- a/infinistore/server.py +++ b/infinistore/server.py @@ -6,6 +6,8 @@ ServerConfig, Logger, evict_cache, + ConsulClusterMgr, + TYPE_RDMA, ) import asyncio import uvloop @@ -14,12 +16,14 @@ import argparse import logging import os - +import json +from fastapi.responses import JSONResponse, Response # disable standard logging, we will use our own logger logging.disable(logging.INFO) app = FastAPI() +config: ServerConfig = None @app.post("/purge") @@ -39,7 +43,37 @@ async def kvmap_len(): return {"len": get_kvmap_len()} -def parse_args(): +@app.get("/health") +async def health(): + Logger.info(f"Health check received at {config.host}:{config.manage_port}...") + return Response(content="Healthy", status_code=200) + + +@app.get("/service/config") +async def service_config() -> Response: + """ + Query the configuration about how to connect to this server node + + Response: + { + "connection_type": "TYPE_RDMA", + "ib_port": "1", + "link_type": "LINK_ETHERNET", + "dev_name": "mlx5_0" + } + """ + service_conf = { + "manage_port": config.manage_port, + "service_port": config.service_port, + "connection_type": TYPE_RDMA, + "ib_port": config.ib_port, + "link_type": config.link_type, + "dev_name": config.dev_name, + } + return JSONResponse(status_code=200, content=service_conf) + + +def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--auto-increase", @@ -144,8 +178,31 @@ def parse_args(): help="hint gid index, default 1, -1 means no hint", type=int, ) + parser.add_argument( + "--cluster-mode", + required=False, + action="store_true", + help="Specify whether the infinistore server is in a cluster", + dest="cluster_mode", + ) + parser.add_argument( + "--bootstrap-ip", + required=False, + default="127.0.0.1:18080", + help="The bootstrap ip:port address to query for cluster information", + type=str, + dest="bootstrap_ip", + ) + parser.add_argument( + "--service-id", + required=True, + default="infinistore-standalone", + help="The service ID which is used by consul cluster to identify the service instance", + type=str, + dest="service_id", + ) - return parser.parse_args() + return parser def prevent_oom(): @@ -161,7 +218,10 @@ async def periodic_evict(min_threshold: float, max_threshold: float, interval: i def main(): - args = parse_args() + global config + + parser = get_parser() + args = parser.parse_args() config = ServerConfig( **vars(args), ) @@ -192,6 +252,21 @@ def main(): app, host="0.0.0.0", port=config.manage_port, loop="uvloop" ) + if args.cluster_mode: + # Initialized the cluster mgr with a bootstrap ip:port + health_url = f"http://{args.host}:{config.manage_port}/health" + cluster_mgr = ConsulClusterMgr(bootstrap_address=args.bootstrap_ip) + + # Note: service_id is required by consul cluster to uniquely identify a service instance + cluster_mgr.register_service_node( + service_host=args.host, + service_port=config.service_port, + service_id=args.service_id, + service_manage_port=config.manage_port, + check={"http": health_url, "interval": "5s"}, + ) + loop.create_task(cluster_mgr.refresh_task()) + server = uvicorn.Server(http_config) Logger.warn("server started")