一、基本概念:
字典树(Trie)是一种用于实现字符串快速检索的多叉树结构。字典树的每一个结点都拥有若干个字符指针,若在插入或检索字符串时扫描到一个字符 c,就沿着当前结点的 c 字符指针,走向该指针指向的结点。
1、初始化
一棵空字典树仅包含一个根节点,该点的字符指针均指向空。
2、插入操作
当需要插入一个字符串 str 时,我们另一个指针 p 指向根节点。人后,依次扫描 s 中的每一个字符 c :
(1)若 p 的 c 字符指针指向一个已经存在的结点 q,则令 p = q
若 p 的 c 字符指针指向空,则新建一个结点 q,令 p = q
当 s 中的字符扫描完毕时,在当前结点 p 上标记它是一个字符串的尾部
int trie[maxn][26],tot=1;
bool nd[maxn];
void _insert(char str[])
{
int p = 1;
for(int k=0;str[k];k++){
int ch = str[k]-'a';
if(trie[p][ch]==0) trie[p][ch] = ++tot;
p = trie[p][ch];
}
nd[p] = true;
}
2、检索操作
当需要检索一个字符串 str 在字典树是否存在时,我们令一个指针 p 指向根节点,然后,依次扫描 str 中的每个字符 c :
(1)若 p 的 c 字符指针指向空,则说明 str 没有插入过 字典树,结束检索
(2)若 p 的 c 字符指针指向一个已经存在的结点 q,则令 p = q
当 s 中的字符扫描完毕时,若当前结点 p 被标记为一个字符串的尾部,则说明 str 在字典树中存在,否则不存在
bool _search(char str[])
{
int p = 1;
for(int k=0;str[k];k++){
p = trie[p][str[k]-'a'];
if(p==0) return false;
}
return nd[p];
}
代码比较简单,但是我们进一步想统计前缀出现的次数怎么办?那就开一个 sum 数组,表示某节点被访问过的次数。我们知道对于每一个前缀单词的插入,只要出现过这个前缀,那么总是要遍历一次从根节点到这个前缀单词的终节点路径中所有的节点,在遍历每一个节点的时候,我们都让此节点的sum计数数组加一即可。而对于某个前缀出现的次数,我们最后只需要返回此前缀单词最后一个字符对应的sum值即可。代码如下:
int trie[maxn][26],tot;
int sum[maxn];
void _insert(char str[])
{
int p = 0;
for(int k=0;str[k];k++){
int ch = str[k]-'a';
if(trie[p][ch]==0) trie[p][ch] = ++tot;
p = trie[p][ch];
sum[p]++;
}
}
int _search(char str[])
{
int p = 0;
for(int k=0;str[k];k++){
p = trie[p][str[k]-'a'];
if(sum[p]==0) return 0;
}
return sum[p];
}
二、例题
1、模板题目:HDU 1251 统计难题
注意控制输入输出即可,其余模板。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<string>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<string>
#include<map>
#include<set>
#include<tr1/unordered_set>
#define ll long long
#define ld long double
#define ull unsigned long long
using namespace std;
using namespace tr1;
const int maxn = 1000010;
int trie[maxn][26],tot;
int sum[maxn];
char a[24];
void _insert(char str[])
{
int p = 0;
for(int k=0;str[k];k++){
int ch = str[k]-'a';
if(trie[p][ch]==0) trie[p][ch] = ++tot;
p = trie[p][ch];
sum[p]++;
}
}
int _search(char str[])
{
int p = 0;
for(int k=0;str[k];k++){
p = trie[p][str[k]-'a'];
if(sum[p]==0) return 0;
}
return sum[p];
}
int main(void)
{
while(gets(a),strlen(a)!=0){
_insert(a);
}
while(~scanf("%s",a)){
printf("%d\n",_search(a));
}
return 0;
}
2、AcWing 142. 前缀统计
给定N个字符串S1,S2…SN,接下来进行M次询问,每次询问给定一个字符串T,求S1~SN中有多少个字符串是T的前缀。
反向提问,那么我们就把每一个以 p 的 c 字符指针为结尾的字符,用 nd 数组记入一下个数。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<string>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<string>
#include<map>
#include<set>
#include<tr1/unordered_set>
#define ll long long
#define ld long double
#define ull unsigned long long
using namespace std;
using namespace tr1;
const int maxn = 1000010;
int trie[maxn][26];
int nd[maxn],tot;
char a[maxn];
void _insert(char str[])
{
int p = 0;
for(int k=0;str[k];k++){
int ch = str[k]-'a';
if(trie[p][ch]==0) trie[p][ch] = ++tot;
p = trie[p][ch];
}
nd[p]++;
}
int _search(char str[])
{
int ans = 0,p = 0;
for(int k=0;str[k];k++){
p = trie[p][str[k]-'a'];
if(p==0) return ans;
ans += nd[p];
}
return ans;
}
int main(void)
{
int n,m;
scanf("%d%d",&n,&m);
while(n--){
scanf("%s",a);
_insert(a);
}
while(m--){
scanf("%s",a);
printf("%d\n",_search(a));
}
return 0;
}
三、0/1字典树
0/1字典树主要解决求异或最值的问题。
0/1字典树是一棵最多32层(或64层)的二叉树,其每个节点的两条边分别表示二进制的某一位的值为 0 还是为 1,将某个路径上边的值连起来就得到一个二进制串。
根节点的边表示二进制串的最高位
简单来说,0/1字典树存放的某个数的二进制串,跟字典树存放字符串的原理差不多,
所以插入函数:
int trie[32*maxn][2],tot; //如果是64位级别,就64*maxn
ll val[32*maxn];
void Init()
{
memset(trie[0],0,sizeof(trie[0]));
tot = 1;
}
void Insert(ll x)
{
int p = 0;
for(int i=32;i>=0;i--){ //从高到低,依次插入
int c = (x>>i)&1;
if(trie[p][c]==0){
memset(trie[tot],0,sizeof(trie[tot]));
val[tot] = 0;
trie[p][c] = tot++;
}
p = trie[p][c];
}
val[p] = x; //以 P 为二进制串结尾的数是 x
}
如果求解异或 x 最大值的数
那我们就找寻每一位尽可能与 x 对应的二进制串位异或结果为 1 的边
这样得到的二进制串组成的数 p,p^x一定是最大的
ll query(ll x)
{
int p = 0;
for(int i=32;i>=0;i--){
int c = (x>>i)&1;
if(trie[p][c^1]) p = trie[p][c^1];
else p = trie[p][c];
}
return val[p];
}
例题1:HDU 4825 Xor Sum
Zeus 和 Prometheus 做了一个游戏,Prometheus 给 Zeus 一个集合,集合中包含了N个正整数,随后 Prometheus 将向 Zeus 发起M次询问,每次询问中包含一个正整数 S ,之后 Zeus 需要在集合当中找出一个正整数 K ,使得 K 与 S 的异或结果最大。Prometheus 为了让 Zeus 看到人类的伟大,随即同意 Zeus 可以向人类求助。你能证明人类的智慧么?
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<string>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<string>
#include<map>
#include<set>
#include<tr1/unordered_set>
#define ll long long
#define ld long double
#define ull unsigned long long
using namespace std;
using namespace tr1;
const int maxn = 100010;
int trie[32*maxn][2],tot;
ll val[32*maxn];
void Init()
{
memset(trie[0],0,sizeof(trie[0]));
tot = 1;
}
void Insert(ll x)
{
int p = 0;
for(int i=32;i>=0;i--){
int c = (x>>i)&1;
if(trie[p][c]==0){
memset(trie[tot],0,sizeof(trie[tot]));
val[tot] = 0;
trie[p][c] = tot++;
}
p = trie[p][c];
}
val[p] = x;
}
ll query(ll x)
{
int p = 0;
for(int i=32;i>=0;i--){
int c = (x>>i)&1;
if(trie[p][c^1]) p = trie[p][c^1];
else p = trie[p][c];
}
return val[p];
}
int main(void)
{
int t,n,m;
ll x;
cin>>t;
for(int cas=1;cas<=t;cas++){
Init();
scanf("%d%d",&n,&m);
while(n--){
scanf("%lld",&x);
Insert(x);
}
printf("Case #%d:\n",cas);
while(m--){
scanf("%lld",&x);
printf("%lld\n",query(x));
}
}
return 0;
}
例题2:AcWing 143 最大异或对
在给定的N个整数A1,A2……AN中选出两个进行xor(异或)运算,得到的结果最大是多少?
注意数据范围:0≤Ai<231,对应着插入和检索的范围
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<string>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<string>
#include<map>
#include<set>
#include<tr1/unordered_set>
#define ll long long
#define ld long double
#define ull unsigned long long
using namespace std;
using namespace tr1;
const int maxn = 100010;
int trie[32*maxn][2],tot; //数据范围决定
int a[maxn];
void Init()
{
memset(trie,0,sizeof(trie));
tot = 1;
}
void Insert(ll x)
{
int p = 0;
for(int i=30;i>=0;i--){ //数据范围决定
int c = (x>>i)&1;
if(trie[p][c]==0){
trie[p][c] = tot++;
}
p = trie[p][c];
}
}
int query(int x)
{
int p = 0,ans = 0;
for(int i=30;i>=0;i--){ //数据范围决定
int c = (x>>i)&1;
if(trie[p][c^1]){
p = trie[p][c^1];
ans |= (1<<i);
}
else p = trie[p][c];
}
return ans;
}
int main(void)
{
int n,ans=0;
cin>>n;
Init();
for(int i=1;i<=n;i++){
scanf("%d",a+i);
Insert(a[i]);
}
for(int i=1;i<=n;i++){
ans = max(ans,query(a[i]));
}
printf("%d\n",ans);
return 0;
}