Skip to content

Commit 37ae0d3

Browse files
committed
wandb style
1 parent 0b91461 commit 37ae0d3

File tree

2 files changed

+97
-91
lines changed

2 files changed

+97
-91
lines changed

eval_protocol/human_id/__init__.py

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,85 +12,66 @@
1212
def generate_id(
1313
separator: str = "-",
1414
seed: int | float | str | bytes | bytearray | None = None,
15-
word_count: int = 5,
1615
index: int | None = None,
1716
) -> str:
1817
"""
19-
Generate a human readable ID
18+
Generate a human readable ID in format: adjective-noun-NN
2019
2120
:param separator: The string to use to separate words
2221
:param seed: The seed to use. The same seed will produce the same ID or index-based mapping
2322
:param index: Optional non-negative integer providing a 1:1 mapping to an ID.
2423
When provided, the mapping is deterministic and bijective for
2524
all integers in range [0, total_combinations).
26-
:param word_count: The number of words to use. Minimum of 3.
2725
:return: A human readable ID
2826
"""
29-
if word_count < 3:
30-
raise ValueError("word_count cannot be lower than 3")
3127

32-
# If a specific index is provided, use mixed-radix encoding into a fixed
33-
# sequence of parts to guarantee a bijection between integers and IDs.
34-
# The sequence cycles as: verb, adjective, noun, verb, adjective, noun, ...
28+
# If a specific index is provided, use it for deterministic generation
3529
if index is not None:
3630
if not isinstance(index, int) or index < 0:
3731
raise ValueError("index must be a non-negative integer if provided")
3832

3933
# Prepare category lists; if seed is provided, shuffle deterministically
40-
base_categories = [dictionary.verbs, dictionary.adjectives, dictionary.nouns]
4134
if seed is not None:
4235
rnd = random.Random(seed)
43-
categories = [tuple(rnd.sample(cat, len(cat))) for cat in base_categories]
36+
adjectives = tuple(rnd.sample(dictionary.adjectives, len(dictionary.adjectives)))
37+
nouns = tuple(rnd.sample(dictionary.nouns, len(dictionary.nouns)))
4438
else:
45-
categories = base_categories
46-
# Build the category order for the desired word_count
47-
ordered_categories = [categories[i % 3] for i in range(word_count)]
39+
adjectives = dictionary.adjectives
40+
nouns = dictionary.nouns
4841

49-
# Compute total number of combinations for this word_count
50-
radices = [len(cat) for cat in ordered_categories]
51-
total = num_combinations(word_count)
42+
# Calculate total combinations: adjectives * nouns * 100 (for 00-99)
43+
total = len(adjectives) * len(nouns) * 100
5244

5345
if index >= total:
54-
raise ValueError(f"index out of range for given word_count. Received {index}, max allowed is {total - 1}")
46+
raise ValueError(f"index out of range. Received {index}, max allowed is {total - 1}")
5547

56-
# Mixed-radix decomposition (least significant position is the last word)
57-
digits: list[int] = []
58-
remaining = index
59-
for base in reversed(radices):
60-
digits.append(remaining % base)
61-
remaining //= base
62-
digits.reverse()
48+
# Decompose index into adjective, noun, and number
49+
number = index % 100
50+
remaining = index // 100
51+
noun_idx = remaining % len(nouns)
52+
adj_idx = remaining // len(nouns)
6353

64-
words = [ordered_categories[pos][digits[pos]] for pos in range(word_count)]
65-
return separator.join(words)
54+
adjective = adjectives[adj_idx]
55+
noun = nouns[noun_idx]
6656

57+
return f"{adjective}{separator}{noun}{separator}{number:02d}"
58+
59+
# Random generation
6760
random_obj = system_random
6861
if seed is not None:
6962
random_obj = random.Random(seed)
7063

71-
parts = {dictionary.verbs: 1, dictionary.adjectives: 1, dictionary.nouns: 1}
72-
73-
for _ in range(3, word_count):
74-
parts[random_obj.choice(list(parts.keys()))] += 1
75-
76-
parts = itertools.chain.from_iterable(random_obj.sample(part, count) for part, count in parts.items())
64+
adjective = random_obj.choice(dictionary.adjectives)
65+
noun = random_obj.choice(dictionary.nouns)
66+
number = random_obj.randint(0, 99)
7767

78-
return separator.join(parts)
68+
return f"{adjective}{separator}{noun}{separator}{number:02d}"
7969

8070

81-
def num_combinations(word_count: int = 5) -> int:
71+
def num_combinations() -> int:
8272
"""
83-
Return the total number of unique IDs possible for the given word_count.
73+
Return the total number of unique IDs possible.
8474
85-
The sequence of categories cycles as: verb, adjective, noun, then repeats.
86-
This value can be used to mod an index when calling generate_id(index=...).
75+
Format uses adjective-noun-NN, so total = adjectives * nouns * 100.
8776
"""
88-
if word_count < 3:
89-
raise ValueError("word_count cannot be lower than 3")
90-
91-
categories = [dictionary.verbs, dictionary.adjectives, dictionary.nouns]
92-
radices = [len(categories[i % 3]) for i in range(word_count)]
93-
total = 1
94-
for r in radices:
95-
total *= r
96-
return total
77+
return len(dictionary.adjectives) * len(dictionary.nouns) * 100

tests/test_human_id.py

Lines changed: 70 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,66 +4,91 @@
44
from eval_protocol.human_id import generate_id, num_combinations
55

66

