问题

洛谷 P3366

最小生成树的模板题。之前用Kruskal解过LeetCode上针对边的问题。这次用Prim算法解决。

解决

标准的Prim算法不再赘述。其实Prim算法与Dijkstra算法很类似。区别是,Dijkstra算法每次更新的是新扩增的点到其后继节点的最短距离,而Prim算法每次更新的是,当构建最小生成树的集合(记作MST)中增加新的点之后,更新这个集合与剩余集合之间的最短距离。

如果直接按照这种方式实现,那么找最短边时需要大量遍历。但其实,我们可以维护一个数组,记录每个点到MST集合的最短距离。这些值什么时候可能会变呢?其实就是MST集合每新增一个点时会变,因为可能某剩余点i与新增点new之间的距离小于原来i到原MST集合的距离。因此,只需要每次往MST中新增点new时,更新new的后继节点到MST的距离即可。由于这里有找后继的操作,因此显然图的表示方法采用邻接表(adjacency list)的方式最好。

在此基础上,如果不想每次找距离MST最近的点时需要遍历所有的点,就需要借助小顶堆来实现。小顶堆其实就是优先队列。这个优先队列的每个元素是一个二元组,第一个值是点的序号,第二个值是该点到MST的距离。一开始,我们随机选一个点加入MST,然后将其后继节点加入到优先队列中。其它点不加是因为他们目前到MST的距离是无穷大,加了也没有意义。之后同理,每次新增点new进入MST时,如果new->剩余点i的距离小于i之前到MST的距离,就可以把(i, weight(new, i))加入优先队列,意味着new加入MST之后,i到MST的距离变成了更小的new->i的距离。至于之前加入优先队列的信息可以不管了,因为最优的一定会跑到top去。优先队列保证了top元素是到MST距离最短的,只要它还是剩余点,它就是应该被新扩增的点。

实现

C++的priority_queue的使用方式需要注意。

声明格式:priority_queue<元素类型, 容器类型, 比较函数> 名称;
比较函数默认是less,对应大顶堆!(与直觉是反的。)
如果自定义为greater,就是小顶堆。
自定义比较函数的话就需要补全第二个参数容器类型,一般用vector。

#include <cstdio>
#include <iostream>
#include <queue>
#include <cstring>
#include <list>
#include <stack>
#include <ctime>
using namespace std;

class Prim
{
    vector< vector< pair<int, int> > >& graph;   // adjacency list
    int n;
    int mst_total;
    struct Dist
    {
        int node;
        int dist;

        Dist(const int _node, const int _dist):
            node(_node), dist(_dist) {}

        friend bool operator< (const Dist& x, const Dist& y)
        {
            return x.dist < y.dist;
        }

        friend bool operator> (const Dist& x, const Dist& y)
        {
            return x.dist > y.dist;
        }
    };

    int getRandom()
    {
        srand(time(0));
        return rand();
    }

public:
    Prim(const int _n, vector< vector< pair<int, int> > >& _graph):
        n(_n), graph(_graph), mst_total(0)
    {
        mst();
    }

    int mst()
    {
        const int inf = 1000000000;
        vector<int> dist_to_mst(n, inf);
        priority_queue<Dist, vector<Dist>, greater<Dist>> q;    // greater<int> minimum is on the top
        vector<bool> visited(n, false);
        int mst_n = 0;
        int start = getRandom() % n;
        dist_to_mst[start] = 0;
        visited[start] = true;
        //q.emplace(Dist(start, 0));
        mst_n++;
        for (auto& nxt : graph[start])
        {
            dist_to_mst[nxt.first] = nxt.second;
            q.emplace(Dist(nxt.first, nxt.second));
        }
        while (!q.empty())
        {
            auto now = q.top();
            q.pop();
            if (!visited[now.node])
            {
                visited[now.node] = true;
                // dist_to_mst[now.node] = 0;
                mst_total += now.dist;
                mst_n++;
                for (auto& nxt : graph[now.node])
                {
                    // now -> next.node
                    if (!visited[nxt.first] && nxt.second < dist_to_mst[nxt.first])
                    {
                        dist_to_mst[nxt.first] = nxt.second;
                        q.emplace(Dist(nxt.first, nxt.second));
                    }
                }
            }
        }
        if (mst_n == n)
            return mst_total;
        else
        {
            mst_total = -1;
            return -1;
        }
    }

    string getAns()
    {
        if (mst_total != -1)
            return to_string(mst_total);
        else
            return "orz";
    }
};

int main()
{
    int n, m;
    cin >> n >> m;
    vector< vector< pair<int, int> > > graph(n);
    for (int i = 0; i < m; i++)
    {
        int u, v, w;
        cin >> u >> v >> w;
        u--, v--;
        graph[u].emplace_back(v, w);
        graph[v].emplace_back(u, w);
    }
    cout << Prim(n, graph).getAns() << endl;

    return 0;
}
Last modification:June 27th, 2020 at 03:49 pm
如果觉得我的文章对你有用,请随意赞赏