Skip to content

Commit 858855a

Browse files
fix: provide a more accurate problem scale log (#2183)
Previously, the logic just assumed the values from entity dependent value ranges are unique, which massively inflated the problem scale log for list variables. Now, we track the number of entities a value is allowed to be in, and use that to scale the problem scale. For instance, if a value is only allowed to be in 25% of entities, it should reduce the problem scale by 25%.
1 parent 65918e5 commit 858855a

10 files changed

Lines changed: 179 additions & 179 deletions

File tree

core/src/main/java/ai/timefold/solver/core/impl/domain/solution/descriptor/ProblemScaleTracker.java

Lines changed: 15 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,29 @@
11
package ai.timefold.solver.core.impl.domain.solution.descriptor;
22

3-
import java.util.Collections;
4-
import java.util.IdentityHashMap;
5-
import java.util.Set;
6-
3+
import ai.timefold.solver.core.impl.domain.variable.descriptor.ListVariableDescriptor;
4+
import ai.timefold.solver.core.impl.score.director.ListValueRangeStatistics;
5+
import ai.timefold.solver.core.impl.score.director.ValueRangeManager;
76
import ai.timefold.solver.core.impl.util.MathUtils;
87

9-
public class ProblemScaleTracker {
8+
public class ProblemScaleTracker<Solution_> {
109
private final long logBase;
11-
private final Set<Object> visitedAnchorSet = Collections.newSetFromMap(new IdentityHashMap<>());
12-
10+
private final ListValueRangeStatistics<Solution_> listValueRangeStatistics;
1311
private long basicProblemScaleLog = 0L;
14-
private int listPinnedValueCount = 0;
15-
private int listTotalEntityCount = 0;
16-
private int listMovableEntityCount = 0;
17-
private int listTotalValueCount = 0;
12+
private long cachedTotalProblemScaleLog = -1L;
1813

19-
public ProblemScaleTracker(long logBase) {
14+
public ProblemScaleTracker(ListVariableDescriptor<Solution_> listVariableDescriptor,
15+
ValueRangeManager<Solution_> valueRangeManager,
16+
long logBase) {
2017
this.logBase = logBase;
18+
this.listValueRangeStatistics = new ListValueRangeStatistics<>(listVariableDescriptor, valueRangeManager);
2119
}
2220

23-
// Simple getters
24-
public long getBasicProblemScaleLog() {
25-
return basicProblemScaleLog;
26-
}
27-
28-
public int getListPinnedValueCount() {
29-
return listPinnedValueCount;
30-
}
31-
32-
public int getListTotalEntityCount() {
33-
return listTotalEntityCount;
34-
}
35-
36-
public int getListMovableEntityCount() {
37-
return listMovableEntityCount;
38-
}
39-
40-
public int getListTotalValueCount() {
41-
return listTotalValueCount;
42-
}
43-
44-
public void setListTotalValueCount(int listTotalValueCount) {
45-
this.listTotalValueCount = listTotalValueCount;
46-
}
47-
48-
// Complex methods
49-
public boolean isAnchorVisited(Object anchor) {
50-
if (visitedAnchorSet.contains(anchor)) {
51-
return true;
52-
}
53-
visitedAnchorSet.add(anchor);
54-
return false;
55-
}
56-
57-
public void addListValueCount(int count) {
58-
listTotalValueCount += count;
59-
}
60-
61-
public void addPinnedListValueCount(int count) {
62-
listPinnedValueCount += count;
63-
}
64-
65-
public void incrementListEntityCount(boolean isMovable) {
66-
listTotalEntityCount++;
67-
if (isMovable) {
68-
listMovableEntityCount++;
21+
public long getProblemScaleLog() {
22+
if (cachedTotalProblemScaleLog != -1L) {
23+
return cachedTotalProblemScaleLog;
6924
}
25+
cachedTotalProblemScaleLog = basicProblemScaleLog + listValueRangeStatistics.computeListProblemScaleLog(logBase);
26+
return cachedTotalProblemScaleLog;
7027
}
7128

7229
public void addBasicProblemScale(long count) {

core/src/main/java/ai/timefold/solver/core/impl/heuristic/selector/common/ReachableValues.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.Objects;
11+
import java.util.Set;
1112
import java.util.function.Function;
1213

1314
import ai.timefold.solver.core.impl.domain.valuerange.descriptor.FromEntityPropertyValueRangeDescriptor;
@@ -82,6 +83,14 @@ public List<Entity_> extractEntitiesAsList(Object value) {
8283
return entityList;
8384
}
8485

86+
public int getReachableEntitiesSize(Object value) {
87+
var itemValue = fetchItemValue(value);
88+
if (itemValue == null) {
89+
return 0;
90+
}
91+
return itemValue.getReachableEntitySize();
92+
}
93+
8594
public List<Value_> extractValuesAsList(Object value) {
8695
var itemValue = fetchItemValue(value);
8796
if (itemValue == null) {
@@ -95,6 +104,14 @@ public List<Value_> extractValuesAsList(Object value) {
95104
return valueList;
96105
}
97106

107+
public Set<Entity_> extractAllEntitiesAsSet() {
108+
return entitiesIndex.indexMap().keySet();
109+
}
110+
111+
public Set<Value_> extractAllValuesAsSet() {
112+
return valuesIndex.indexMap().keySet();
113+
}
114+
98115
public int getSize() {
99116
return valuesIndex.allItems().size();
100117
}
@@ -184,6 +201,10 @@ boolean containsValue(int valueIndex) {
184201
return valueBitSet.get(valueIndex);
185202
}
186203

204+
int getReachableEntitySize() {
205+
return entityBitSet.cardinality();
206+
}
207+
187208
List<Entity_> getRandomAccessEntityList(List<Entity_> allEntities) {
188209
return new BitSetIndexedList<>(allEntities, entityBitSet);
189210
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package ai.timefold.solver.core.impl.score.director;
2+
3+
import ai.timefold.solver.core.impl.domain.variable.descriptor.ListVariableDescriptor;
4+
import ai.timefold.solver.core.impl.util.MathUtils;
5+
6+
import org.jspecify.annotations.NullMarked;
7+
import org.jspecify.annotations.Nullable;
8+
9+
@NullMarked
10+
public class ListValueRangeStatistics<Solution_> {
11+
@Nullable
12+
private final ListVariableDescriptor<Solution_> listVariableDescriptor;
13+
private final ValueRangeManager<Solution_> valueRangeManager;
14+
15+
public ListValueRangeStatistics(@Nullable ListVariableDescriptor<Solution_> listVariableDescriptor,
16+
ValueRangeManager<Solution_> valueRangeManager) {
17+
this.listVariableDescriptor = listVariableDescriptor;
18+
this.valueRangeManager = valueRangeManager;
19+
}
20+
21+
public long computeListProblemScaleLog(long logBase) {
22+
if (listVariableDescriptor == null) {
23+
// No list variable
24+
return 0L;
25+
}
26+
var allowsUnassignedValues = listVariableDescriptor.allowsUnassignedValues();
27+
var reachableValues = valueRangeManager.getReachableValues(listVariableDescriptor);
28+
var entityCount = reachableValues.extractAllEntitiesAsSet().size();
29+
if (entityCount == 0) {
30+
// No entities
31+
return 0L;
32+
}
33+
if (allowsUnassignedValues) {
34+
// Unassigned values are treated as if they are assigned to a virtual entity to simplify calculations
35+
entityCount++;
36+
}
37+
38+
var valueSet = reachableValues.extractAllValuesAsSet();
39+
var valueCount = valueSet.size();
40+
var validPercentageLog = 0L;
41+
var additionalCount = allowsUnassignedValues ? 1 : 0;
42+
43+
for (var value : valueSet) {
44+
validPercentageLog += MathUtils.getScaledApproximateLog(MathUtils.LOG_PRECISION, logBase,
45+
((double) reachableValues.getReachableEntitiesSize(value) + additionalCount) / entityCount);
46+
}
47+
48+
return MathUtils.getPossibleArrangementsScaledApproximateLog(MathUtils.LOG_PRECISION, logBase,
49+
valueCount, entityCount) + validPercentageLog;
50+
}
51+
}

core/src/main/java/ai/timefold/solver/core/impl/score/director/ValueRangeManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ public <Entity_, Value_> ValueRange<Value_> getFromEntity(
180180
AbstractValueRangeDescriptor<Solution_> abstractValueRangeDescriptor, Entity_ entity,
181181
@Nullable SelectionSorter<Solution_, Value_> sorter) {
182182
ValueRangeState<Solution_, Entity_, Value_> descriptor = fromDescriptor(abstractValueRangeDescriptor);
183-
return descriptor.getFromEntity(entity, getInitializationStatistics().genuineEntityCount(), sorter);
183+
return descriptor.getFromEntity(entity, sorter);
184184
}
185185

186186
public long countOnSolution(AbstractValueRangeDescriptor<Solution_> abstractValueRangeDescriptor, Solution_ solution) {

core/src/main/java/ai/timefold/solver/core/impl/score/director/ValueRangeState.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,21 +139,21 @@ private ValueRange<Value_> sortValueRange(ValueRange<Value_> originalValueRange,
139139
return sortableValueRange.sort(SelectionSorterAdapter.of(cachedWorkingSolution, sorter));
140140
}
141141

142-
public ValueRange<Value_> getFromEntity(Entity_ entity, int entityCount,
142+
public ValueRange<Value_> getFromEntity(Entity_ entity,
143143
@Nullable SelectionSorter<Solution_, Value_> sorter) {
144-
var entityMap = ensureEntityMapIsInitialized(entityCount);
144+
var entityMap = ensureEntityMapIsInitialized();
145145
var item = entityMap.get(entity);
146146
// No item, we set the left side by default
147147
if (item == null) {
148148
var newItem = buildEntityValueRangeItem(entity, sorter);
149149
entityMap.put(entity, newItem);
150150
if (newItem.entity() != null && newItem.leftItem() == null && newItem.rightItem() == null) {
151151
// Placeholder for another entity
152-
return getFromEntity(Objects.requireNonNull(newItem.entity()), entityCount, sorter);
152+
return getFromEntity(Objects.requireNonNull(newItem.entity()), sorter);
153153
}
154154
return Objects.requireNonNull(newItem.leftItem());
155155
}
156-
var valueRange = pickValueBySorter(item, sorter, (p, s) -> getFromEntity(p, entityCount, s));
156+
var valueRange = pickValueBySorter(item, sorter, (p, s) -> getFromEntity(p, s));
157157
if (valueRange != null) {
158158
return valueRange;
159159
}
@@ -177,10 +177,10 @@ public ValueRange<Value_> getFromEntity(Entity_ entity, int entityCount,
177177
}
178178

179179
private Map<Entity_, ValueRangeItem<Solution_, Entity_, ValueRange<Value_>, Value_>>
180-
ensureEntityMapIsInitialized(int entityCount) {
180+
ensureEntityMapIsInitialized() {
181181
if (fromEntityMap == null) {
182-
fromEntityMap = new IdentityHashMap<>(entityCount);
183-
valueRangeDeduplicationCache = HashMap.newHashMap(entityCount);
182+
fromEntityMap = new IdentityHashMap<>();
183+
valueRangeDeduplicationCache = new HashMap<>();
184184
}
185185
return fromEntityMap;
186186
}
@@ -284,7 +284,7 @@ private ReachableValues<Entity_, Value_> fetchReachableValues(GenuineVariableDes
284284
var valueIndexItem = new ReachableValuesIndex<>(valueIndexMap, reachableValueList);
285285
for (var i = 0; i < entityList.size(); i++) {
286286
var entity = entityList.get(i);
287-
var valueRange = getFromEntity(entity, entityList.size(), null);
287+
var valueRange = getFromEntity(entity, null);
288288
loadEntityValueRange(i, valueIndexMap, valueRange, reachableValueList);
289289
}
290290
var sorterAdapter = sorter != null ? SelectionSorterAdapter.of(cachedWorkingSolution, sorter) : null;
@@ -370,7 +370,7 @@ public static <Solution_, Entity_, Type_, Value_> ValueRangeItem<Solution_, Enti
370370
* The record holds a reference to {@link ValueRange},
371371
* a precomputed hash to avoid recalculating it every time.
372372
*/
373-
private record HashedValueRange<T>(ValueRange<T> item, int hash) {
373+
record HashedValueRange<T>(ValueRange<T> item, int hash) {
374374

375375
public static <Value_> HashedValueRange<Value_> of(ValueRange<Value_> valueRange) {
376376
return new HashedValueRange<>(valueRange, valueRange.hashCode());

0 commit comments

Comments
 (0)