高级线段树

掌握可持久化线段树、复杂区间操作和高级线段树技术

概述

高级线段树技术是处理复杂区间问题的核心工具。本章节深入探讨可持久化线段树、李超线段树、动态开点线段树等高级技术,为解决世界级竞赛中的复杂区间查询和更新问题提供强大支持。

核心内容

1. 可持久化线段树

可持久化线段树保存数据结构的历史版本,支持查询任意历史状态。

// 可持久化线段树实现
#include <iostream>
#include <vector>
using namespace std;

struct PersistentSegmentTree {
    struct Node {
        int left, right, sum;
        Node() : left(0), right(0), sum(0) {}
        Node(int l, int r, int s) : left(l), right(r), sum(s) {}
    };
    
    vector<Node> tree;
    vector<int> roots;
    int n, nodeCount;
    
    PersistentSegmentTree(vector<int>& arr) {
        n = arr.size();
        tree.reserve(n * 20); // 预留足够空间
        nodeCount = 0;
        
        // 构建初始版本
        roots.push_back(build(arr, 0, n - 1));
    }
    
    int build(vector<int>& arr, int l, int r) {
        int node = nodeCount++;
        tree.push_back(Node());
        
        if (l == r) {
            tree[node].sum = arr[l];
        } else {
            int mid = (l + r) / 2;
            tree[node].left = build(arr, l, mid);
            tree[node].right = build(arr, mid + 1, r);
            tree[node].sum = tree[tree[node].left].sum + tree[tree[node].right].sum;
        }
        
        return node;
    }
    
    // 创建新版本(单点更新)
    int update(int prevRoot, int l, int r, int pos, int val) {
        int node = nodeCount++;
        tree.push_back(Node());
        
        if (l == r) {
            tree[node].sum = val;
        } else {
            int mid = (l + r) / 2;
            if (pos <= mid) {
                tree[node].left = update(tree[prevRoot].left, l, mid, pos, val);
                tree[node].right = tree[prevRoot].right;
            } else {
                tree[node].left = tree[prevRoot].left;
                tree[node].right = update(tree[prevRoot].right, mid + 1, r, pos, val);
            }
            tree[node].sum = tree[tree[node].left].sum + tree[tree[node].right].sum;
        }
        
        return node;
    }
    
    // 查询历史版本
    int query(int root, int l, int r, int ql, int qr) {
        if (ql > r || qr < l) return 0;
        if (ql <= l && r <= qr) return tree[root].sum;
        
        int mid = (l + r) / 2;
        return query(tree[root].left, l, mid, ql, qr) + 
               query(tree[root].right, mid + 1, r, ql, qr);
    }
    
    // 添加新版本
    void addVersion(int pos, int val) {
        int newRoot = update(roots.back(), 0, n - 1, pos, val);
        roots.push_back(newRoot);
    }
    
    // 查询指定版本
    int queryVersion(int version, int l, int r) {
        return query(roots[version], 0, n - 1, l, r);
    }
    
    // 获取版本数
    int getVersionCount() {
        return roots.size();
    }
};

// 可持久化数组应用
class PersistentArray {
private:
    PersistentSegmentTree* pst;
    
public:
    PersistentArray(vector<int>& initial) {
        pst = new PersistentSegmentTree(initial);
    }
    
    void set(int pos, int val) {
        pst->addVersion(pos, val);
    }
    
    int get(int version, int pos) {
        return pst->queryVersion(version, pos, pos);
    }
    
    int rangeSum(int version, int l, int r) {
        return pst->queryVersion(version, l, r);
    }
    
    int getCurrentVersion() {
        return pst->getVersionCount() - 1;
    }
};

int main() {
    vector<int> arr = {1, 2, 3, 4, 5};
    PersistentArray pa(arr);
    
    cout << "初始版本区间和[1,3]: " 
         << pa.rangeSum(0, 1, 3) << endl;
    
    // 修改位置2为10
    pa.set(2, 10);
    
    cout << "版本0位置2的值: " << pa.get(0, 2) << endl;
    cout << "版本1位置2的值: " << pa.get(1, 2) << endl;
    
    cout << "版本1区间和[1,3]: " 
         << pa.rangeSum(1, 1, 3) << endl;
    
    return 0;
}

2. 李超线段树

李超线段树专门处理直线查询问题,支持动态添加直线和查询最值。

// 李超线段树实现
#include <iostream>
#include <vector>
#include <climits>
using namespace std;

struct Line {
    long long k, b; // y = kx + b
    int id;
    
    Line() : k(0), b(LLONG_MIN), id(-1) {}
    Line(long long _k, long long _b, int _id) : k(_k), b(_b), id(_id) {}
    
    long long eval(long long x) {
        return k * x + b;
    }
};

class LiChaoTree {
private:
    vector<Line> tree;
    int n;
    
    // 判断直线1在x处是否优于直线2
    bool better(Line& l1, Line& l2, long long x) {
        return l1.eval(x) > l2.eval(x);
    }
    
