@@ -21,6 +21,8 @@ def __init__(
2121 min_temperature : float = 1e-3 ,
2222 iterations_per_temp : int = 100 ,
2323 neighbor_scale : float = 0.1 ,
24+ local_search : Optional [Callable [[Sequence [float ], Callable [[Sequence [float ]], float ], Callable [[Sequence [float ]], Sequence [float ]], int ], Tuple [Sequence [float ], float ]]] = None ,
25+ local_search_iters : int = 10 ,
2426 seed : Optional [int ] = None ,
2527 ):
2628 self .func = func
@@ -33,6 +35,9 @@ def __init__(
3335 self .min_temperature = float (min_temperature )
3436 self .iterations_per_temp = int (iterations_per_temp )
3537 self .neighbor_scale = float (neighbor_scale )
38+ # local_search: callable(solution, func, neighbor_fn, iters) -> (improved_solution, improved_cost)
39+ self .local_search = local_search
40+ self .local_search_iters = int (local_search_iters )
3641 if seed is not None :
3742 random .seed (seed )
3843
@@ -91,7 +96,17 @@ def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object]
9196 return best , best_cost , history
9297
9398 candidate = self ._neighbor (current )
94- candidate_cost = float (self .func (candidate ))
99+ # Optionally refine candidate with local search before evaluating/accepting
100+ if self .local_search is not None :
101+ try :
102+ improved , improved_cost = self .local_search (candidate , self .func , self ._neighbor , self .local_search_iters )
103+ candidate = list (improved )
104+ candidate_cost = float (improved_cost )
105+ except Exception :
106+ # Fall back to plain candidate evaluation if local search fails
107+ candidate_cost = float (self .func (candidate ))
108+ else :
109+ candidate_cost = float (self .func (candidate ))
95110 delta = candidate_cost - current_cost
96111 if self ._accept (delta , temp ):
97112 current = candidate
@@ -131,5 +146,66 @@ def _test_quadratic():
131146 print ("best:" , best , "cost:" , cost )
132147
133148
149+ def simple_local_search (solution : Sequence [float ], func : Callable [[Sequence [float ]], float ], neighbor_fn : Callable [[Sequence [float ]], Sequence [float ]], iterations : int = 10 ) -> Tuple [Sequence [float ], float ]:
150+ """A tiny hill-climbing local search that repeatedly accepts improving neighbors.
151+
152+ Parameters
153+ - solution: starting solution sequence
154+ - func: objective function (lower is better)
155+ - neighbor_fn: function that given a solution returns a new neighbor solution
156+ - iterations: number of neighbor attempts
157+
158+ Returns a tuple (best_solution, best_cost).
159+
160+ >>> func = lambda x: (x[0] - 5) ** 2
161+ >>> start = [0.0]
162+ >>> def neighbor(x):
163+ ... return [x[0] + 0.5]
164+ >>> best, cost = simple_local_search(start, func, neighbor, iterations=5)
165+ >>> best[0] > start[0]
166+ True
167+ >>> cost == func(best)
168+ True
169+ """
170+ best = list (solution )
171+ best_cost = float (func (best ))
172+ for _ in range (int (iterations )):
173+ cand = neighbor_fn (best )
174+ cand_cost = float (func (cand ))
175+ if cand_cost < best_cost :
176+ best = list (cand )
177+ best_cost = cand_cost
178+ return best , best_cost
179+
180+
181+ def _doctest_local_search_benefit ():
182+ """Demonstrate that providing a local_search can improve or match the solution found by SimulatedAnnealing.
183+
184+ The test uses a deterministic seed so the result is reproducible in doctest.
185+
186+ >>> func = lambda x: (x[0] - 5) ** 2
187+ >>> sa1 = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=20, seed=1)
188+ >>> best1, cost1, _ = sa1.optimize(max_steps=200)
189+ >>> # define a deterministic, greedy local search that moves toward the known minimum (5.0)
190+ >>> def my_local_search(sol, f, neighbour, iters):
191+ ... s = list(sol)
192+ ... bestc = float(f(s))
193+ ... for _ in range(int(iters)):
194+ ... # move halfway toward 5.0 (gradient-free, deterministic)
195+ ... s[0] = s[0] + 0.5 * (5.0 - s[0])
196+ ... c = float(f(s))
197+ ... if c < bestc:
198+ ... bestc = c
199+ ... else:
200+ ... break
201+ ... return s, bestc
202+ >>> sa2 = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=20, seed=1, local_search=my_local_search, local_search_iters=5)
203+ >>> best2, cost2, _ = sa2.optimize(max_steps=200)
204+ >>> cost2 <= cost1
205+ True
206+ """
207+ pass
208+
209+
134210if __name__ == "__main__" :
135211 _test_quadratic ()
0 commit comments