Skip to content

Commit bb7fe44

Browse files
author
艾瑞 (Ai)
committed
feat(rolling_hash): add Rabin-Karp rolling hash algorithm
- Implement rabin_karp(text, pattern) with type hints and docstrings - Supports empty pattern, Unicode, overlapping matches - Include comprehensive test suite (basic, edge cases, unicode, long patterns) - This algorithm provides O(n) average substring search Next steps: add more rolling hash variants (e.g., for plagiarism detection).
1 parent 68473af commit bb7fe44

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed

rolling_hash/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Rolling hash algorithms for string matching and similarity."""

rolling_hash/rabin_karp.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Rabin-Karp rolling hash algorithm for substring search.
2+
3+
Implements the classic Rabin-Karp algorithm using a rolling hash to find
4+
all occurrences of a pattern in a text in O(n) average time.
5+
6+
The algorithm uses a simple polynomial rolling hash with modulo prime to
7+
avoid overflow. It works well for ASCII/Unicode strings.
8+
9+
References:
10+
- Rabin, M. O., & Karp, R. M. (1987). Algorithms for pattern matching.
11+
"""
12+
from typing import List
13+
14+
15+
def rabin_karp(text: str, pattern: str) -> List[int]:
16+
"""Return starting indices of pattern in text using rolling hash.
17+
18+
Args:
19+
text: The text to search within.
20+
pattern: The pattern to find.
21+
22+
Returns:
23+
List of starting indices (0-based) where pattern occurs.
24+
25+
Example:
26+
>>> rabin_karp("abracadabra", "abra")
27+
[0, 7]
28+
"""
29+
# Edge cases
30+
if pattern == "":
31+
# By convention, empty pattern matches at each position plus one
32+
return list(range(len(text) + 1))
33+
if len(pattern) > len(text):
34+
return []
35+
36+
# Rolling hash parameters
37+
base = 256 # number of possible character values (ASCII/extended)
38+
prime = 101 # a small prime for modulus
39+
m, n = len(pattern), len(text)
40+
41+
# Precompute base^(m-1) mod prime for rolling removal
42+
h = 1
43+
for _ in range(m - 1):
44+
h = (h * base) % prime
45+
46+
# Compute initial hash values
47+
pattern_hash = 0
48+
window_hash = 0
49+
for i in range(m):
50+
pattern_hash = (base * pattern_hash + ord(pattern[i])) % prime
51+
window_hash = (base * window_hash + ord(text[i])) % prime
52+
53+
matches: List[int] = []
54+
# Slide the window over text
55+
for i in range(n - m + 1):
56+
if pattern_hash == window_hash:
57+
# Double-check to avoid hash collisions
58+
if text[i:i + m] == pattern:
59+
matches.append(i)
60+
if i < n - m:
61+
# Roll: remove leading char, add trailing char
62+
window_hash = (base * (window_hash - ord(text[i]) * h) + ord(text[i + m])) % prime
63+
if window_hash < 0:
64+
window_hash += prime
65+
return matches
66+
67+
68+
def demo() -> None:
69+
"""Run a simple demonstration."""
70+
text = "abracadabra"
71+
pattern = "abra"
72+
indices = rabin_karp(text, pattern)
73+
print(f"Pattern '{pattern}' found at positions: {indices}")
74+
75+
76+
if __name__ == "__main__":
77+
demo()

tests/test_rolling_hash.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Tests for rolling hash Rabin-Karp implementation."""
2+
import pytest
3+
from rolling_hash.rabin_karp import rabin_karp
4+
5+
6+
def test_basic_matches():
7+
assert rabin_karp("abracadabra", "abra") == [0, 7]
8+
assert rabin_karp("aaaaa", "aa") == [0, 1, 2, 3]
9+
assert rabin_karp("hello world", "world") == [6]
10+
11+
12+
def test_no_match():
13+
assert rabin_karp("abcdef", "gh") == []
14+
assert rabin_karp("abc", "abcd") == []
15+
16+
17+
def test_empty_pattern():
18+
# Empty pattern matches at every position (including end)
19+
assert rabin_karp("abc", "") == [0, 1, 2, 3]
20+
assert rabin_karp("", "") == [0]
21+
22+
23+
def test_single_character():
24+
assert rabin_karp("a", "a") == [0]
25+
assert rabin_karp("ab", "a") == [0]
26+
assert rabin_karp("ab", "b") == [1]
27+
28+
29+
def test_overlapping():
30+
text = "aaa"
31+
pattern = "aa"
32+
assert rabin_karp(text, pattern) == [0, 1]
33+
34+
35+
def test_case_sensitive():
36+
assert rabin_karp("ABCabc", "abc") == [3]
37+
assert rabin_karp("ABCabc", "ABC") == [0]
38+
39+
40+
def test_unicode():
41+
# Unicode characters
42+
assert rabin_karp("你好世界你好", "你好") == [0, 4]
43+
44+
45+
def test_long_pattern():
46+
text = "a" * 1000
47+
pattern = "a" * 100
48+
expected = list(range(0, 901))
49+
assert rabin_karp(text, pattern) == expected

0 commit comments

Comments
 (0)