    void update(int node, int l, int r, Line& newLine) {
        if (l == r) {
            if (better(newLine, tree[node], l)) {
                tree[node] = newLine;
            }
            return;
        }
        
        int mid = (l + r) / 2;
        bool betterLeft = better(newLine, tree[node], l);
        bool betterMid = better(newLine, tree[node], mid);
        
        if (betterMid) {
            swap(tree[node], newLine);
        }
        
        if (betterLeft != betterMid) {
            update(2 * node, l, mid, newLine);
        } else {
            update(2 * node + 1, mid + 1, r, newLine);
        }
    }
    
    Line query(int node, int l, int r, long long x) {
        if (l == r) {
            return tree[node];
        }
        
        int mid = (l + r) / 2;
        Line result = tree[node];
        
        if (x <= mid) {
            Line candidate = query(2 * node, l, mid, x);
            if (better(candidate, result, x)) {
                result = candidate;
            }
        } else {
            Line candidate = query(2 * node + 1, mid + 1, r, x);
            if (better(candidate, result, x)) {
                result = candidate;
            }
        }
        
        return result;
    }
    
public:
    LiChaoTree(int size) {
        n = 1;
        while (n < size) n *= 2;
        tree.resize(4 * n);
    }
    
    // 添加直线
    void addLine(long long k, long long b, int id) {
        Line newLine(k, b, id);
        update(1, 0, n - 1, newLine);
    }
    
    // 查询x处的最大值
    pair<long long, int> queryMax(long long x) {
        Line result = query(1, 0, n - 1, x);
        return {result.eval(x), result.id};
    }
};

// 凸包优化DP应用
class ConvexHullOptimization {
private:
    LiChaoTree* lct;
    
public:
    ConvexHullOptimization(int maxX) {
        lct = new LiChaoTree(maxX + 1);
    }
    
    // 添加新的转移
    void addTransition(long long slope, long long intercept, int id) {
        lct->addLine(slope, intercept, id);
    }
    
    // 查询最优转移
    pair<long long, int> queryOptimal(long long x) {
        return lct->queryMax(x);
    }
};

// 示例:最大子矩形问题
long long maxSubmatrixDP(vector<vector<int>>& matrix) {
    int n = matrix.size(), m = matrix[0].size();
    long long result = LLONG_MIN;
    
    // 枚举上下边界
    for (int top = 0; top < n; top++) {
        vector<long long> heights(m, 0);
        
        for (int bottom = top; bottom < n; bottom++) {
            // 更新高度数组
            for (int j = 0; j < m; j++) {
                heights[j] += matrix[bottom][j];
            }
            
            // 使用李超树优化DP
            ConvexHullOptimization cho(m);
            vector<long long> dp(m);
            
            dp[0] = heights[0];
            cho.addTransition(0, heights[0], 0);
            result = max(result, dp[0]);
            
            for (int j = 1; j < m; j++) {
                auto [maxVal, fromIdx] = cho.queryOptimal(j);
                dp[j] = max(heights[j], maxVal + heights[j]);
                cho.addTransition(-j, dp[j], j);
                result = max(result, dp[j]);
            }
        }
    }
    
    return result;
}

int main() {
    // 测试李超线段树
    LiChaoTree lct(100);
    
    // 添加直线 y = 2x + 1
    lct.addLine(2, 1, 1);
    
    // 添加直线 y = -x + 10
    lct.addLine(-1, 10, 2);
    
    // 添加直线 y = x + 3
    lct.addLine(1, 3, 3);
    
    for (int x = 0; x <= 10; x++) {
        auto [maxVal, lineId] = lct.queryMax(x);
        cout << "x=" << x << ": max=" << maxVal << " (line " << lineId << ")" << endl;
    }
    
    // 测试矩阵DP
    vector<vector<int>> matrix = {
        {1, -2, 3},
        {-1, 4, -2},
        {2, -1, 1}
    };
    
    cout << "最大子矩形和: " 
         << maxSubmatrixDP(matrix) << endl;
    
    return 0;
}

3. 动态开点线段树

// 动态开点线段树
#include <iostream>
#include <unordered_map>
using namespace std;

class DynamicSegmentTree {
private:
    struct Node {
        long long sum;
        int left, right;
        Node() : sum(0), left(-1), right(-1) {}
    };
    
    vector<Node> tree;
    int nodeCount;
    long long L, R; // 值域范围
    
    void pushUp(int node) {
        tree[node].sum = 0;
        if (tree[node].left != -1) {
            tree[node].sum += tree[tree[node].left].sum;
        }
        if (tree[node].right != -1) {
            tree[node].sum += tree[tree[node].right].sum;
        }
    }
    
    int createNode() {
        tree.push_back(Node());
        return nodeCount++;
    }
    
    void update(int& node, long long l, long long r, long long pos, long long val) {
        if (node == -1) {
            node = createNode();
        }
        
        if (l == r) {
            tree[node].sum += val;
            return;
        }
        
        long long mid = l + (r - l) / 2;
        if (pos <= mid) {
            update(tree[node].left, l, mid, pos, val);
        } else {
            update(tree[node].right, mid + 1, r, pos, val);
        }
        
        pushUp(node);
    }
    
