-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathp2p_rdma_rc_read.py
More file actions
52 lines (42 loc) · 1.71 KB
/
p2p_rdma_rc_read.py
File metadata and controls
52 lines (42 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from dlslime import available_nic, RDMAEndpoint
devices = available_nic()
assert devices, "No RDMA devices."
mr_key = "buffer"
# Initialize RDMA endpoint
initiator = RDMAEndpoint(device_name=devices[0], ib_port=1, link_type="RoCE")
target = RDMAEndpoint(device_name=devices[-1], ib_port=1, link_type="RoCE")
# Register local GPU memory with RDMA subsystem
local_tensor = torch.zeros([16], device="cpu", dtype=torch.uint8)
handler = initiator.register_memory_region(
"kv",
local_tensor.data_ptr(),
int(local_tensor.storage_offset()),
local_tensor.numel() * local_tensor.itemsize,
)
remote_tensor = torch.ones([16], device="cpu", dtype=torch.uint8)
hremote = target.register_memory_region(
"kv",
remote_tensor.data_ptr(),
int(remote_tensor.storage_offset()),
remote_tensor.numel() * remote_tensor.itemsize,
)
# Simulate OOB Exchange: Target -> Initiator
# mr_info = target.get_local_mr_info("kv")
info = target.endpoint_info()
kv_info = info["mr_info"]["kv"]
hremote_on_initiator = initiator.register_remote_memory_region("kv", kv_info)
# Establish bidirectional RDMA connection:
# 1. Target connects to initiator's endpoint information
# 2. Initiator connects to target's endpoint information
# Note: Real-world scenarios typically use out-of-band exchange (e.g., via TCP)
target.connect(initiator.endpoint_info())
initiator.connect(target.endpoint_info())
print("Remote tensor after RDMA write:", remote_tensor)
slot = initiator.read([(handler, hremote_on_initiator, 0, 8, 8)], None)
slot.wait()
assert torch.all(local_tensor[:8] == 0)
assert torch.all(local_tensor[8:] == 1)
print("Remote tensor after RDMA write:", local_tensor)
del target, initiator
print("run rdma rc write example successful")