Advanced Tree Problems

Master DFS ordering, tree + segment tree combinations, and advanced tree algorithms

Overview

Advanced tree problems are core content in competitive programming, involving DFS serialization, tree-data structure combinations, and complex tree algorithms. This chapter explores these advanced techniques for solving complex tree problems in world-class competitions.

Core Content

1. DFS Ordering and Tree Linearization

DFS ordering converts tree structure to linear sequence, transforming subtree queries to range queries.

// DFS ordering and tree linearization
#include <iostream>
#include <vector>
using namespace std;

class TreeDFS {
private:
    vector<vector<int>> adj;
    vector<int> dfs_order, start_time, end_time;
    vector<int> depth, parent;
    int timer;
    
    void dfs(int u, int p, int d) {
        start_time[u] = timer;
        dfs_order[timer++] = u;
        depth[u] = d;
        parent[u] = p;
        
        for (int v : adj[u]) {
            if (v != p) {
                dfs(v, u, d + 1);
            }
        }
        
        end_time[u] = timer - 1;
    }
    
public:
    void build(vector<vector<int>>& edges, int root = 0) {
        int n = edges.size() + 1;
        adj.resize(n);
        dfs_order.resize(n);
        start_time.resize(n);
        end_time.resize(n);
        depth.resize(n);
        parent.resize(n);
        timer = 0;
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
        
        dfs(root, -1, 0);
    }
    
    // Check if u is ancestor of v
    bool isAncestor(int u, int v) {
        return start_time[u] <= start_time[v] && end_time[v] <= end_time[u];
    }
    
    // Get DFS range of subtree
    pair<int, int> getSubtreeRange(int u) {
        return {start_time[u], end_time[u]};
    }
    
    void printDFSOrder() {
        cout << "DFS order: ";
        for (int i = 0; i < dfs_order.size(); i++) {
            cout << dfs_order[i] << " ";
        }
        cout << endl;
    }
};

int main() {
    vector<vector<int>> edges = {{0,1}, {0,2}, {1,3}, {1,4}, {2,5}};
    TreeDFS treeDFS;
    treeDFS.build(edges);
    treeDFS.printDFSOrder();
    
    cout << "Node 1 subtree range: [" 
         << treeDFS.getSubtreeRange(1).first << ", " 
         << treeDFS.getSubtreeRange(1).second << "]" << endl;
    
    return 0;
}

2. Tree + Segment Tree

Transform tree problems to segment tree problems using DFS ordering, supporting subtree updates and queries.

// Tree + Segment Tree for subtree problems
#include <iostream>
#include <vector>
using namespace std;

class TreeSegmentTree {
private:
    vector<vector<int>> adj;
    vector<int> start_time, end_time, values;
    vector<long long> tree, lazy;
    int n, timer;
    
    void dfs(int u, int p) {
        start_time[u] = timer++;
        for (int v : adj[u]) {
            if (v != p) dfs(v, u);
        }
        end_time[u] = timer - 1;
    }
    
    void build(int node, int start, int end) {
        if (start == end) {
            tree[node] = values[start];
        } else {
            int mid = (start + end) / 2;
            build(2*node, start, mid);
            build(2*node+1, mid+1, end);
            tree[node] = tree[2*node] + tree[2*node+1];
        }
    }
    
    void updateLazy(int node, int start, int end) {
        if (lazy[node] != 0) {
            tree[node] += lazy[node] * (end - start + 1);
            if (start != end) {
                lazy[2*node] += lazy[node];
                lazy[2*node+1] += lazy[node];
            }
            lazy[node] = 0;
        }
    }
    
    void updateRange(int node, int start, int end, int l, int r, long long val) {
        updateLazy(node, start, end);
        if (start > r || end < l) return;
        
        if (start >= l && end <= r) {
            lazy[node] += val;
            updateLazy(node, start, end);
            return;
        }
        
        int mid = (start + end) / 2;
        updateRange(2*node, start, mid, l, r, val);
        updateRange(2*node+1, mid+1, end, l, r, val);
        
        updateLazy(2*node, start, mid);
        updateLazy(2*node+1, mid+1, end);
        tree[node] = tree[2*node] + tree[2*node+1];
    }
    
    long long queryRange(int node, int start, int end, int l, int r) {
        if (start > r || end < l) return 0;
        updateLazy(node, start, end);
        
        if (start >= l && end <= r) return tree[node];
        
        int mid = (start + end) / 2;
        return queryRange(2*node, start, mid, l, r) + 
               queryRange(2*node+1, mid+1, end, l, r);
    }
    
public:
    void initialize(vector<vector<int>>& edges, vector<int>& nodeValues, int root = 0) {
        n = edges.size() + 1;
        adj.resize(n);
        start_time.resize(n);
        end_time.resize(n);
        values.resize(n);
        tree.resize(4 * n);
        lazy.resize(4 * n);
        timer = 0;
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
        
        dfs(root, -1);
        
        // Rearrange values by DFS order
        for (int i = 0; i < n; i++) {
            values[start_time[i]] = nodeValues[i];
        }
        
        build(1, 0, n - 1);
    }
    
    // Subtree update
    void updateSubtree(int u, long long val) {
        updateRange(1, 0, n - 1, start_time[u], end_time[u], val);
    }
    
    // Subtree query
    long long querySubtree(int u) {
        return queryRange(1, 0, n - 1, start_time[u], end_time[u]);
    }
};

