Leetcode 310: 最小高度树

310 最小高度树

每日一题 4月6日

给定一个树的各个边,要求选择根节点,使得树的高度最小

初步的想法是使用广度优先搜索,对于每一个根节点进行尝试,找到最小的那个。因为广度优先搜索的复杂度为O(n),因此整体复杂度为O(n^2)

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <queue>
using namespace std;

struct Graph{
    vector<vector<int>> arr;
    int node_num;
    Graph(int n, vector<vector<int>> edges){
        this->node_num = n;
        this->arr = vector<vector<int>>(n,vector<int>());
        for(int i=0;i<edges.size();i++){
            int s,t;
            s = edges[i][0];
            t = edges[i][1];
            arr[s].push_back(t);
            arr[t].push_back(s);
        }
    }

    int bfs_maxlength(int node){
        // bfs distance from node
        queue<int> q;
        vector<int> distance(this->node_num,-1);
        distance[node] = 0;
        q.push(node);
        int max_distance = 0;
        while(!q.empty()){
            int cur_node = q.front(); q.pop();
            for(int i=0;i<this->arr[cur_node].size();i++){
                int dest_node = this->arr[cur_node][i];
                if(distance[dest_node] == -1){
                    q.push(dest_node);
                    distance[dest_node] = distance[cur_node] + 1;
                    if(distance[dest_node] > max_distance){
                        max_distance = distance[dest_node];
                    }
                }
            }
        }
        return max_distance;
    }
};

class Solution {
public:
    vector<int> findMinHeightTrees(int n, vector<vector<int>>& edges) {
        if(n==1){
            return vector<int>{0};
        }else if(n==2){
            return edges[0];
        }
        Graph graph(n,edges);
        vector<int> result;
        int min_dist = n;
        for(int i=0;i<n;i++){
            int dist = graph.bfs_maxlength(i);
//            cout<<"node "<<i<<" has dist "<<dist<<endl;
            if(dist < min_dist){
                min_dist = dist;
                result = vector<int>{i};
            }else if(dist == min_dist){
                result.push_back(i);
            }
        }
        return result;
    }
};

int main() {
    Solution s;
    vector<vector<int>> edges;
    edges.push_back(vector<int>{0,1});
    edges.push_back(vector<int>{0,2});
    edges.push_back(vector<int>{0,3});
    edges.push_back(vector<int>{3,4});
    edges.push_back(vector<int>{4,5});

//    edges.push_back(vector<int>{1,2});
//    edges.push_back(vector<int>{1,3});
//    for(int i=0;i<edges.size();i++){
//        for(int j=0;j<edges[i].size();j++){
//            cout<<edges[i][j]<<" ";
//        }
//        cout<<endl;
//    }
    vector<int> result(s.findMinHeightTrees(edges.size()+1,edges));
//    cout<<result.size()<<endl;
    for(int i=0;i<result.size();i++){
        cout<<result[i]<<endl;
    }
    return 0;
}

不出意外,直接超时。

看了一下题解,感觉自己有很多地方没有想到。

有两个应该想到的定理,一个是最大树的高度一定为最长距离的一半,另外一个是根节点一定在这个最大距离上。因此问题转换为如何求出这个最大距离。

最大距离的求解很简单,但是证明起来很麻烦,这里就不放了。原证明为算法导论9-1的解答,求解过程大致为:

  1. 从任意点出发,找到距离其最大的节点x
  2. 从节点x出发,找到距离其最大的节点y
  3. x-y即为树的直径,其路径最大。

这个是leetcode官方给的题解,如果有证明的话后面其实可以不用写了……这道题还可以用树形DP,不过我实在没弄明白。

class Solution {
public:
    int findLongestNode(int u, vector<int> & parent, vector<vector<int>>& adj) {
        int n = adj.size();
        queue<int> qu;
        vector<bool> visit(n);
        qu.emplace(u);
        visit[u] = true;
        int node = -1;
  
        while (!qu.empty()) {
            int curr = qu.front();
            qu.pop();
            node = curr;
            for (auto & v : adj[curr]) {
                if (!visit[v]) {
                    visit[v] = true;
                    parent[v] = curr;
                    qu.emplace(v);
                }
            }
        }
        return node;
    }

    vector<int> findMinHeightTrees(int n, vector<vector<int>>& edges) {
        if (n == 1) {
            return {0};
        }
        vector<vector<int>> adj(n);
        for (auto & edge : edges) {
            adj[edge[0]].emplace_back(edge[1]);
            adj[edge[1]].emplace_back(edge[0]);
        }
        
        vector<int> parent(n, -1);
        /* 找到与节点 0 最远的节点 x */
        int x = findLongestNode(0, parent, adj);
        /* 找到与节点 x 最远的节点 y */
        int y = findLongestNode(x, parent, adj);
        /* 求出节点 x 到节点 y 的路径 */
        vector<int> path;
        parent[x] = -1;
        while (y != -1) {
            path.emplace_back(y);
            y = parent[y];
        }
        int m = path.size();
        if (m % 2 == 0) {
            return {path[m / 2 - 1], path[m / 2]};
        } else {
            return {path[m / 2]};
        }
    }
};
0%