问题
最小生成树的模板题。之前用Kruskal解过LeetCode上针对边的问题。这次用Prim算法解决。
解决
标准的Prim算法不再赘述。其实Prim算法与Dijkstra算法很类似。区别是,Dijkstra算法每次更新的是新扩增的点到其后继节点的最短距离,而Prim算法每次更新的是,当构建最小生成树的集合(记作MST)中增加新的点之后,更新这个集合与剩余集合之间的最短距离。
如果直接按照这种方式实现,那么找最短边时需要大量遍历。但其实,我们可以维护一个数组,记录每个点到MST集合的最短距离。这些值什么时候可能会变呢?其实就是MST集合每新增一个点时会变,因为可能某剩余点i与新增点new之间的距离小于原来i到原MST集合的距离。因此,只需要每次往MST中新增点new时,更新new的后继节点到MST的距离即可。由于这里有找后继的操作,因此显然图的表示方法采用邻接表(adjacency list)的方式最好。
在此基础上,如果不想每次找距离MST最近的点时需要遍历所有的点,就需要借助小顶堆来实现。小顶堆其实就是优先队列。这个优先队列的每个元素是一个二元组,第一个值是点的序号,第二个值是该点到MST的距离。一开始,我们随机选一个点加入MST,然后将其后继节点加入到优先队列中。其它点不加是因为他们目前到MST的距离是无穷大,加了也没有意义。之后同理,每次新增点new进入MST时,如果new->剩余点i的距离小于i之前到MST的距离,就可以把(i, weight(new, i))加入优先队列,意味着new加入MST之后,i到MST的距离变成了更小的new->i的距离。至于之前加入优先队列的信息可以不管了,因为最优的一定会跑到top去。优先队列保证了top元素是到MST距离最短的,只要它还是剩余点,它就是应该被新扩增的点。
实现
C++的priority_queue的使用方式需要注意。
声明格式:priority_queue<元素类型, 容器类型, 比较函数> 名称;
比较函数默认是less,对应大顶堆!(与直觉是反的。)
如果自定义为greater,就是小顶堆。
自定义比较函数的话就需要补全第二个参数容器类型,一般用vector。
#include <cstdio>
#include <iostream>
#include <queue>
#include <cstring>
#include <list>
#include <stack>
#include <ctime>
using namespace std;
class Prim
{
vector< vector< pair<int, int> > >& graph; // adjacency list
int n;
int mst_total;
struct Dist
{
int node;
int dist;
Dist(const int _node, const int _dist):
node(_node), dist(_dist) {}
friend bool operator< (const Dist& x, const Dist& y)
{
return x.dist < y.dist;
}
friend bool operator> (const Dist& x, const Dist& y)
{
return x.dist > y.dist;
}
};
int getRandom()
{
srand(time(0));
return rand();
}
public:
Prim(const int _n, vector< vector< pair<int, int> > >& _graph):
n(_n), graph(_graph), mst_total(0)
{
mst();
}
int mst()
{
const int inf = 1000000000;
vector<int> dist_to_mst(n, inf);
priority_queue<Dist, vector<Dist>, greater<Dist>> q; // greater<int> minimum is on the top
vector<bool> visited(n, false);
int mst_n = 0;
int start = getRandom() % n;
dist_to_mst[start] = 0;
visited[start] = true;
//q.emplace(Dist(start, 0));
mst_n++;
for (auto& nxt : graph[start])
{
dist_to_mst[nxt.first] = nxt.second;
q.emplace(Dist(nxt.first, nxt.second));
}
while (!q.empty())
{
auto now = q.top();
q.pop();
if (!visited[now.node])
{
visited[now.node] = true;
// dist_to_mst[now.node] = 0;
mst_total += now.dist;
mst_n++;
for (auto& nxt : graph[now.node])
{
// now -> next.node
if (!visited[nxt.first] && nxt.second < dist_to_mst[nxt.first])
{
dist_to_mst[nxt.first] = nxt.second;
q.emplace(Dist(nxt.first, nxt.second));
}
}
}
}
if (mst_n == n)
return mst_total;
else
{
mst_total = -1;
return -1;
}
}
string getAns()
{
if (mst_total != -1)
return to_string(mst_total);
else
return "orz";
}
};
int main()
{
int n, m;
cin >> n >> m;
vector< vector< pair<int, int> > > graph(n);
for (int i = 0; i < m; i++)
{
int u, v, w;
cin >> u >> v >> w;
u--, v--;
graph[u].emplace_back(v, w);
graph[v].emplace_back(u, w);
}
cout << Prim(n, graph).getAns() << endl;
return 0;
}