Skip to content

Commit 10b1a2a

Browse files
committed
Add Traveling Salesman Problem algorithms and tests
1 parent c3d4b9e commit 10b1a2a

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
from graphs.traveling_salesman_problem import tsp_brute_force, tsp_dp, tsp_greedy
3+
4+
def sample_graph_1() -> list[list[int]]:
5+
return [
6+
[0, 29, 20],
7+
[29, 0, 15],
8+
[20, 15, 0],
9+
]
10+
11+
def sample_graph_2() -> list[list[int]]:
12+
return [
13+
[0, 10, 15, 20],
14+
[10, 0, 35, 25],
15+
[15, 35, 0, 30],
16+
[20, 25, 30, 0],
17+
]
18+
19+
def test_brute_force():
20+
graph = sample_graph_1()
21+
assert tsp_brute_force(graph) == 64
22+
23+
def test_dp():
24+
graph = sample_graph_1()
25+
assert tsp_dp(graph) == 64
26+
27+
def test_greedy():
28+
graph = sample_graph_1()
29+
# The greedy algorithm does not guarantee an optimal solution;
30+
# it is necessary to verify that its output is an integer greater than 0.
31+
result = tsp_greedy(graph)
32+
assert isinstance(result, int)
33+
assert result > 0
34+
35+
def test_dp_larger_graph():
36+
graph = sample_graph_2()
37+
assert tsp_dp(graph) == 80
38+
39+
def test_brute_force_larger_graph():
40+
graph = sample_graph_2()
41+
assert tsp_brute_force(graph) == 80
42+
43+
def test_greedy_larger_graph():
44+
graph = sample_graph_2()
45+
# An approximate solution cannot be represented by '==' and can only ensure that the result is reasonable.
46+
result = tsp_greedy(graph)
47+
assert isinstance(result, int)
48+
assert result >= 80
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from itertools import permutations
2+
3+
def tsp_brute_force(graph: list[list[int]]) -> int:
4+
"""
5+
Solves TSP using brute-force permutations.
6+
7+
Args:
8+
graph: 2D list representing distances between cities.
9+
10+
Returns:
11+
The minimal total travel distance visiting all cities exactly once and returning to the start.
12+
13+
Example:
14+
>>> tsp_brute_force([[0, 29, 20], [29, 0, 15], [20, 15, 0]])
15+
64
16+
"""
17+
n = len(graph)
18+
# Apart from other cities aside from City 0, City 0 serves as the starting point.
19+
nodes = list(range(1, n))
20+
min_path = float('inf')
21+
22+
# Enumerate all the permutations from city 1 to city n-1.
23+
for perm in permutations(nodes):
24+
# Construct a complete path:
25+
# starting from point 0, visit in the order of arrangement, and then return to point 0.
26+
path = [0] + list(perm) + [0]
27+
28+
# Calculate the total distance of the path.
29+
# Update the shortest path.
30+
total_cost = sum(graph[path[i]][path[i + 1]] for i in range(n))
31+
min_path = min(min_path, total_cost)
32+
33+
return min_path
34+
35+
def tsp_dp(graph: list[list[int]]) -> int:
36+
"""
37+
Solves the Traveling Salesman Problem using Held-Karp dynamic programming.
38+
39+
Args:
40+
graph: A 2D list representing distances between cities (n x n matrix).
41+
42+
Returns:
43+
The minimum cost to visit all cities exactly once and return to the origin.
44+
45+
Example:
46+
>>> tsp_dp([[0, 29, 20], [29, 0, 15], [20, 15, 0]])
47+
64
48+
"""
49+
n = len(graph)
50+
# Create a dynamic programming table of size (2^n) x n.
51+
# Noting: 1 << n = 2^n
52+
# dp[mask][i] represents the shortest path starting from city 0, passing through the cities in the mask, and ultimately ending at city i.
53+
dp = [[float('inf')] * n for _ in range(1 << n)]
54+
# Initial state: only city 0 is visited, and the path length is 0.
55+
dp[1][0] = 0
56+
57+
for mask in range(1 << n):
58+
# The mask indicates which cities have been visited.
59+
for u in range(n):
60+
if not (mask & (1 << u)):
61+
# If the city u is not included in the mask, skip it.
62+
continue
63+
64+
for v in range(n):
65+
# City v has not been accessed and is different from city u.
66+
if mask & (1 << v) or u == v:
67+
continue
68+
69+
# New State: Transition to city v
70+
# State Transition: From city u to city v, updating the shortest path.
71+
next_mask = mask | (1 << v)
72+
dp[next_mask][v] = min(dp[next_mask][v], dp[mask][u] + graph[u][v])
73+
74+
# After completing visits to all cities, return to city 0 and obtain the minimum value.
75+
return min(dp[(1 << n) - 1][i] + graph[i][0] for i in range(1, n))
76+
77+
def tsp_greedy(graph: list[list[int]]) -> int:
78+
"""
79+
Solves TSP approximately using the nearest neighbor heuristic.
80+
Warming: This algorithm is not guaranteed to find the optimal solution! But it is fast and applicable to any input size.
81+
82+
Args:
83+
graph: 2D list representing distances between cities.
84+
85+
Returns:
86+
The total distance of the approximated TSP route.
87+
88+
Example:
89+
>>> tsp_greedy([[0, 29, 20], [29, 0, 15], [20, 15, 0]])
90+
64
91+
"""
92+
n = len(graph)
93+
visited = [False] * n # Mark whether each city has been visited.
94+
path = [0]
95+
total_cost = 0
96+
visited[0] = True # Start from city 0.
97+
current = 0 # Current city.
98+
99+
for _ in range(n - 1):
100+
# Find the nearest city to the current location that has not been visited.
101+
next_city = min(
102+
((city, cost) for city, cost in enumerate(graph[current]) if not visited[city] and city != current),
103+
key=lambda x: x[1],
104+
default=(None, float('inf'))
105+
)[0]
106+
107+
# If no such city exists, break the loop.
108+
if next_city is None:
109+
break
110+
111+
# Update the total cost and the current city.
112+
# Mark the city as visited.
113+
# Append the city to the path.
114+
total_cost += graph[current][next_city]
115+
visited[next_city] = True
116+
current = next_city
117+
path.append(current)
118+
119+
# Back to start
120+
total_cost += graph[current][0]
121+
path.append(0)
122+
123+
return total_cost
124+
125+
126+
def test_tsp_example():
127+
graph = [[0, 29, 20], [29, 0, 15], [20, 15, 0]]
128+
129+
result = tsp_brute_force(graph)
130+
if result != 64:
131+
raise Exception('tsp_brute_force Incorrect result')
132+
else:
133+
print('Test passed')
134+
135+
result = tsp_dp(graph)
136+
if result != 64:
137+
raise Exception('tsp_dp Incorrect result')
138+
else:
139+
print("Test passed")
140+
141+
result = tsp_greedy(graph)
142+
if result != 64:
143+
if result < 0:
144+
raise Exception('tsp_greedy Incorrect result')
145+
else:
146+
print("tsp_greedy gets an approximate result.")
147+
else:
148+
print('Test passed')
149+
150+
if __name__ == '__main__':
151+
test_tsp_example()

0 commit comments

Comments
 (0)