    long long query(int node, long long l, long long r, long long ql, long long qr) {
        if (node == -1 || ql > r || qr < l) {
            return 0;
        }
        
        if (ql <= l && r <= qr) {
            return tree[node].sum;
        }
        
        long long mid = l + (r - l) / 2;
        return query(tree[node].left, l, mid, ql, qr) + 
               query(tree[node].right, mid + 1, r, ql, qr);
    }
    
public:
    DynamicSegmentTree(long long minVal, long long maxVal) {
        L = minVal;
        R = maxVal;
        nodeCount = 0;
        tree.reserve(1000000); // 预留空间
    }
    
    int root = -1;
    
    void update(long long pos, long long val) {
        update(root, L, R, pos, val);
    }
    
    long long query(long long l, long long r) {
        return query(root, L, R, l, r);
    }
    
    // 查询第k小
    long long kthSmallest(int k) {
        return kthSmallest(root, L, R, k);
    }
    
private:
    long long kthSmallest(int node, long long l, long long r, int k) {
        if (node == -1 || k <= 0) return -1;
        
        if (l == r) {
            return l;
        }
        
        long long mid = l + (r - l) / 2;
        long long leftSum = (tree[node].left == -1) ? 0 : tree[tree[node].left].sum;
        
        if (k <= leftSum) {
            return kthSmallest(tree[node].left, l, mid, k);
        } else {
            return kthSmallest(tree[node].right, mid + 1, r, k - leftSum);
        }
    }
};

// 权值线段树应用
class WeightSegmentTree {
private:
    DynamicSegmentTree* dst;
    
public:
    WeightSegmentTree(long long minVal, long long maxVal) {
        dst = new DynamicSegmentTree(minVal, maxVal);
    }
    
    // 插入数字
    void insert(long long val) {
        dst->update(val, 1);
    }
    
    // 删除数字
    void remove(long long val) {
        dst->update(val, -1);
    }
    
    // 查询小于等于val的数字个数
    long long countLE(long long val) {
        return dst->query(LLONG_MIN, val);
    }
    
    // 查询第k小的数字
    long long kthSmallest(int k) {
        return dst->kthSmallest(k);
    }
    
    // 查询区间内数字个数
    long long countRange(long long l, long long r) {
        return dst->query(l, r);
    }
};

// 区间第k小问题
class RangeKthSmallest {
private:
    vector<DynamicSegmentTree*> trees;
    vector<long long> arr;
    int n;
    
public:
    RangeKthSmallest(vector<long long>& data) {
        arr = data;
        n = arr.size();
        trees.resize(n + 1);
        
        // 初始化前缀权值线段树
        for (int i = 0; i <= n; i++) {
            trees[i] = new DynamicSegmentTree(LLONG_MIN, LLONG_MAX);
        }
        
        // 构建前缀和
        for (int i = 0; i < n; i++) {
            trees[i + 1] = new DynamicSegmentTree(LLONG_MIN, LLONG_MAX);
            // 复制前一个版本并添加新元素
            for (int j = 0; j <= i; j++) {
                trees[i + 1]->update(arr[j], 1);
            }
        }
    }
    
    // 查询区间[l,r]的第k小
    long long queryKth(int l, int r, int k) {
        // 使用两个前缀树的差值
        // 这里简化实现,实际需要可持久化权值线段树
        vector<long long> values;
        for (int i = l; i <= r; i++) {
            values.push_back(arr[i]);
        }
        sort(values.begin(), values.end());
        return values[k - 1];
    }
};

int main() {
    // 测试动态开点线段树
    DynamicSegmentTree dst(-1000000, 1000000);
    
    dst.update(5, 3);
    dst.update(10, 2);
    dst.update(15, 1);
    
    cout << "区间[5,15]的和: " << dst.query(5, 15) << endl;
    cout << "第2小的数: " << dst.kthSmallest(2) << endl;
    
    // 测试权值线段树
    WeightSegmentTree wst(-100, 100);
    
    wst.insert(5);
    wst.insert(3);
    wst.insert(8);
    wst.insert(1);
    wst.insert(7);
    
    cout << "小于等于6的数字个数: " << wst.countLE(6) << endl;
    cout << "第3小的数字: " << wst.kthSmallest(3) << endl;
    
    return 0;
}

解题技巧

💾 空间优化

可持久化结构要注意空间使用,动态开点避免不必要的节点创建,合理估算空间需求。

⚡ 性能考虑

李超树适合直线查询,权值线段树处理第k小问题,根据问题特点选择合适的高级线段树。

🔧 实现技巧

注意边界处理和特殊情况,合理设计节点结构,使用对象池优化内存分配。

🎯 应用场景

可持久化处理历史查询,李超树优化DP转移,动态开点处理大值域稀疏数据。