字典树(Trie Tree)数据结构

字典树(Trie Tree)是一种树形数据结构,用于高效地存储和检索字符串集合。每个节点代表一个字符,从根到任意节点的路径代表一个字符串前缀。字典树在字符串匹配、前缀查询、自动补全等场景中有广泛应用。

字典树基础实现

字典树的基本操作包括插入、查询、删除字符串,以及前缀匹配。每个节点包含子节点指针和表示是否为单词结尾的标志。

标准字典树实现

#include <iostream>
#include <vector>
#include <string>
#include <unordered_map>
using namespace std;

class TrieNode {
public:
    unordered_map<char, TrieNode*> children;
    bool isEnd;
    int count; // 以当前节点为结尾的字符串数量
    
    TrieNode() : isEnd(false), count(0) {}
};

class Trie {
private:
    TrieNode* root;
    
public:
    Trie() {
        root = new TrieNode();
    }
    
    // 插入字符串
    void insert(const string& word) {
        TrieNode* node = root;
        for (char c : word) {
            if (node->children.find(c) == node->children.end()) {
                node->children[c] = new TrieNode();
            }
            node = node->children[c];
        }
        node->isEnd = true;
        node->count++;
    }
    
    // 查询字符串是否存在
    bool search(const string& word) {
        TrieNode* node = findNode(word);
        return node != nullptr && node->isEnd;
    }
    
    // 查询前缀是否存在
    bool startsWith(const string& prefix) {
        return findNode(prefix) != nullptr;
    }
    
    // 删除字符串
    bool remove(const string& word) {
        return removeHelper(root, word, 0);
    }
    
    // 获取所有以prefix为前缀的字符串
    vector<string> getWordsWithPrefix(const string& prefix) {
        vector<string> result;
        TrieNode* prefixNode = findNode(prefix);
        if (prefixNode != nullptr) {
            dfs(prefixNode, prefix, result);
        }
        return result;
    }
    
    // 获取字典树中的所有字符串
    vector<string> getAllWords() {
        vector<string> result;
        dfs(root, "", result);
        return result;
    }
    
    // 计算以prefix为前缀的字符串数量
    int countWordsWithPrefix(const string& prefix) {
        TrieNode* prefixNode = findNode(prefix);
        if (prefixNode == nullptr) return 0;
        return countWordsHelper(prefixNode);
    }
    
private:
    TrieNode* findNode(const string& word) {
        TrieNode* node = root;
        for (char c : word) {
            if (node->children.find(c) == node->children.end()) {
                return nullptr;
            }
            node = node->children[c];
        }
        return node;
    }
    
    bool removeHelper(TrieNode* node, const string& word, int index) {
        if (index == word.length()) {
            if (!node->isEnd) return false;
            node->isEnd = false;
            node->count--;
            return node->children.empty() && node->count == 0;
        }
        
        char c = word[index];
        if (node->children.find(c) == node->children.end()) {
            return false;
        }
        
        TrieNode* child = node->children[c];
        bool shouldDeleteChild = removeHelper(child, word, index + 1);
        
        if (shouldDeleteChild) {
            delete child;
            node->children.erase(c);
            return !node->isEnd && node->children.empty() && node->count == 0;
        }
        
        return false;
    }
    
    void dfs(TrieNode* node, const string& current, vector<string>& result) {
        if (node->isEnd) {
            for (int i = 0; i < node->count; i++) {
                result.push_back(current);
            }
        }
        
        for (auto& [c, child] : node->children) {
            dfs(child, current + c, result);
        }
    }
    
    int countWordsHelper(TrieNode* node) {
        int count = node->count;
        for (auto& [c, child] : node->children) {
            count += countWordsHelper(child);
        }
        return count;
    }
};

int main() {
    Trie trie;
    
    // 插入单词
    vector<string> words = {"apple", "app", "apricot", "banana", "band", "bandana"};
    
    cout << "插入单词: ";
    for (const string& word : words) {
        cout << word << " ";
        trie.insert(word);
    }
    cout << endl;
    
    // 搜索测试
    vector<string> searchWords = {"app", "apple", "application", "ban"};
    cout << "\n搜索结果:" << endl;
    for (const string& word : searchWords) {
        bool found = trie.search(word);
        cout << word << ": " << (found ? "找到" : "未找到") << endl;
    }
    
    // 前缀测试
    vector<string> prefixes = {"ap", "ban", "car"};
    cout << "\n前缀结果:" << endl;
    for (const string& prefix : prefixes) {
        bool hasPrefix = trie.startsWith(prefix);
        cout << prefix << ": " << (hasPrefix ? "有前缀" : "无前缀") << endl;
        
        if (hasPrefix) {
            vector<string> wordsWithPrefix = trie.getWordsWithPrefix(prefix);
            cout << "  单词: ";
            for (const string& word : wordsWithPrefix) {
                cout << word << " ";
            }
            cout << endl;
        }
    }
    
    return 0;
}