int main() {
    vector<vector<int>> edges = {{0,1}, {0,2}, {1,3}, {1,4}, {2,5}};
    vector<int> values = {1, 2, 3, 4, 5, 6};
    
    TreeSegmentTree tst;
    tst.initialize(edges, values);
    
    cout << "Node 1 subtree sum: " << tst.querySubtree(1) << endl;
    
    tst.updateSubtree(1, 10);
    cout << "After update, node 1 subtree sum: " << tst.querySubtree(1) << endl;
    
    return 0;
}

3. Advanced Tree Algorithms

// Tree chain decomposition and LCA
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

class TreeChainDecomposition {
private:
    vector<vector<int>> adj;
    vector<int> parent, depth, heavy, head, pos;
    int timer;
    
    int dfs1(int u) {
        int size = 1, maxSize = 0;
        for (int v : adj[u]) {
            if (v != parent[u]) {
                parent[v] = u;
                depth[v] = depth[u] + 1;
                int childSize = dfs1(v);
                if (childSize > maxSize) {
                    maxSize = childSize;
                    heavy[u] = v;
                }
                size += childSize;
            }
        }
        return size;
    }
    
    void dfs2(int u, int h) {
        head[u] = h;
        pos[u] = timer++;
        
        if (heavy[u] != -1) {
            dfs2(heavy[u], h);
        }
        
        for (int v : adj[u]) {
            if (v != parent[u] && v != heavy[u]) {
                dfs2(v, v);
            }
        }
    }
    
public:
    void build(vector<vector<int>>& edges, int root = 0) {
        int n = edges.size() + 1;
        adj.resize(n);
        parent.resize(n);
        depth.resize(n);
        heavy.assign(n, -1);
        head.resize(n);
        pos.resize(n);
        timer = 0;
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
        
        parent[root] = -1;
        depth[root] = 0;
        dfs1(root);
        dfs2(root, root);
    }
    
    int lca(int u, int v) {
        while (head[u] != head[v]) {
            if (depth[head[u]] > depth[head[v]]) {
                u = parent[head[u]];
            } else {
                v = parent[head[v]];
            }
        }
        return depth[u] < depth[v] ? u : v;
    }
    
    int distance(int u, int v) {
        return depth[u] + depth[v] - 2 * depth[lca(u, v)];
    }
};

// Tree diameter and center
class TreeDiameterCenter {
private:
    vector<vector<int>> adj;
    vector<int> dist;
    
    pair<int, int> bfs(int start) {
        fill(dist.begin(), dist.end(), -1);
        queue<int> q;
        q.push(start);
        dist[start] = 0;
        
        int farthest = start, maxDist = 0;
        while (!q.empty()) {
            int u = q.front();
            q.pop();
            
            for (int v : adj[u]) {
                if (dist[v] == -1) {
                    dist[v] = dist[u] + 1;
                    q.push(v);
                    if (dist[v] > maxDist) {
                        maxDist = dist[v];
                        farthest = v;
                    }
                }
            }
        }
        
        return {farthest, maxDist};
    }
    
public:
    void build(vector<vector<int>>& edges) {
        int n = edges.size() + 1;
        adj.resize(n);
        dist.resize(n);
        
        for (auto& edge : edges) {
            adj[edge[0]].push_back(edge[1]);
            adj[edge[1]].push_back(edge[0]);
        }
    }
    
    pair<int, vector<int>> findDiameter() {
        // First BFS to find farthest point
        auto [u, _] = bfs(0);
        
        // Second BFS to find diameter
        auto [v, diameter] = bfs(u);
        
        // Reconstruct diameter path
        vector<int> path;
        int curr = v;
        while (curr != u) {
            path.push_back(curr);
            for (int next : adj[curr]) {
                if (dist[next] == dist[curr] - 1) {
                    curr = next;
                    break;
                }
            }
        }
        path.push_back(u);
        reverse(path.begin(), path.end());
        
        return {diameter, path};
    }
    
    vector<int> findCenter() {
        auto [diameter, path] = findDiameter();
        vector<int> centers;
        
        int n = path.size();
        if (n % 2 == 1) {
            centers.push_back(path[n / 2]);
        } else {
            centers.push_back(path[n / 2 - 1]);
            centers.push_back(path[n / 2]);
        }
        
        return centers;
    }
};

int main() {
    vector<vector<int>> edges = {{0,1}, {1,2}, {2,3}, {1,4}, {4,5}};
    
    // Test tree chain decomposition
    TreeChainDecomposition tcd;
    tcd.build(edges);
    
    cout << "LCA of nodes 3 and 5: " << tcd.lca(3, 5) << endl;
    cout << "Distance between nodes 3 and 5: " << tcd.distance(3, 5) << endl;
    
    // Test diameter and center
    TreeDiameterCenter tdc;
    tdc.build(edges);
    
    auto [diameter, path] = tdc.findDiameter();
    cout << "Tree diameter: " << diameter << endl;
    
    vector<int> centers = tdc.findCenter();
    cout << "Tree centers: ";
    for (int center : centers) {
        cout << center << " ";
    }
    cout << endl;
    
    return 0;
}

Problem-Solving Tips

🎯 DFS Order Applications

Use DFS ordering to transform tree problems to sequence problems, converting subtree operations to range operations, greatly simplifying complexity.

⚡ Data Structure Integration

Combining trees with segment trees, Fenwick trees, and other data structures enables efficient range queries and updates on trees.

🔄 Tree Properties

Deeply understand tree properties: diameter, center, centroid, etc. These properties are very useful for solving complex tree problems.

🧮 Algorithm Optimization

Use heavy-light decomposition, LCA preprocessing, and other techniques to optimize tree queries, reducing complexity from O(n) to O(log n).