高级树问题

掌握DFS序、树与线段树结合等高级树算法技巧

概述

高级树问题是算法竞赛中的核心内容,涉及DFS序列化、树与数据结构结合、复杂树算法等技术。本章节深入探讨这些高级技巧,为解决世界级竞赛中的复杂树问题提供强大工具。

核心内容

1. DFS序与树的线性化

DFS序将树结构转换为线性序列,使得树上的子树查询转化为区间查询。

// DFS序与树的线性化
#include <iostream>
#include <vector>
using namespace std;

class TreeDFS {
private:
    vector<vector<int>> adj;
    vector<int> dfs_order, start_time, end_time;
    vector<int> depth, parent;
    int timer;
    
    void dfs(int u, int p, int d) {
        start_time[u] = timer;
        dfs_order[timer++] = u;
        depth[u] = d;
        parent[u] = p;
        
        for (int v : adj[u]) {
            if (v != p) {
                dfs(v, u, d + 1);
            }
        }
        
        end_time[u] = timer - 1;
    }
    
public:
    void build(vector<vector<int>>& edges, int root = 0) {
        int n = edges.size() + 1;
        adj.resize(n);
        dfs_order.resize(n);
        start_time.resize(n);
        end_time.resize(n);
        depth.resize(n);
        parent.resize(n);
        timer = 0;
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
        
        dfs(root, -1, 0);
    }
    
    // 检查u是否是v的祖先
    bool isAncestor(int u, int v) {
        return start_time[u] <= start_time[v] && end_time[v] <= end_time[u];
    }
    
    // 获取子树的DFS序区间
    pair<int, int> getSubtreeRange(int u) {
        return {start_time[u], end_time[u]};
    }
    
    void printDFSOrder() {
        cout << "DFS序: ";
        for (int i = 0; i < dfs_order.size(); i++) {
            cout << dfs_order[i] << " ";
        }
        cout << endl;
    }
};

int main() {
    vector<vector<int>> edges = {{0,1}, {0,2}, {1,3}, {1,4}, {2,5}};
    TreeDFS treeDFS;
    treeDFS.build(edges);
    treeDFS.printDFSOrder();
    
    cout << "节点1的子树区间:[" 
         << treeDFS.getSubtreeRange(1).first << ", " 
         << treeDFS.getSubtreeRange(1).second << "]" << endl;
    
    return 0;
}

2. 树与线段树结合

通过DFS序将树问题转化为线段树问题,支持子树更新和查询操作。

// 树与线段树结合解决子树问题
#include <iostream>
#include <vector>
using namespace std;

class TreeSegmentTree {
private:
    vector<vector<int>> adj;
    vector<int> start_time, end_time, values;
    vector<long long> tree, lazy;
    int n, timer;
    
    void dfs(int u, int p) {
        start_time[u] = timer++;
        for (int v : adj[u]) {
            if (v != p) dfs(v, u);
        }
        end_time[u] = timer - 1;
    }
    
    void build(int node, int start, int end) {
        if (start == end) {
            tree[node] = values[start];
        } else {
            int mid = (start + end) / 2;
            build(2*node, start, mid);
            build(2*node+1, mid+1, end);
            tree[node] = tree[2*node] + tree[2*node+1];
        }
    }
    
    void updateLazy(int node, int start, int end) {
        if (lazy[node] != 0) {
            tree[node] += lazy[node] * (end - start + 1);
            if (start != end) {
                lazy[2*node] += lazy[node];
                lazy[2*node+1] += lazy[node];
            }
            lazy[node] = 0;
        }
    }
    
    void updateRange(int node, int start, int end, int l, int r, long long val) {
        updateLazy(node, start, end);
        if (start > r || end < l) return;
        
        if (start >= l && end <= r) {
            lazy[node] += val;
            updateLazy(node, start, end);
            return;
        }
        
        int mid = (start + end) / 2;
        updateRange(2*node, start, mid, l, r, val);
        updateRange(2*node+1, mid+1, end, l, r, val);
        
        updateLazy(2*node, start, mid);
        updateLazy(2*node+1, mid+1, end);
        tree[node] = tree[2*node] + tree[2*node+1];
    }
    
    long long queryRange(int node, int start, int end, int l, int r) {
        if (start > r || end < l) return 0;
        updateLazy(node, start, end);
        
        if (start >= l && end <= r) return tree[node];
        
        int mid = (start + end) / 2;
        return queryRange(2*node, start, mid, l, r) + 
               queryRange(2*node+1, mid+1, end, l, r);
    }
    
public:
    void initialize(vector<vector<int>>& edges, vector<int>& nodeValues, int root = 0) {
        n = edges.size() + 1;
        adj.resize(n);
        start_time.resize(n);
        end_time.resize(n);
        values.resize(n);
        tree.resize(4 * n);
        lazy.resize(4 * n);
        timer = 0;
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
        
        dfs(root, -1);
        
        // 按DFS序重排值
        for (int i = 0; i < n; i++) {
            values[start_time[i]] = nodeValues[i];
        }
        
        build(1, 0, n - 1);
    }
    
    // 子树更新
    void updateSubtree(int u, long long val) {
        updateRange(1, 0, n - 1, start_time[u], end_time[u], val);
    }
    
    // 子树查询
    long long querySubtree(int u) {
        return queryRange(1, 0, n - 1, start_time[u], end_time[u]);
    }
};