这个字典树实现使用哈希表存储子节点,支持相同单词的多次出现,提供包括前缀查询和单词枚举在内的全面操作。

数组实现的字典树(更高效)

在竞赛编程中,基于数组的字典树实现通常在时间和空间上都更高效,特别是当字母表大小较小时。

数组实现的字典树

#include <iostream>
#include <vector>
#include <string>
#include <cstring>
using namespace std;

class ArrayTrie {
private:
    static const int MAXN = 100005; // 最大节点数
    static const int ALPHABET_SIZE = 26; // 字母表大小
    
    int trie[MAXN][ALPHABET_SIZE];
    bool isEnd[MAXN];
    int count[MAXN];
    int nodeCount;
    
public:
    ArrayTrie() {
        nodeCount = 1; // 根节点
        memset(trie[0], 0, sizeof(trie[0]));
        isEnd[0] = false;
        count[0] = 0;
    }
    
    void insert(const string& word) {
        int node = 0;
        for (char c : word) {
            int idx = c - 'a';
            if (trie[node][idx] == 0) {
                trie[node][idx] = nodeCount;
                memset(trie[nodeCount], 0, sizeof(trie[nodeCount]));
                isEnd[nodeCount] = false;
                count[nodeCount] = 0;
                nodeCount++;
            }
            node = trie[node][idx];
        }
        isEnd[node] = true;
        count[node]++;
    }
    
    bool search(const string& word) {
        int node = findNode(word);
        return node != -1 && isEnd[node];
    }
    
    bool startsWith(const string& prefix) {
        return findNode(prefix) != -1;
    }
    
    int countWordsWithPrefix(const string& prefix) {
        int node = findNode(prefix);
        if (node == -1) return 0;
        return countWordsHelper(node);
    }
    
    // 找最长公共前缀
    string longestCommonPrefix(const vector<string>& words) {
        if (words.empty()) return "";
        
        // 插入所有单词
        for (const string& word : words) {
            insert(word);
        }
        
        string lcp = "";
        int node = 0;
        
        while (true) {
            int childCount = 0;
            int nextNode = -1;
            char nextChar = 'a';
            
            // 计算子节点数
            for (int i = 0; i < ALPHABET_SIZE; i++) {
                if (trie[node][i] != 0) {
                    childCount++;
                    nextNode = trie[node][i];
                    nextChar = 'a' + i;
                }
            }
            
            // 如果有多个子节点或到达单词末尾则停止
            if (childCount != 1 || isEnd[node]) {
                break;
            }
            
            lcp += nextChar;
            node = nextNode;
        }
        
        return lcp;
    }
    
private:
    int findNode(const string& word) {
        int node = 0;
        for (char c : word) {
            int idx = c - 'a';
            if (trie[node][idx] == 0) {
                return -1;
            }
            node = trie[node][idx];
        }
        return node;
    }
    
    int countWordsHelper(int node) {
        int totalCount = count[node];
        for (int i = 0; i < ALPHABET_SIZE; i++) {
            if (trie[node][i] != 0) {
                totalCount += countWordsHelper(trie[node][i]);
            }
        }
        return totalCount;
    }
};

int main() {
    ArrayTrie trie;
    
    vector<string> words = {"flower", "flow", "flight", "fly"};
    
    cout << "单词: ";
    for (const string& word : words) {
        cout << word << " ";
        trie.insert(word);
    }
    cout << endl;
    
    // 测试搜索
    vector<string> testWords = {"flow", "flower", "flowing", "fl"};
    cout << "\n搜索结果:" << endl;
    for (const string& word : testWords) {
        bool found = trie.search(word);
        cout << word << ": " << (found ? "找到" : "未找到") << endl;
    }
    
    // 测试前缀计数
    cout << "\n前缀计数:" << endl;
    vector<string> prefixes = {"fl", "flo", "fly"};
    for (const string& prefix : prefixes) {
        int count = trie.countWordsWithPrefix(prefix);
        cout << prefix << ": " << count << " 个单词" << endl;
    }
    
    // 找最长公共前缀
    ArrayTrie lcpTrie;
    string lcp = lcpTrie.longestCommonPrefix(words);
    cout << "\n最长公共前缀: " << lcp << endl;
    
    return 0;
}

基于数组的字典树使用静态数组存储子节点,提供更好的缓存局部性和更快的访问速度。对于固定字母表大小的小写英文字母特别高效。

关键要点

  • 字典树为插入、搜索和删除提供O(m)时间复杂度,其中m是字符串长度
  • 空间复杂度为O(ALPHABET_SIZE × N × M),其中N是字符串数量,M是平均长度
  • 对于固定字母表的竞赛编程,基于数组的实现更高效
  • 常见应用:自动补全、拼写检查、IP路由、字符串匹配问题