7-
def test_generate_id_index_basic_3_words():
8-
# index 0 maps to the first element of each category (verb, adjective, noun)
9-
assert generate_id(index=0, word_count=3) == "be-other-time"
10-
11-
# incrementing index advances the least-significant position (noun)
12-
assert generate_id(index=1, word_count=3) == "be-other-year"
13-
14-
# carry into the adjective when nouns wrap
15-
# index == len(nouns) => adjective advances by 1, noun resets
16-
# nouns length inferred by probing with large indices is brittle; instead, compute via reach
17-
# We know index=0 gives be-other-time, and index that produces adjective=new, noun=time should be reachable.
18-
# Derive by scanning forward until adjective changes to 'new'. This keeps test robust to dictionary size edits.
19-
base = generate_id(index=0, word_count=3)
20-
# Find the first index where adjective becomes 'new' and noun resets to 'time'
21-
target = None
22-
for i in range(1, 2000):
23-
cand = generate_id(index=i, word_count=3)
24-
if cand.startswith("be-new-time"):
25-
target = i
26-
break
27-
assert target is not None, "Expected to find carry into adjective within search bound"
28-
assert generate_id(index=target, word_count=3) == "be-new-time"
29-
30-
31-
def test_generate_id_index_word_count_cycle():
32-
# word_count cycles categories: verb, adj, noun, verb, adj, ...
33-
assert generate_id(index=0, word_count=5) == "be-other-time-be-other"
34-
# increment least-significant position (adj at position 5)
35-
assert generate_id(index=1, word_count=5) == "be-other-time-be-new"
36-
37-
38-
def test_generate_id_index_out_of_range_and_negative():
39-
# Use exported total combinations for clean boundary checks
40-
total = num_combinations(word_count=3)
7+
def test_generate_id_basic_format():
8+
"""Test that generate_id produces the expected adjective-noun-NN format"""
9+
id_str = generate_id(index=0)
10+
# Should match pattern: adjective-noun-NN where NN is 00-99
11+
assert re.match(r"^[a-z]+-[a-z]+-\d{2}$", id_str)
12+
13+
# Test a few specific indices to ensure deterministic behavior
14+
assert generate_id(index=0) == "other-time-00"
15+
assert generate_id(index=1) == "other-time-01"
16+
assert generate_id(index=99) == "other-time-99"
17+
assert generate_id(index=100) == "other-year-00"
18+
19+
20+
def test_generate_id_index_mapping():
21+
"""Test that index mapping works correctly"""
22+
# Test number cycling (0-99)
23+
for i in range(100):
24+
id_str = generate_id(index=i)
25+
expected_num = f"{i:02d}"
26+
assert id_str.endswith(f"-{expected_num}")
27+
assert id_str.startswith("other-time-")
28+
29+
# Test noun advancement after 100 numbers
30+
id_100 = generate_id(index=100)
31+
assert id_100.startswith("other-year-00")
32+
33+
# Test adjective advancement (after all nouns * 100)
34+
# This will depend on dictionary size, so let's test the pattern
35+
from eval_protocol.human_id import dictionary
36+
37+
nouns_count = len(dictionary.nouns)
38+
adjective_boundary = nouns_count * 100
39+
40+
id_at_boundary = generate_id(index=adjective_boundary)
41+
# Should have advanced to the next adjective
42+
assert not id_at_boundary.startswith("other-")
43+
44+
45+
def test_generate_id_index_out_of_range():
46+
"""Test that invalid indices raise appropriate errors"""
47+
total = num_combinations()
4148
assert total > 0
42-
# Last valid index
43-
generate_id(index=total - 1, word_count=3)
44-
# First invalid index
49+
50+
# Last valid index should work
51+
generate_id(index=total - 1)
52+
53+
# First invalid index should raise error
4554
with pytest.raises(ValueError):
46-
generate_id(index=total, word_count=3)
55+
generate_id(index=total)
4756

57+
# Negative index should raise error
4858
with pytest.raises(ValueError):
49-
generate_id(index=-1, word_count=3)
59+
generate_id(index=-1)
5060

5161

52-
def test_generate_id_seed_stability_and_compat():
53-
# Without index, same seed yields same id
62+
def test_generate_id_seed_stability():
63+
"""Test that same seed produces same ID"""
5464
a = generate_id(seed=1234)
5565
b = generate_id(seed=1234)
5666
assert a == b
5767

5868
# Without index, default produces separator '-' and at least 3 components
5969
c = generate_id()
60-
assert re.match(r"^[a-z]+(-[a-z]+){2,}$", c)
70+
71+
assert re.match(r"^[a-z]+-[a-z]+-\d{2}$", c)
6172

6273

63-
def test_generate_id_index_ignores_seed():
64-
# With index provided, seed should affect the mapping deterministically
74+
def test_generate_id_seed_with_index():
75+
"""Test that seed affects index-based generation deterministically"""
6576
x = generate_id(index=42, seed=1)
6677
y = generate_id(index=42, seed=999)
6778
z = generate_id(index=42, seed=1)
68-
assert x != y
79+
80+
# Same seed should produce same result
6981
assert x == z
82+
# Different seeds should produce different results
83+
assert x != y
84+
85+
# All should follow the correct format
86+
assert re.match(r"^[a-z]+-[a-z]+-\d{2}$", x)
87+
assert re.match(r"^[a-z]+-[a-z]+-\d{2}$", y)
88+
89+
90+
def test_generate_id_random_format():
91+
"""Test that random generation (no index) produces correct format"""
92+
for _ in range(10):
93+
id_str = generate_id()
94+
assert re.match(r"^[a-z]+-[a-z]+-\d{2}$", id_str)

0 commit comments

Comments
 (0)