Leetcode 1584.连接所有点的最小费用
比较经典的一道最小生成树的题目,本来对这个问题兴趣不高——这个算法太教科书了。不过提交了几次,发现里面可以优化的地方蛮多的,这里记录下。
PS:最小生成树有两种算法,本文用的是Prim算法。
题目
You are given an array points
representing integer coordinates of some points on a 2D-plane, where points[i] = [xi, yi]
.
The cost of connecting two points [xi, yi]
and [xj, yj]
is the manhattan distance between them: |xi - xj| + |yi - yj|
, where |val|
denotes the absolute value of val
.
Return the minimum cost to make all points connected. All points are connected if there is exactly one simple path between any two points.
Example 1:
Input: points = [[0,0],[2,2],[3,10],[5,2],[7,0]] Output: 20 Explanation: We can connect the points as shown above to get the minimum cost of 20. Notice that there is a unique path between every pair of points.
Example 2:
Input: points = [[3,12],[-2,5],[-4,1]] Output: 18
Constraints:
1 <= points.length <= 1000
-106 <= xi, yi <= 106
- All pairs
(xi, yi)
are distinct.
分析
首先我得承认,最小生成树是个啥我已经忘光了。
拿到这个题以后一通抓瞎,看了提示发现是典型的最小生成树场景。
那么我们直接来看最小生成树之Prim算法。下面这个表达很好理解。
1.图的所有顶点集合为V;初始令集合u={s},v=V−u={s},v=V−u; 2.在两个集合u,v能够组成的边中,选择一条代价最小的边(u0,v0)(u0,v0),加入到最小生成树中,并把v0v0并入到集合u中。 3.重复上述步骤,直到最小生成树有n-1条边或者n个顶点为止。
下面我们开始撸代码
代码
v1
最朴素的写法,毫无意外超时了。
class Solution(object):
def minCostConnectPoints(self, points):
"""
:type points: List[List[int]]
:rtype: int
"""
if len(points) < 2:
return 0
n = len(points)
dists = [[None] * n for _ in range(n)]
ret = 0
for i in range(n):
for j in range(i + 1, n):
dis = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
dists[i][j] = dis
#edges = sorted(edges, key=lambda t:t[0])
s1, s2 = set([0]), set(list(range(1, n)))
#print(edges)
while len(s2) > 0:
e1, e2, d = None, None, None
for i in s1:
for j in s2:
dis = dists[i][j] if i < j else dists[j][i]
if d is None or dis < d:
e1 = i
e2 = j
d = dis
s1.add(e2)
s2.remove(e2)
ret += d
return ret
v2
分析一下,21&22行这个二重循环太憨憨了,每次计算新添加的边的时候,都要计算&对比s1*s2的所有边,新增一条边前后存在大量的重复计算。
如下图,一轮迭代前后,其实只有变动的点(蓝色)相关的边需要重新计算距离&与其他的边比较大小。
于是我们有了新的思路——用一个最小堆维护当前可添加的边,有新的端点被加入时,向堆中添加由新端点(蓝色)带来的新边,每轮迭代时从最小堆中pop出一个符合要求的最小边。代码如下
class Solution(object):
def minCostConnectPoints(self, points):
"""
:type points: List[List[int]]
:rtype: int
"""
import heapq
if len(points) < 2:
return 0
n = len(points)
dists = [[None] * n for _ in range(n)]
ret = 0
for i in range(n):
for j in range(i + 1, n):
dis = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
dists[i][j] = dis
dists[j][i] = dis
s1, s2 = set([0]), set(list(range(1, n)))
hp = [(dists[0][i], 0, i) for i in range(1, n)]
heapq.heapify(hp)
while len(s2) > 0:
dis, e1, e2 = heapq.heappop(hp)
while e1 in s1 and e2 in s1:
dis, e1, e2 = heapq.heappop(hp)
ret += dis
new_e = e2 if e1 in s1 else e1
s1.add(new_e)
s2.remove(new_e)
for e in s2:
dis = dists[new_e][e]
heapq.heappush(hp, (dis, e, new_e))
return ret
done
v3
那么问题来了,还能更快吗?
参考了下花花酱(https://zxi.mytechroad.com/blog/graph/leetcode-1584-min-cost-to-connect-all-points/)的思路。
其实,没必要把所有的可用边都存起来——只需要存“没加入的端点距离任意已加入的端点的最小距离”,就行了。代码如下。
class Solution(object):
def minCostConnectPoints(self, points):
"""
:type points: List[List[int]]
:rtype: int
"""
n = len(points)
def _dist(i, j):
return abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
min_dist = [None] + [_dist(0, i) for i in range(1, n)]
ans = 0
for i in range(1, n):
new_edge_weight = min(filter(lambda x:x is not None, min_dist))
new_node_idx = min_dist.index(new_edge_weight)
ans += new_edge_weight
min_dist[new_node_idx] = None
for i in range(n):
if min_dist[i] is None:
continue
min_dist[i] = min(min_dist[i], _dist(i, new_node_idx))
return ans
done!
其他
另一个比较有趣的小细节,v2里面用了预计算的距离,v3里面则没有。这是为什么呢?这里就不写了:)感兴趣的欢迎留言讨论下。
——此处是内容的分割线——
除非注明,否则均为广陌原创文章,转载必须以链接形式标明本文链接
本文链接:https://www.utopiafar.com/2022/04/27/leetcode-1584-min-cost-to-connect-all-points/
码字不易,如果觉得内容有帮助,欢迎留言or点赞!