字典树(Tire)简介
字典树,英文名 trie。顾名思义,就是一个像字典一样的树。
先放一张图:
可以发现,这颗字典树用边来表示字符,而从根结点到树上某一结点的路径就代表了一个字符串。举个例子, $ 1 \rightarrow 4 \rightarrow 8 \rightarrow 12 $表示的就是字符串 caa
。
字典树(Trie)的建立
原理
我们用$ tr[now][x] $ 表示结点$ now $的 $ x $ 字符指向的下一个结点,或着说是结点 $ now $ 代表的字符串后面添加一个字符 $ x $ 形成的字符串的结点。($ x $ 的取值范围和字符集大小有关,不一定是 $ 0 \sim 26 $。)
有时候我们需要在字典树中插入一系列标记,例如我们要寻找一整个单词而不是单词的前缀时,要在单词末尾字母指向的节点做一个标记,在查询时方便判断是否满足查询的条件。
代码
int tr[1000001][26],f[1000001],rt=1,tot=1; //tr[i][j]表示节点i边字符为j所指向的节点,rt表示根节点,tot表示指向的节点的编号。
char s[51];//字符串
void ins(int n)//n表示单词长度
{
int now=rt;//首先将now(当前节点)赋值为根节点
for(int i=1;i<=n;i++)
{
int x=s[i]-'a';
if(tr[now][x]==0)//如果字典树上没有要新建立的边
tr[now][x]=++tot;//更新一个新的节点
now=tr[now][x];//修改新当前节点为原当前节点所指向的节点
}
f[now]=1;//根据题目设置标记,这里判断是否为一整个单词
return;
}
字典树的查询
原理
如上图,我们已经建立了一棵字典树,我们要从根节点向下进行搜索,对于需要查询的字符串的当前字符,如果这个对应的字符指针为空,就说明不含这个单词,直接跳出。
同样地,我们可以利用建树时的标记来判断其他条件是否满足。
代码
int find(char c[],int n)//c为所要查找的单词,n为该单词的长度
{
int now=rt;//设当前节点为根节点
for(int i=1;i<=n;i++)
{
int x=c[i]-'a';
if(tr[now][x]==0)//如果不存在这一条边,就退出,返回不存在
return 0;
now=tr[now][x];//如果存在这条边,就向下搜索
}
return f[now]++;//根据其他的标记返回答案,这里是判断能否取到整个单词。
}
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
using namespace std;
#define ll long long
#define pb(x) push_back(x)
#define mp(a,b) make_pair(a,b)
#define fi first
#define se second
#define pr(x) cerr<<#x<<"="<<(x)<<endl
#define pri(x,lo) {cerr<<#x<<"={";for (int ol=0;ol<=lo;ol++)cerr<<x[ol]<<",";cerr<<"}"<<endl;}
#define inf 100000000
#define N 1000
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
int tr[1000001][26],f[1000001],rt=1,tot=1;
char s[10001][51];
void ins(int d,int n)
{
int now=rt;
for(int i=1;i<=n;i++)
{
int x=s[d][i]-'a';
if(tr[now][x]==0)
tr[now][x]=++tot;
now=tr[now][x];
}
f[now]=1;
return;
}
int find(char c[],int n)
{
int now=rt;
for(int i=1;i<=n;i++)
{
int x=c[i]-'a';
if(tr[now][x]==0)
return 0;
now=tr[now][x];
}
return f[now]++;
}
int main()
{
int n=read();
for(int i=1;i<=n;i++)
{
scanf("%s",s[i]+1);
}
for(int i=1;i<=n;i++)
{
ins(i,strlen(s[i]+1));
}
int m=read();
for(int i=1;i<=m;i++)
{
char c[51];
scanf("%s",c+1);
int ans=find(c,strlen(c+1));
if(ans==0)
printf("WRONG\n");
if(ans==1)
printf("OK\n");
if(ans>=2)
printf("REPEAT\n");
}
return 0;
}