diff --git a/query/shortest.go b/query/shortest.go index 435942496b2..45ef756c36e 100644 --- a/query/shortest.go +++ b/query/shortest.go @@ -88,6 +88,21 @@ func (h *priorityQueue) Pop() interface{} { return val } +func (pq *priorityQueue) TrimToMax(max int64) { + if max <= 0 || int64(pq.Len()) <= max { + return + } + imax := 0 + maxCost := (*pq)[0].cost + for i := 1; i < pq.Len(); i++ { + if (*pq)[i].cost > maxCost { + imax = i + maxCost = (*pq)[i].cost + } + } + heap.Remove(pq, imax) +} + type mapItem struct { attr string cost float64 @@ -405,10 +420,8 @@ func runKShortestPaths(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) { hop: item.hop + 1, path: route{route: curPath}, } - if int64(pq.Len()) > sg.Params.MaxFrontierSize { - pq.Pop() - } heap.Push(&pq, node) + pq.TrimToMax(sg.Params.MaxFrontierSize) } // Return the popped nodes path to pool. pathPool.Put(item.path.route) @@ -561,10 +574,8 @@ func shortestPath(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) { cost: nodeCost, hop: item.hop + 1, } - if int64(pq.Len()) > sg.Params.MaxFrontierSize { - pq.Pop() - } heap.Push(&pq, node) + pq.TrimToMax(sg.Params.MaxFrontierSize) } else { // We've already seen this node. So, just update the cost // and fix the priority in the heap and map. diff --git a/query/shortest_test.go b/query/shortest_test.go new file mode 100644 index 00000000000..1f39a1b92fa --- /dev/null +++ b/query/shortest_test.go @@ -0,0 +1,33 @@ +package query + +import ( + "container/heap" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPriorityQueueTrimToMax_RemovesHighestCost(t *testing.T) { + var pq priorityQueue + heap.Init(&pq) + + costs := []float64{1, 50, 2, 60, 3, 70, 4} + for _, c := range costs { + heap.Push(&pq, &queueItem{cost: c}) + } + + // Trim to keep N-1 elements. + (&pq).TrimToMax(int64(len(costs) - 1)) + require.Equal(t, len(costs)-1, pq.Len()) + + // Pop all remaining costs and ensure the maximum was removed. + seen := make(map[float64]bool, len(costs)) + for pq.Len() > 0 { + seen[heap.Pop(&pq).(*queueItem).cost] = true + } + + require.False(t, seen[70], "expected highest cost to be removed") + for _, c := range []float64{1, 2, 3, 4, 50, 60} { + require.True(t, seen[c], "expected cost to remain: %v", c) + } +}