Skip to content

Commit a84ebb3

Browse files
committed
Add optional local_search to SimulatedAnnealing; implement simple_local_search and doctest demonstrating benefit
1 parent 3e41fd6 commit a84ebb3

File tree

1 file changed

+77
-1
lines changed

1 file changed

+77
-1
lines changed

simulated_annealing/simulated_annealing.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def __init__(
2121
min_temperature: float = 1e-3,
2222
iterations_per_temp: int = 100,
2323
neighbor_scale: float = 0.1,
24+
local_search: Optional[Callable[[Sequence[float], Callable[[Sequence[float]], float], Callable[[Sequence[float]], Sequence[float]], int], Tuple[Sequence[float], float]]] = None,
25+
local_search_iters: int = 10,
2426
seed: Optional[int] = None,
2527
):
2628
self.func = func
@@ -33,6 +35,9 @@ def __init__(
3335
self.min_temperature = float(min_temperature)
3436
self.iterations_per_temp = int(iterations_per_temp)
3537
self.neighbor_scale = float(neighbor_scale)
38+
# local_search: callable(solution, func, neighbor_fn, iters) -> (improved_solution, improved_cost)
39+
self.local_search = local_search
40+
self.local_search_iters = int(local_search_iters)
3641
if seed is not None:
3742
random.seed(seed)
3843

@@ -91,7 +96,17 @@ def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object]
9196
return best, best_cost, history
9297

9398
candidate = self._neighbor(current)
94-
candidate_cost = float(self.func(candidate))
99+
# Optionally refine candidate with local search before evaluating/accepting
100+
if self.local_search is not None:
101+
try:
102+
improved, improved_cost = self.local_search(candidate, self.func, self._neighbor, self.local_search_iters)
103+
candidate = list(improved)
104+
candidate_cost = float(improved_cost)
105+
except Exception:
106+
# Fall back to plain candidate evaluation if local search fails
107+
candidate_cost = float(self.func(candidate))
108+
else:
109+
candidate_cost = float(self.func(candidate))
95110
delta = candidate_cost - current_cost
96111
if self._accept(delta, temp):
97112
current = candidate
@@ -131,5 +146,66 @@ def _test_quadratic():
131146
print("best:", best, "cost:", cost)
132147

133148

149+
def simple_local_search(solution: Sequence[float], func: Callable[[Sequence[float]], float], neighbor_fn: Callable[[Sequence[float]], Sequence[float]], iterations: int = 10) -> Tuple[Sequence[float], float]:
150+
"""A tiny hill-climbing local search that repeatedly accepts improving neighbors.
151+
152+
Parameters
153+
- solution: starting solution sequence
154+
- func: objective function (lower is better)
155+
- neighbor_fn: function that given a solution returns a new neighbor solution
156+
- iterations: number of neighbor attempts
157+
158+
Returns a tuple (best_solution, best_cost).
159+
160+
>>> func = lambda x: (x[0] - 5) ** 2
161+
>>> start = [0.0]
162+
>>> def neighbor(x):
163+
... return [x[0] + 0.5]
164+
>>> best, cost = simple_local_search(start, func, neighbor, iterations=5)
165+
>>> best[0] > start[0]
166+
True
167+
>>> cost == func(best)
168+
True
169+
"""
170+
best = list(solution)
171+
best_cost = float(func(best))
172+
for _ in range(int(iterations)):
173+
cand = neighbor_fn(best)
174+
cand_cost = float(func(cand))
175+
if cand_cost < best_cost:
176+
best = list(cand)
177+
best_cost = cand_cost
178+
return best, best_cost
179+
180+
181+
def _doctest_local_search_benefit():
182+
"""Demonstrate that providing a local_search can improve or match the solution found by SimulatedAnnealing.
183+
184+
The test uses a deterministic seed so the result is reproducible in doctest.
185+
186+
>>> func = lambda x: (x[0] - 5) ** 2
187+
>>> sa1 = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=20, seed=1)
188+
>>> best1, cost1, _ = sa1.optimize(max_steps=200)
189+
>>> # define a deterministic, greedy local search that moves toward the known minimum (5.0)
190+
>>> def my_local_search(sol, f, neighbour, iters):
191+
... s = list(sol)
192+
... bestc = float(f(s))
193+
... for _ in range(int(iters)):
194+
... # move halfway toward 5.0 (gradient-free, deterministic)
195+
... s[0] = s[0] + 0.5 * (5.0 - s[0])
196+
... c = float(f(s))
197+
... if c < bestc:
198+
... bestc = c
199+
... else:
200+
... break
201+
... return s, bestc
202+
>>> sa2 = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=20, seed=1, local_search=my_local_search, local_search_iters=5)
203+
>>> best2, cost2, _ = sa2.optimize(max_steps=200)
204+
>>> cost2 <= cost1
205+
True
206+
"""
207+
pass
208+
209+
134210
if __name__ == "__main__":
135211
_test_quadratic()

0 commit comments

Comments
 (0)