Leetcode 1584. Min Cost to Connect All Points


Leetcode 1584.连接所有点的最小费用

比较经典的一道最小生成树的题目,本来对这个问题兴趣不高——这个算法太教科书了。不过提交了几次,发现里面可以优化的地方蛮多的,这里记录下。

PS:最小生成树有两种算法,本文用的是Prim算法。

题目

You are given an array points representing integer coordinates of some points on a 2D-plane, where points[i] = [xi, yi].

The cost of connecting two points [xi, yi] and [xj, yj] is the manhattan distance between them: |xi - xj| + |yi - yj|, where |val| denotes the absolute value of val.

Return the minimum cost to make all points connected. All points are connected if there is exactly one simple path between any two points.

Example 1:

Input: points = [[0,0],[2,2],[3,10],[5,2],[7,0]]
Output: 20
Explanation: 

We can connect the points as shown above to get the minimum cost of 20.
Notice that there is a unique path between every pair of points.

Example 2:

Input: points = [[3,12],[-2,5],[-4,1]]
Output: 18

Constraints:

  • 1 <= points.length <= 1000
  • -106 <= xi, yi <= 106
  • All pairs (xi, yi) are distinct.

分析

首先我得承认,最小生成树是个啥我已经忘光了。

拿到这个题以后一通抓瞎,看了提示发现是典型的最小生成树场景。

那么我们直接来看最小生成树之Prim算法。下面这个表达很好理解。

1.图的所有顶点集合为V;初始令集合u={s},v=V−u={s},v=V−u;
2.在两个集合u,v能够组成的边中,选择一条代价最小的边(u0,v0)(u0,v0),加入到最小生成树中,并把v0v0并入到集合u中。
3.重复上述步骤,直到最小生成树有n-1条边或者n个顶点为止。

下面我们开始撸代码

代码

v1

最朴素的写法,毫无意外超时了。

class Solution(object):
    def minCostConnectPoints(self, points):
        """
        :type points: List[List[int]]
        :rtype: int
        """
        if len(points) < 2:
            return 0
        n = len(points)
        dists = [[None] * n for _ in range(n)]
        ret = 0
        for i in range(n):
            for j in range(i + 1, n):
                dis = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
                dists[i][j] = dis
        #edges = sorted(edges, key=lambda t:t[0])
        s1, s2 = set([0]), set(list(range(1, n)))
        #print(edges)
        while len(s2) > 0:
            e1, e2, d = None, None, None
            for i in s1:
                for j in s2:
                    dis = dists[i][j] if i < j else dists[j][i]
                    if d is None or dis < d:
                        e1 = i
                        e2 = j
                        d = dis
            s1.add(e2)
            s2.remove(e2)
            ret += d
        return ret

v2

分析一下,21&22行这个二重循环太憨憨了,每次计算新添加的边的时候,都要计算&对比s1*s2的所有边,新增一条边前后存在大量的重复计算。

如下图,一轮迭代前后,其实只有变动的点(蓝色)相关的边需要重新计算距离&与其他的边比较大小。

于是我们有了新的思路——用一个最小堆维护当前可添加的边,有新的端点被加入时,向堆中添加由新端点(蓝色)带来的新边,每轮迭代时从最小堆中pop出一个符合要求的最小边。代码如下

class Solution(object):
    def minCostConnectPoints(self, points):
        """
        :type points: List[List[int]]
        :rtype: int
        """
        import heapq
        if len(points) < 2:
            return 0
        n = len(points)
        dists = [[None] * n for _ in range(n)]
        ret = 0
        for i in range(n):
            for j in range(i + 1, n):
                dis = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
                dists[i][j] = dis
                dists[j][i] = dis
        s1, s2 = set([0]), set(list(range(1, n)))
        hp = [(dists[0][i], 0, i) for i in range(1, n)]
        heapq.heapify(hp)
        while len(s2) > 0:
            dis, e1, e2 = heapq.heappop(hp)
            while e1 in s1 and e2 in s1:
                dis, e1, e2 = heapq.heappop(hp)
            ret += dis
            new_e = e2 if e1 in s1 else e1
            s1.add(new_e)
            s2.remove(new_e)
            for e in s2:
                dis = dists[new_e][e]
                heapq.heappush(hp, (dis, e, new_e))
            
        return ret

done

v3

那么问题来了,还能更快吗?

参考了下花花酱(https://zxi.mytechroad.com/blog/graph/leetcode-1584-min-cost-to-connect-all-points/)的思路。

其实,没必要把所有的可用边都存起来——只需要存“没加入的端点距离任意已加入的端点的最小距离”,就行了。代码如下。

class Solution(object):
    def minCostConnectPoints(self, points):
        """
        :type points: List[List[int]]
        :rtype: int
        """
        n = len(points)

        def _dist(i, j):
            return abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
        min_dist = [None] + [_dist(0, i) for i in range(1, n)]
        
        ans = 0
        for i in range(1, n):
            new_edge_weight = min(filter(lambda x:x is not None, min_dist))
            new_node_idx = min_dist.index(new_edge_weight)
            ans += new_edge_weight
            min_dist[new_node_idx] = None
            for i in range(n):
                if min_dist[i] is None:
                    continue
                min_dist[i] = min(min_dist[i], _dist(i, new_node_idx))
        return ans

done!

其他

另一个比较有趣的小细节,v2里面用了预计算的距离,v3里面则没有。这是为什么呢?这里就不写了:)感兴趣的欢迎留言讨论下。



——此处是内容的分割线——

除非注明,否则均为广陌原创文章,转载必须以链接形式标明本文链接

本文链接:https://www.utopiafar.com/2022/04/27/leetcode-1584-min-cost-to-connect-all-points/

码字不易,如果觉得内容有帮助,欢迎留言or点赞!


发表回复

您的电子邮箱地址不会被公开。