Skip to content

Commit 0552cdb

Browse files
committed
Revise ConcurrentReferenceHashMap for @ConcurrencyLimit race condition
Closes gh-35788 See gh-35794
1 parent 721c40b commit 0552cdb

File tree

3 files changed

+200
-46
lines changed

3 files changed

+200
-46
lines changed

spring-context/src/main/java/org/springframework/resilience/annotation/ConcurrencyLimitBeanPostProcessor.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.lang.reflect.Method;
2020
import java.util.Map;
2121
import java.util.concurrent.ConcurrentHashMap;
22+
import java.util.concurrent.ConcurrentMap;
2223

2324
import org.aopalliance.intercept.MethodInterceptor;
2425
import org.aopalliance.intercept.MethodInvocation;
@@ -73,7 +74,7 @@ public void setEmbeddedValueResolver(StringValueResolver resolver) {
7374

7475
private class ConcurrencyLimitInterceptor implements MethodInterceptor {
7576

76-
private final Map<Object, ConcurrencyThrottleCache> cachePerInstance =
77+
private final ConcurrentMap<Object, ConcurrencyThrottleCache> cachePerInstance =
7778
new ConcurrentReferenceHashMap<>(16, ConcurrentReferenceHashMap.ReferenceType.WEAK);
7879

7980
@Override
@@ -87,8 +88,11 @@ private class ConcurrencyLimitInterceptor implements MethodInterceptor {
8788
}
8889
Assert.state(target != null, "Target must not be null");
8990

91+
// Build unique ConcurrencyThrottleCache instance per target object
9092
ConcurrencyThrottleCache cache = this.cachePerInstance.computeIfAbsent(target,
9193
k -> new ConcurrencyThrottleCache());
94+
95+
// Determine method-specific interceptor instance with isolated concurrency count
9296
MethodInterceptor interceptor = cache.methodInterceptors.get(method);
9397
if (interceptor == null) {
9498
synchronized (cache) {

spring-core/src/main/java/org/springframework/util/ConcurrentReferenceHashMap.java

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333
import java.util.concurrent.ConcurrentMap;
3434
import java.util.concurrent.atomic.AtomicInteger;
3535
import java.util.concurrent.locks.ReentrantLock;
36+
import java.util.function.BiFunction;
37+
import java.util.function.Function;
3638

3739
import org.jspecify.annotations.Nullable;
3840

3941
/**
40-
* A {@link ConcurrentHashMap} that uses {@link ReferenceType#SOFT soft} or
42+
* A {@link ConcurrentHashMap} variant that uses {@link ReferenceType#SOFT soft} or
4143
* {@linkplain ReferenceType#WEAK weak} references for both {@code keys} and {@code values}.
4244
*
4345
* <p>This class can be used as an alternative to
@@ -320,7 +322,7 @@ protected Boolean execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> e
320322
return false;
321323
}
322324
});
323-
return (Boolean.TRUE.equals(result));
325+
return Boolean.TRUE.equals(result);
324326
}
325327

326328
@Override
@@ -335,7 +337,7 @@ protected Boolean execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> e
335337
return false;
336338
}
337339
});
338-
return (Boolean.TRUE.equals(result));
340+
return Boolean.TRUE.equals(result);
339341
}
340342

341343
@Override
@@ -353,6 +355,114 @@ protected Boolean execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> e
353355
});
354356
}
355357

358+
@Override
359+
public @Nullable V computeIfAbsent(@Nullable K key, Function<@Nullable ? super K, @Nullable ? extends V> mappingFunction) {
360+
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
361+
@Override
362+
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
363+
if (entry != null) {
364+
return entry.getValue();
365+
}
366+
V value = mappingFunction.apply(key);
367+
// Add entry only if not null
368+
if (value != null) {
369+
Assert.state(entries != null, "No entries segment");
370+
entries.add(value);
371+
}
372+
return value;
373+
}
374+
});
375+
}
376+
377+
@Override
378+
public @Nullable V computeIfPresent(@Nullable K key, BiFunction<@Nullable ? super K, @Nullable ? super V, @Nullable ? extends V> remappingFunction) {
379+
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
380+
@Override
381+
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
382+
if (entry != null) {
383+
V oldValue = entry.getValue();
384+
V value = remappingFunction.apply(key, oldValue);
385+
if (value != null) {
386+
// Replace entry
387+
entry.setValue(value);
388+
return value;
389+
}
390+
else {
391+
// Remove entry
392+
if (ref != null) {
393+
ref.release();
394+
}
395+
}
396+
}
397+
return null;
398+
}
399+
});
400+
}
401+
402+
@Override
403+
public @Nullable V compute(@Nullable K key, BiFunction<@Nullable ? super K, @Nullable ? super V, @Nullable ? extends V> remappingFunction) {
404+
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
405+
@Override
406+
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
407+
V oldValue = null;
408+
if (entry != null) {
409+
oldValue = entry.getValue();
410+
}
411+
V value = remappingFunction.apply(key, oldValue);
412+
if (value != null) {
413+
if (entry != null) {
414+
// Replace entry
415+
entry.setValue(value);
416+
}
417+
else {
418+
// Add entry
419+
Assert.state(entries != null, "No entries segment");
420+
entries.add(value);
421+
}
422+
return value;
423+
}
424+
else {
425+
// Remove entry
426+
if (ref != null) {
427+
ref.release();
428+
}
429+
}
430+
return null;
431+
}
432+
});
433+
}
434+
435+
@Override
436+
public @Nullable V merge(@Nullable K key, @Nullable V value, BiFunction<@Nullable ? super V, @Nullable ? super V, @Nullable ? extends V> remappingFunction) {
437+
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
438+
@Override
439+
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
440+
if (entry != null) {
441+
V oldValue = entry.getValue();
442+
V newValue = remappingFunction.apply(oldValue, value);
443+
if (newValue != null) {
444+
// Replace entry
445+
entry.setValue(newValue);
446+
return newValue;
447+
}
448+
else {
449+
// Remove entry
450+
if (ref != null) {
451+
ref.release();
452+
}
453+
return null;
454+
}
455+
}
456+
else {
457+
// Add entry
458+
Assert.state(entries != null, "No entries segment");
459+
entries.add(value);
460+
return value;
461+
}
462+
}
463+
});
464+
}
465+
356466
@Override
357467
public void clear() {
358468
for (Segment segment : this.segments) {

0 commit comments

Comments
 (0)