-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrust-ext-example.py
More file actions
128 lines (97 loc) · 3.49 KB
/
rust-ext-example.py
File metadata and controls
128 lines (97 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import time
from typing import Iterator
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import numpy as np
import msgpack
from tqdm import tqdm
from loguru import logger
import rust_ext
SIZE_ARRAY_DIM = 512
COUNT_PER_MSGPACK_UPPERBOUND = 50000
def iterate_msgpack(filename):
with open(filename, "rb") as handle:
unpacker = msgpack.Unpacker(handle)
try:
for item in unpacker:
yield item
except Exception as e:
logger.error(e)
def take_iter_py(iterator: Iterator[bytes], np_vectors: np.ndarray) -> int:
idx = 0
for bytes_vector in iterator:
try:
vector = np.frombuffer(bytes_vector, dtype=np.float32).reshape(
SIZE_ARRAY_DIM
)
np_vectors[idx] = vector
except ValueError:
print("Array size does not match!")
continue
idx += 1
return idx
def process_py(filepath):
t = time.time()
np_vectors = np.empty(
(COUNT_PER_MSGPACK_UPPERBOUND, SIZE_ARRAY_DIM), dtype=np.float32
)
count = take_iter_py(iterate_msgpack(filepath), np_vectors)
np_vectors = np_vectors[:count]
# print(f"{np_vectors[-1, :5]}")
# logger.debug(f"take_iter_py took {time.time() - t:.2f} seconds.")
return np_vectors # comment out if memory consumption too big
def process_rs(filepath):
t = time.time()
np_vectors = np.empty(
(COUNT_PER_MSGPACK_UPPERBOUND, SIZE_ARRAY_DIM), dtype=np.float32
)
count = rust_ext.take_iter(iterate_msgpack(filepath), np_vectors)
np_vectors = np_vectors[:count]
# print(f"{np_vectors[-1, :5]}")
# logger.debug(f"rust_ext.take_iter took {time.time() - t:.2f} seconds.")
return np_vectors # comment out if memory consumption too big
def process_rs_full(filepath):
t = time.time()
np_vectors = np.empty(
(COUNT_PER_MSGPACK_UPPERBOUND, SIZE_ARRAY_DIM), dtype=np.float32
)
count = rust_ext.take_filepath(filepath, np_vectors)
np_vectors = np_vectors[:count]
# print(f"{np_vectors[-1, :5]}")
# logger.debug(f"rust_ext.take_iter took {time.time() - t:.2f} seconds.")
return np_vectors # comment out if memory consumption too big
if __name__ == "__main__":
filepath = "bytes_vectors.msgpack"
N_tasks = 32
N_workers = 4
pool = ProcessPoolExecutor(max_workers=N_workers)
logger.info("With Python / process pool:")
futures = []
for _ in range(N_tasks):
futures.append(pool.submit(process_py, filepath))
for future in tqdm(futures):
_ = future.result()
logger.info("With Rust extension / process pool:")
futures = []
for _ in range(N_tasks):
futures.append(pool.submit(process_rs, filepath))
for future in tqdm(futures):
_ = future.result()
pool = ThreadPoolExecutor(max_workers=N_workers)
logger.info("With Python / thread pool:")
futures = []
for _ in range(N_tasks):
futures.append(pool.submit(process_py, filepath))
for future in tqdm(futures):
_ = future.result()
logger.info("With Rust extension / thread pool:")
futures = []
for _ in range(N_tasks):
futures.append(pool.submit(process_rs, filepath))
for future in tqdm(futures):
_ = future.result()
logger.info("With Rust extension - full / thread pool:")
futures = []
for _ in range(N_tasks):
futures.append(pool.submit(process_rs_full, filepath))
for future in tqdm(futures):
_ = future.result()