Skip to content

Commit c1b0edd

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1452f5d commit c1b0edd

File tree

1 file changed

+50
-52
lines changed

1 file changed

+50
-52
lines changed

graphs/travelling_salesman.py

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import heapq
22

3+
34
def tsp(cost):
45
"""
56
https://www.geeksforgeeks.org/dsa/approximate-solution-for-travelling-salesman-problem-using-mst/
@@ -18,8 +19,8 @@ def tsp(cost):
1819
Assumptions:
1920
1. The graph is complete.
2021
21-
2. The problem instance satisfies Triangle-Inequality.(The least distant path to reach a vertex j from i is always to reach j
22-
directly from i, rather than through some other vertex k)
22+
2. The problem instance satisfies Triangle-Inequality.(The least distant path to reach a vertex j from i is always to reach j
23+
directly from i, rather than through some other vertex k)
2324
2425
3. The cost matrix is symmetric, i.e., cost[i][j] = cost[j][i]
2526
@@ -29,65 +30,67 @@ def tsp(cost):
2930
"""
3031
# create the adjacency list
3132
adj = createList(cost)
32-
33-
#check for triangle inequality violations
33+
34+
# check for triangle inequality violations
3435
if triangleInequality(adj):
3536
print("Triangle Inequality Violation")
3637
return -1
37-
38+
3839
# construct the travelling salesman tour
3940
tspTour = approximateTSP(adj)
40-
41+
4142
# calculate the cost of the tour
4243
tspCost = tourCost(tspTour)
43-
44+
4445
return tspCost
4546

47+
4648
# function to implement approximate TSP
4749
def approximateTSP(adj):
4850
n = len(adj)
49-
51+
5052
# to store the cost of minimum spanning tree
5153
mstCost = [0]
52-
54+
5355
# stores edges of minimum spanning tree
5456
mstEdges = findMST(adj, mstCost)
55-
57+
5658
# to mark the visited nodes
5759
visited = [False] * n
58-
60+
5961
# create adjacency list for mst
6062
mstAdj = [[] for _ in range(n)]
6163
for e in mstEdges:
6264
mstAdj[e[0]].append([e[1], e[2]])
6365
mstAdj[e[1]].append([e[0], e[2]])
64-
66+
6567
# to store the eulerian tour
6668
tour = []
6769
eulerianCircuit(mstAdj, 0, tour, visited, -1)
68-
70+
6971
# add the starting node to the tour
7072
tour.append(0)
71-
73+
7274
# to store the final tour path
7375
tourPath = []
74-
76+
7577
for i in range(len(tour) - 1):
7678
u = tour[i]
7779
v = tour[i + 1]
7880
weight = 0
79-
81+
8082
# find the weight of the edge u -> v
8183
for neighbor in adj[u]:
8284
if neighbor[0] == v:
8385
weight = neighbor[1]
8486
break
85-
87+
8688
# add the edge to the tour path
8789
tourPath.append([u, v, weight])
88-
90+
8991
return tourPath
9092

93+
9194
def tourCost(tour):
9295
cost = 0
9396
for edge in tour:
@@ -98,67 +101,67 @@ def tourCost(tour):
98101
def eulerianCircuit(adj, u, tour, visited, parent):
99102
visited[u] = True
100103
tour.append(u)
101-
104+
102105
for neighbor in adj[u]:
103106
v = neighbor[0]
104107
if v == parent:
105108
continue
106-
109+
107110
if visited[v] == False:
108111
eulerianCircuit(adj, v, tour, visited, u)
109-
112+
113+
110114
# function to find the minimum spanning tree
111115
def findMST(adj, mstCost):
112116
n = len(adj)
113-
117+
114118
# to marks the visited nodes
115119
visited = [False] * n
116-
120+
117121
# stores edges of minimum spanning tree
118122
mstEdges = []
119-
123+
120124
pq = []
121125
heapq.heappush(pq, [0, 0, -1])
122-
126+
123127
while pq:
124128
current = heapq.heappop(pq)
125-
129+
126130
u = current[1]
127131
weight = current[0]
128132
parent = current[2]
129-
133+
130134
if visited[u]:
131135
continue
132-
136+
133137
mstCost[0] += weight
134138
visited[u] = True
135-
139+
136140
if parent != -1:
137141
mstEdges.append([u, parent, weight])
138-
142+
139143
for neighbor in adj[u]:
140144
v = neighbor[0]
141145
if v == parent:
142146
continue
143147
w = neighbor[1]
144-
148+
145149
if not visited[v]:
146150
heapq.heappush(pq, [w, v, u])
147151
return mstEdges
148-
149152

150-
151-
# function to calculate if the
153+
154+
# function to calculate if the
152155
# triangle inequality is violated
153156
def triangleInequality(adj):
154157
n = len(adj)
155-
156-
# Sort each adjacency list based
158+
159+
# Sort each adjacency list based
157160
# on the weight of the edges
158161
for i in range(n):
159162
adj[i].sort(key=lambda a: a[1])
160-
161-
# check triangle inequality for each
163+
164+
# check triangle inequality for each
162165
# triplet of nodes (u, v, w)
163166
for u in range(n):
164167
for x in adj[u]:
@@ -174,33 +177,28 @@ def triangleInequality(adj):
174177
return True
175178
# no violations found
176179
return False
177-
180+
181+
178182
# function to create the adjacency list
179183
def createList(cost):
180184
n = len(cost)
181-
185+
182186
# to store the adjacency list
183187
adj = [[] for _ in range(n)]
184-
188+
185189
for u in range(n):
186190
for v in range(n):
187191
# if there is no edge between u and v
188192
if cost[u][v] == 0:
189193
continue
190194
# add the edge to the adjacency list
191195
adj[u].append([v, cost[u][v]])
192-
196+
193197
return adj
194-
195198

196-
199+
197200
if __name__ == "__main__":
198-
#test
199-
cost = [
200-
[0, 1000, 5000],
201-
[5000, 0, 1000],
202-
[1000, 5000, 0]
203-
]
204-
201+
# test
202+
cost = [[0, 1000, 5000], [5000, 0, 1000], [1000, 5000, 0]]
203+
205204
print(tsp(cost))
206-

0 commit comments

Comments
 (0)