-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_fuzz_cross_impl.py
More file actions
212 lines (164 loc) · 6.01 KB
/
test_fuzz_cross_impl.py
File metadata and controls
212 lines (164 loc) · 6.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "bip32",
# "hypothesis",
# "pytest",
# "verystable",
# ]
# ///
import sys
import logging
import time
from contextlib import contextmanager
import bip32 as py_bip32
from verystable import bip32 as vs_bip32
import pytest
from hypothesis import given, strategies as st, target, settings
from bindings import derive, derive_from_seed, b58_decode, b58_encode
log = logging.getLogger(__name__)
logging.basicConfig()
# Strategy for valid hex seeds (must be 128-512 bits)
valid_seeds = st.binary(min_size=16, max_size=64).map(lambda b: b.hex())
INVALID_KEY = '_'
@contextmanager
def timer(label):
start = time.perf_counter()
yield
target(time.perf_counter() - start, label=label)
def py_derive(seed_hex_str: str, bip32_path: str) -> str:
bip32 = py_bip32.BIP32.from_seed(bytes.fromhex(seed_hex_str))
try:
return bip32.get_xpriv_from_path(bip32_path)
except Exception:
return INVALID_KEY
def _path_str_to_ints(bip32_path) -> list[int] | None:
if not bip32_path.startswith('m'):
return None
path_ints = []
components = filter(None, bip32_path.lstrip('m').split('/'))
for comp in components:
if comp.endswith('h'):
path_ints.append(int(comp[:-1]) | vs_bip32.HARDENED_INDEX)
else:
path_ints.append(int(comp))
return path_ints
def verystable_derive(seed_hex_str: str, bip32_path: str) -> str:
bip32 = vs_bip32.BIP32.from_bytes(bytes.fromhex(seed_hex_str), True)
path_ints = _path_str_to_ints(bip32_path)
if path_ints is None:
return INVALID_KEY
elif not path_ints:
return bip32.serialize()
try:
derived, _ = bip32.derive(*path_ints)
return derived.serialize()
except Exception:
return INVALID_KEY
def py_xpub_derive(base58: str, bip32_path: str) -> str:
bip32 = py_bip32.BIP32.from_xpub(base58)
try:
return bip32.get_xpub_from_path(bip32_path)
except Exception:
return INVALID_KEY
def our_derive(hex_str, path) -> str:
try:
b32 = derive(hex_str, path)
return b32.serialize()
except Exception:
return INVALID_KEY
@st.composite
def py_compatible_bip32_paths(draw):
"""
Given python-bip32's too-large path issue, clamp the max_values that we can fuzz:
https://github.com/darosior/python-bip32/issues/46
"""
MAX_ALLOWED_DEPTH = 255
MAX_UNHARDENED_IDX = 2**31 - 1
depth = draw(st.integers(min_value=0, max_value=(MAX_ALLOWED_DEPTH + 3)))
path_parts = ["m"]
for _ in range(depth):
index = draw(st.integers(min_value=-2, max_value=MAX_UNHARDENED_IDX))
hardened = draw(st.booleans())
path_parts.append(f"{index}{"h" if hardened else ''}")
return "/".join(path_parts)
@st.composite
def bip32_paths(draw):
"""
Generate BIP32 paths with some out of bound values.
"""
MAX_ALLOWED_DEPTH = 255
MAX_UNHARDENED_IDX = 2**31 - 1
depth = draw(st.integers(min_value=0, max_value=(MAX_ALLOWED_DEPTH + 3)))
path_parts = ["m"]
for _ in range(depth):
index = draw(st.integers(min_value=-2, max_value=(MAX_UNHARDENED_IDX + 2)))
hardened = draw(st.booleans())
path_parts.append(f"{index}{"h" if hardened else ''}")
return "/".join(path_parts)
@given(seed_hex_str=valid_seeds, bip32_path=py_compatible_bip32_paths())
@settings(max_examples=2_000)
def test_versus_py(seed_hex_str, bip32_path):
"""
Compare implementations of BIP32 on a random seed and path.
"""
with timer('ours'):
ours = our_derive(seed_hex_str, bip32_path)
with timer('python-bip32'):
pys = py_derive(seed_hex_str, bip32_path)
assert ours == pys
@given(seedhex=valid_seeds, path=py_compatible_bip32_paths())
@settings(max_examples=2_000)
def test_versus_ourselves(seedhex, path):
"""
Ensure that our different derive functions work properly.
"""
seed = bytes.fromhex(seedhex)
from_seed = INVALID_KEY
try:
from_seed = derive_from_seed(seed, path).serialize()
except Exception:
pass
assert our_derive(seedhex, path) == from_seed
@given(seed_hex_str=valid_seeds, bip32_path=py_compatible_bip32_paths())
@settings(max_examples=100, deadline=5000) # verstable is slooooww, so allow 5s tests
def test_versus_vs(seed_hex_str, bip32_path):
"""
Since the verystable implemention is VERY slow (100x+), limit the number of cases.
"""
with timer('ours'):
ours = our_derive(seed_hex_str, bip32_path)
with timer('verystable'):
vs = verystable_derive(seed_hex_str, bip32_path)
assert ours == vs
@given(bip32_path=py_compatible_bip32_paths())
@settings(max_examples=200)
def test_xpub_impls(bip32_path):
xpub = 'xpub6ASuArnXKPbfEwhqN6e3mwBcDTgzisQN1wXN9BJcM47sSikHjJf3UFHKkNAWbWMiGj7Wf5uMash7SyYq527Hqck2AxYysAA7xmALppuCkwQ'
with timer('ours'):
ours = our_derive(xpub, bip32_path)
with timer('python-bip32'):
pys = py_xpub_derive(xpub, bip32_path)
assert ours == pys
@given(b58_data=st.binary(min_size=0, max_size=1000))
@settings(max_examples=1000)
def test_base58(b58_data: bytes):
if b58_data and len(b58_data) >= 2:
# TODO: figure out why the base58 impl is failing on example b':'
assert b58_decode(b58_encode(b58_data)) == b58_data
def test_base58_known_vectors():
cases = [
(bytes.fromhex(""), ""),
(bytes.fromhex("00"), "1"),
(bytes.fromhex("0000"), "11"),
(bytes.fromhex("68656c6c6f20776f726c64"), "StV1DL6CwTryKyV"),
(bytes.fromhex("0068656c6c6f20776f726c64"), "1StV1DL6CwTryKyV"),
(bytes.fromhex("000068656c6c6f20776f726c64"), "11StV1DL6CwTryKyV"),
]
for raw, encoded in cases:
if raw: # Skip empty input for encoding test
assert b58_encode(raw) == encoded
if encoded: # Skip empty input for decoding test
assert b58_decode(encoded) == raw
if __name__ == "__main__":
pytest.main([__file__, "-v", "--capture=no", "--hypothesis-show-statistics", "-x"] + sys.argv[1:])