Skip to content

Commit a8904fd

Browse files
committed
Enhance ConcurrentDictionary with get_locked and key_lock methods for improved thread safety; update README and tests accordingly. Bump version to 2.0.0.
1 parent cb4cef9 commit a8904fd

4 files changed

Lines changed: 155 additions & 10 deletions

File tree

README.md

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,41 @@ print(list(bag)) # [1, 2, 3, 4]
7272

7373
### ConcurrentDictionary
7474

75-
A thread-safe dictionary. For atomic compound updates, use `update_atomic`.
75+
A thread-safe dictionary. It has a few notable methods:
76+
77+
- assign_atomic()
78+
- get_locked()
79+
- update_atomic()
80+
81+
#### ConcurrentDictionary's `assign_atomic()`
82+
83+
Assigns a dictionary value under a key in a thread-safe way.
84+
While `dict["somekey"] = value` is allowed, it's best to use `assign_atomic()` for clarity of intent. Using normal assignment will work but raise a UserWarning.
85+
86+
87+
#### ConcurrentDictionary's `get_locked()`
88+
89+
When working with `ConcurrentDictionary`, you should use the `get_locked` method to safely read or update the value for a specific key in a multi-threaded environment. This ensures that only one thread can access or modify the value for a given key at a time, preventing race conditions.
7690

7791
```python
7892
from concurrent_collections import ConcurrentDictionary
7993

80-
d = ConcurrentDictionary({'x': 1})
81-
d['y'] = 2 # Simple assignment is thread-safe
82-
# For atomic updates:
83-
d.update_atomic('x', lambda v: v + 1)
84-
print(d['x']) # 2
94+
d = ConcurrentDictionary({'x': "some value" })
95+
96+
# Safely read and update the value for 'x'
97+
with d.get_locked('x') as value:
98+
# value is locked for this thread
99+
d['x'] = "new value"
100+
```
101+
102+
#### ConcurrentDictionary's `update_atomic()`
103+
104+
Performs a thread-safe, in-place update to an existing value under a key.
105+
106+
```python
107+
108+
d = ConcurrentDictionary({'x': 1 })
109+
d.update_atomic("x", lambda v: v + 1) # d now contains 2 under the 'x' key.
85110
```
86111

87112
### ConcurrentQueue

concurrent_collections/concurrent_dict.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import threading
2-
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Generic, Tuple
2+
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Generic, Tuple, ContextManager
33
import warnings
4-
import sys
54

65
T = TypeVar('T')
76
K = TypeVar('K')
@@ -21,6 +20,47 @@ class ConcurrentDictionary(Generic[K, V]):
2120
def __init__(self, *args: Any, **kwargs: Any) -> None:
2221
self._lock = threading.RLock()
2322
self._dict: Dict[K, V] = dict(*args, **kwargs) # type: ignore
23+
self._key_locks: Dict[K, threading.RLock] = {}
24+
25+
def _get_key_lock(self, key: K) -> threading.RLock:
26+
with self._lock:
27+
if key not in self._key_locks:
28+
self._key_locks[key] = threading.RLock()
29+
return self._key_locks[key]
30+
31+
class _KeyLockContext:
32+
def __init__(self, outer : "ConcurrentDictionary[K,V]", key: K):
33+
self._outer = outer
34+
self._key = key
35+
self._lock = outer._get_key_lock(key)
36+
37+
def __enter__(self) -> V:
38+
self._lock.acquire()
39+
return self._outer._dict[self._key]
40+
41+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
42+
self._lock.release()
43+
44+
def get_locked(self, key: K) -> ContextManager[V]:
45+
"""
46+
Context manager: lock the key, yield its value, unlock on exit.
47+
48+
Usage:
49+
with d.get_locked('x') as value:
50+
# safely read/update value for 'x'
51+
"""
52+
return self._KeyLockContext(self, key)
53+
54+
def key_lock(self, key: K):
55+
"""
56+
Context manager: lock the key, yield nothing, unlock on exit.
57+
58+
Usage:
59+
with d.key_lock('x'):
60+
# safely update d['x'] or perform multiple operations
61+
"""
62+
lock = self._get_key_lock(key)
63+
return lock
2464