int main() {
    vector<vector<int>> edges = {{0,1}, {0,2}, {1,3}, {1,4}, {2,5}};
    vector<int> values = {1, 2, 3, 4, 5, 6};
    
    TreeSegmentTree tst;
    tst.initialize(edges, values);
    
    cout << "节点1子树和: " << tst.querySubtree(1) << endl;
    
    tst.updateSubtree(1, 10);
    cout << "更新后节点1子树和: " << tst.querySubtree(1) << endl;
    
    return 0;
}

3. 高级树算法

// 树链剖分与LCA
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

class TreeChainDecomposition {
private:
    vector<vector<int>> adj;
    vector<int> parent, depth, heavy, head, pos;
    int timer;
    
    int dfs1(int u) {
        int size = 1, maxSize = 0;
        for (int v : adj[u]) {
            if (v != parent[u]) {
                parent[v] = u;
                depth[v] = depth[u] + 1;
                int childSize = dfs1(v);
                if (childSize > maxSize) {
                    maxSize = childSize;
                    heavy[u] = v;
                }
                size += childSize;
            }
        }
        return size;
    }
    
    void dfs2(int u, int h) {
        head[u] = h;
        pos[u] = timer++;
        
        if (heavy[u] != -1) {
            dfs2(heavy[u], h);
        }
        
        for (int v : adj[u]) {
            if (v != parent[u] && v != heavy[u]) {
                dfs2(v, v);
            }
        }
    }
    
public:
    void build(vector<vector<int>>& edges, int root = 0) {
        int n = edges.size() + 1;
        adj.resize(n);
        parent.resize(n);
        depth.resize(n);
        heavy.assign(n, -1);
        head.resize(n);
        pos.resize(n);
        timer = 0;
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
        
        parent[root] = -1;
        depth[root] = 0;
        dfs1(root);
        dfs2(root, root);
    }
    
    int lca(int u, int v) {
        while (head[u] != head[v]) {
            if (depth[head[u]] > depth[head[v]]) {
                u = parent[head[u]];
            } else {
                v = parent[head[v]];
            }
        }
        return depth[u] < depth[v] ? u : v;
    }
    
    int distance(int u, int v) {
        return depth[u] + depth[v] - 2 * depth[lca(u, v)];
    }
};

// 树的直径和中心
class TreeDiameterCenter {
private:
    vector<vector<int>> adj;
    vector<int> dist;
    
    pair<int, int> bfs(int start) {
        fill(dist.begin(), dist.end(), -1);
        queue<int> q;
        q.push(start);
        dist[start] = 0;
        
        int farthest = start, maxDist = 0;
        while (!q.empty()) {
            int u = q.front();
            q.pop();
            
            for (int v : adj[u]) {
                if (dist[v] == -1) {
                    dist[v] = dist[u] + 1;
                    q.push(v);
                    if (dist[v] > maxDist) {
                        maxDist = dist[v];
                        farthest = v;
                    }
                }
            }
        }
        
        return {farthest, maxDist};
    }
    
public:
    void build(vector<vector<int>>& edges) {
        int n = edges.size() + 1;
        adj.resize(n);
        dist.resize(n);
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
    }
    
    pair<int, vector<int>> findDiameter() {
        // 第一次BFS找到最远点
        auto [u, _] = bfs(0);
        
        // 第二次BFS找到直径
        auto [v, diameter] = bfs(u);
        
        // 重构直径路径
        vector<int> path;
        int curr = v;
        while (curr != u) {
            path.push_back(curr);
            for (int next : adj[curr]) {
                if (dist[next] == dist[curr] - 1) {
                    curr = next;
                    break;
                }
            }
        }
        path.push_back(u);
        reverse(path.begin(), path.end());
        
        return {diameter, path};
    }
    
    vector<int> findCenter() {
        auto [diameter, path] = findDiameter();
        vector<int> centers;
        
        int n = path.size();
        if (n % 2 == 1) {
            centers.push_back(path[n / 2]);
        } else {
            centers.push_back(path[n / 2 - 1]);
            centers.push_back(path[n / 2]);
        }
        
        return centers;
    }
};

int main() {
    vector<vector<int>> edges = {{0,1}, {1,2}, {2,3}, {1,4}, {4,5}};
    
    // 测试树链剖分
    TreeChainDecomposition tcd;
    tcd.build(edges);
    
    cout << "节点3和5的LCA: " << tcd.lca(3, 5) << endl;
    cout << "节点3和5的距离: " << tcd.distance(3, 5) << endl;
    
    // 测试直径和中心
    TreeDiameterCenter tdc;
    tdc.build(edges);
    
    auto [diameter, path] = tdc.findDiameter();
    cout << "树的直径: " << diameter << endl;
    
    vector<int> centers = tdc.findCenter();
    cout << "树的中心: ";
    for (int center : centers) {
        cout << center << " ";
    }
    cout << endl;
    
    return 0;
}

解题技巧

🎯 DFS序应用

利用DFS序将树问题转化为序列问题,子树操作变为区间操作,大大简化问题复杂度。

⚡ 数据结构结合

树与线段树、树状数组等数据结构结合,可以高效处理树上的区间查询和更新操作。

🔄 树的性质

深入理解树的性质:直径、中心、重心等,这些性质在解决复杂树问题时非常有用。

🧮 算法优化

使用重链剖分、LCA预处理等技术优化树上查询,将复杂度从O(n)降低到O(log n)。