2565
def __getitem__(self, key: K) -> V:
2666
with self._lock:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "concurrent_collections"
7-
version = "1.4.0"
7+
version = "2.0.0"
88
description = "A brief description of concurrent-collections"
99
authors = [{ name = "Alessio Lombardi", email = "work@alelom.com" }]
1010
license = { file = "MIT" }

tests/concurrent_dict_test.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,87 @@ def worker():
9595

9696
# No thread safety errors should occur
9797
assert not errors, f"Thread safety errors occurred: {errors}"
98-
98+
99+
def test_get_locked_context_manager_returns_value_and_locks():
100+
d: ConcurrentDictionary[str, int] = ConcurrentDictionary({'a': 1})
101+
with d.get_locked('a') as value:
102+
assert value == 1
103+
# Update inside lock
104+
d['a'] = value + 1
105+
assert d['a'] == 2
106+
107+
def test_get_locked_raises_keyerror_for_missing_key():
108+
d: ConcurrentDictionary[str, int] = ConcurrentDictionary()
109+
try:
110+
with d.get_locked('missing'):
111+
assert False, "Should raise KeyError for missing key"
112+
except KeyError:
113+
pass
114+
115+
def test_key_lock_context_manager_locks_and_unlocks():
116+
d: ConcurrentDictionary[str, int] = ConcurrentDictionary({'b': 10})
117+
lock = d.key_lock('b')
118+
# Should be a context manager (RLock)
119+
assert hasattr(lock, '__enter__') and hasattr(lock, '__exit__')
120+
with lock:
121+
d['b'] = 20
122+
assert d['b'] == 20
123+
124+
def test_get_locked_allows_nested_access():
125+
d: ConcurrentDictionary[str, int] = ConcurrentDictionary({'c': 5})
126+
with d.get_locked('c') as value:
127+
assert value == 5
128+
# Nested get_locked should not deadlock
129+
with d.get_locked('c') as value2:
130+
assert value2 == 5
131+
132+
def test_get_locked_thread_safety():
133+
d: ConcurrentDictionary[str, int] = ConcurrentDictionary({'x': 0})
134+
errors : List[Exception] = []
135+
136+
def worker():
137+
try:
138+
for _ in range(1000):
139+
with d.get_locked('x') as v:
140+
d['x'] = v + 1
141+
except Exception as e:
142+
errors.append(e)
143+
144+
threads = [threading.Thread(target=worker) for _ in range(4)]
145+
for t in threads:
146+
t.start()
147+
for t in threads:
148+
t.join()
149+
150+
assert not errors, f"Thread safety errors occurred: {errors}"
151+
assert d['x'] == 4000
152+
153+
def test_modify_without_get_locked_causes_race_condition():
154+
"""
155+
Demonstrates that modifying a value at a key without using get_locked
156+
can result in a race condition and incorrect result.
157+
"""
158+
d: ConcurrentDictionary[str, int] = ConcurrentDictionary({'x': 0})
159+
160+
def worker():
161+
# Not using get_locked or any locking
162+
for _ in range(1000):
163+
v = d['x'] # Read without lock
164+
# Simulate context switch
165+
time.sleep(0.000001)
166+
d['x'] = v + 1 # Write without lock
167+
168+
threads = [threading.Thread(target=worker) for _ in range(4)]
169+
for t in threads:
170+
t.start()
171+
for t in threads:
172+
t.join()
173+
174+
# The correct value should be 4000, but due to race conditions, it will likely be less
175+
assert d['x'] != 4000, (
176+
"Modifying without get_locked should result in incorrect value due to race conditions, "
177+
f"but got {d['x']}"
178+
)
99179

100180
if __name__ == "__main__":
101181
pytest.main([__file__])

0 commit comments

Comments
 (0)