倍增法
祖孙询问
给定一棵包含 n 个节点的有根无向树,节点编号互不相同,但不一定是 1∼n。
有 m 个询问,每个询问给出了一对节点的编号 x 和 y,询问 x 与 y 的祖孙关系。
输入格式
输入第一行包括一个整数 表示节点个数;
接下来 n 行每行一对整数 a 和 b,表示 a和 b 之间有一条无向边。如果 b是 −1,那么 a 就是树的根;
第 n+2 行是一个整数 m 表示询问个数;
接下来 m 行,每行两个不同的正整数 x 和 y,表示一个询问。
输出格式
对于每一个询问,若 x 是 y 的祖先则输出 1,若 y 是 x 的祖先则输出 2,否则输出 0。
数据范围
1≤n,m≤4×10^4
1≤每个节点的编号≤4×10^4
输入样例:
10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19
输出样例:
1
0
0
0
2
#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
using namespace std;
const int N=4e4+10,M=2*N;
int h[N],e[M],ne[M],idx;
int fa[N][16];
int depth[N];
int n,m;
void add(int a,int b)
{
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void bfs(int root)
{
memset(depth,0x3f,sizeof depth);
depth[0]=0;//哨兵
depth[root]=1;//根节点深度为1
queue<int>q;
q.push(root);
//宽搜
while(q.size())
{
int t=q.front();
q.pop();
//遍历下一层
for(int i=h[t];i!=-1;i=ne[i])
{
int j=e[i];
//此时还没有更新
if(depth[j]>depth[t]+1)//depth[j]==0x3f3f3f3f
{
//深度加一
depth[j]=depth[t]+1;
q.push(j);
//跳一层是j的父节点 2^0
fa[j][0]=t;
/*
i → mid → t
2^j-1 2^j-1
f[i][j-1] f[i][j]
mid = f[i][j-1]
t = f[i][j]
则f[i][j] = f[mid][j-1] = f[f[i][j-1]][j-1]
*/
for(int k=1;k<=15;k++)
{
fa[j][k]=fa[fa[j][k-1]][k-1];
}
/*
举个例子理解超过根节点是怎么超过的
因为我们没有对根节点fa[1][0]赋值,那么fa[1][0] = 0;
1
/ \
2 3
fa[1][0] = 0;
fa[2][0] = 1;
fa[2][1] = fa[fa[2][0]][0] = fa[1][0] = 0;
*/
}
}
}
}
int lca(int a,int b)
{
// 为方便处理 当a在b上面时 把a b 互换
if(depth[a]<depth[b]) swap(a,b);
for(int k=15;k>=0;k--)
{
//当a跳完2^k依然在b下面 我们就一直跳
//二进制拼凑法
if(depth[fa[a][k]]>=depth[b])
{
//更新跳
a=fa[a][k];
}
}
//如果跳到了b 判断一下 是否跳到了b
if(a==b) return a;
for(int k=15;k>=0;k--)
{
// 假如a,b都跳出根节点,fa[a][k]==fa[b][k]==0 不符合更新条件
if(fa[a][k]!=fa[b][k])
{
a=fa[a][k];
b=fa[b][k];
}
}
//循环结束 到达lca下一层
//lca(a,b) = 再往上跳1步即可
return fa[a][0];
}
int main()
{
cin>>n;
//初始化头节点
memset(h,-1,sizeof h);
int root=0;
while(n--)
{
//建图
int a,b;cin>>a>>b;
if(b==-1) root=a;
else
{
add(a,b);
add(b,a);
}
}
//建depth[N] fa[N][15]
bfs(root);
cin>>m;
while(m--)
{
int a,b;cin>>a>>b;
int num=lca(a,b);
if(num==a) cout<<"1"<<endl;
else if(num==b) cout<<"2"<<endl;
else cout<<"0"<<endl;
}
return 0;
}
tarjan求lca
距离
给出 n 个点的一棵树,多次询问两点之间的最短距离。
注意:
- 边是无向的。
- 所有节点的编号是 1,2,…,n。
输入格式
第一行为两个整数 n 和 m。n表示点数,m 表示询问次数;
下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;
再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。
树中结点编号从 1 到 n。
输出格式
共 m 行,对于每次询问,输出一行询问结果。
数据范围
2≤n≤10^4
1≤m≤2×10^4
0<k≤100
1≤x,y≤n
输入样例1:
2 2
1 2 100
1 2
2 1
输出样例1:
100
100
输入样例2:
3 2
1 2 10
3 1 15
1 2
3 2
输出样例2:
10
25
树的最短距离只有一条
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef pair<int,int>PII;
const int N=1e4+10,M=2*N;
int h[N],e[M],w[M],ne[M],idx;
int st[N],res[M];
int dist[N];//随便设一个点为根节点 设1为根节点 求1到其它节点的距离
int p[N];//用来标记
int n,m;
vector<PII>query[N];//把询问存下来
// query[i][first][second] first存查询距离i的另外一个点j,second存查询编号idx
//建立无向图
void add(int a,int b,int c)
{
e[idx]=b;
w[idx]=c;
ne[idx]=h[a];
h[a]=idx++;
}
//初始化
void unit()
{
for(int i=0;i<=n;i++) p[i]=i;
}
int find(int x)
{
if(x!=p[x]) p[x]=find(p[x]);
return p[x];
}
//求1到其它节点的距离 时间复杂度O(m+n)
void dfs(int u,int fa)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(fa==j) continue;
dist[j]=dist[u]+w[i];
dfs(j,u);
}
}
void tarjan(int u)
{
//第二类 标记为1 当前正着搜的点
st[u]=1;
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
//若没有被标记过继续往下搜
if(!st[j])
{
tarjan(j);//往左下搜
//回溯之后
p[j]=u;//从左下回溯后把左下的点合并到根节点 一用find() 这条分支上的所有点都合并到根节点了
}
}
// 对于当前点u 搜索所有和u
for(auto item:query[u])
{
int y=item.first,id=item.second;
if(st[y]==2)//如果查询的这个点已经是左下的点(已经搜索过且回溯过,标记为2)
{
int anc=find(y);
// x到y的距离 = d[x]+d[y] - 2*d[lca]
res[id]=dist[u]+dist[y]-2*dist[anc];
}
}
//点u已经搜索完且要回溯了 就把st[u]标记为2
st[u]=2;
}
int main()
{
memset(h,-1,sizeof h);
cin>>n>>m;
for(int i=1;i<n;i++)
{
int a,b,c;cin>>a>>b>>c;
add(a,b,c);
add(b,a,c);
}
for(int i=0;i<m;i++)
{
int a,b;cin>>a>>b;
if(a!=b)
{
query[a].push_back({b,i});
query[b].push_back({a,i});
}
}
unit();
dfs(1,-1);
tarjan(1);
for(int i=0;i<m;i++) cout<<res[i]<<endl;
return 0;
}
景区导游(蓝桥杯)
题目描述
某景区一共有 N 个景点,编号 1 到 N。景点之间共有 N − 1 条双向的摆渡车线路相连,形成一棵树状结构。在景点之间往返只能通过这些摆渡车进行,需要花费一定的时间。
小明是这个景区的资深导游,他每天都要按固定顺序带客人游览其中 K 个景点:A1, A2, . . . , AK。今天由于时间原因,小明决定跳过其中一个景点,只带游客按顺序游览其中 K − 1 个景点。具体来说,如果小明选择跳过 Ai,那么他会按顺序带游客游览 A1, A2, . . . , Ai−1, Ai+1, . . . , AK, (1 ≤ i ≤ K)。
请你对任意一个 Ai,计算如果跳过这个景点,小明需要花费多少时间在景点之间的摆渡车上?
输入格式
第一行包含 2 个整数 N 和 K。
以下 N − 1 行,每行包含 3 个整数 u, v 和 t,代表景点 u 和 v 之间有摆渡车线路,花费 t 个单位时间。
最后一行包含 K 个整数 A1, A2, . . . , AK 代表原定游览线路。
输出格式
输出 K 个整数,其中第 i 个代表跳过 Ai 之后,花费在摆渡车上的时间。
样例输入
6 4
1 2 1
1 3 1
3 4 2
3 5 2
4 6 3
2 6 5 1
样例输出
10 7 13 14
提示
原路线是 2 → 6 → 5 → 1。
当跳过 2 时,路线是 6 → 5 → 1,其中 6 → 5 花费时间 3 + 2 + 2 = 7,5 → 1 花费时间 2 + 1 = 3,总时间花费 10。
当跳过 6 时,路线是 2 → 5 → 1,其中 2 → 5 花费时间 1 + 1 + 2 = 4,5 → 1 花费时间 2 + 1 = 3,总时间花费 7。
当跳过 5 时,路线是 2 → 6 → 1,其中 2 → 6 花费时间 1 + 1 + 2 + 3 = 7,6 → 1 花费时间 3 + 2 + 1 = 6,总时间花费 13。
当跳过 1 时,路线时 2 → 6 → 5,其中 2 → 6 花费时间 1 + 1 + 2 + 3 = 7,6 → 5 花费时间 3 + 2 + 2 = 7,总时间花费 14。
对于 20% 的数据,2 ≤ K ≤ N ≤ 102。
对于 40% 的数据,2 ≤ K ≤ N ≤ 104。
对于 100% 的数据,2 ≤ K ≤ N ≤ 105,1 ≤ u, v, Ai ≤ N,1 ≤ t ≤ 105。保证Ai 两两不同。
#include<iostream>
#include<queue>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
const int N=1e5+10,M=2*N;
int h[N],ne[M],e[M],w[M],idx;
int depth[N];
ll dist[N];
int fa[N][20];
int a[N];
ll s[N];
int n,k;
void add(int a,int b,int c)
{
e[idx]=b;
w[idx]=c;
ne[idx]=h[a];
h[a]=idx++;
}
void dfs(int u,int fa)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==fa)continue;
dist[j]=dist[u]+w[i];
dfs(j,u);
}
}
void bfs(int root)
{
memset(depth,0x3f,sizeof depth);
depth[root]=1;
depth[0]=0;
queue<int>q;
q.push(root);
while(q.size())
{
int t=q.front();
q.pop();
for(int i=h[t];i!=-1;i=ne[i])
{
int j=e[i];
if(depth[j]>depth[t]+1)
{
depth[j]=depth[t]+1;
q.push(j);
fa[j][0]=t;
for(int k=1;k<=19;k++)
{
fa[j][k]=fa[fa[j][k-1]][k-1];
}
}
}
}
}
int lca(int a,int b)
{
if(depth[a]<depth[b]) swap(a,b);
for(int k=19;k>=0;k--)
{
if(depth[fa[a][k]]>=depth[b])
{
a=fa[a][k];
}
}
if(a==b) return a;
for(int k>=19;k>=0;k--)
{
if(fa[a][k]!=fa[b][k])
{
a=fa[a][k];
b=fa[b][k];
}
}
return fa[a][0];
}
int main()
{
cin>>n>>k;
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int a,b,c;cin>>a>>b>>c;
add(a,b,c);
add(b,a,c);
}
for(int i=1;i<=k;i++)
{
cin>>a[i];
}
dfs(1,-1);
bfs(1);
ll sum=0;
for(int i=1;i<k;k++)
{
int num=lca(a[i],a[i+1]);
s[i]=dist[a[i]]+dist[a[i+1]]-2*dist[num];
sum+=dist[a[i]]+dist[a[i+1]]-2*dist[num];
}
int num=lca(a[k-1],a[k]);
sum+=dist[a[k-1]]+dist[a[k]]-2*dist[num];
for(int i=1;i<k;i++)
{
if(i==1)
{
cout<<sum-s[i]<<" ";
}
else
{
int num1=lca(a[i-1],a[i+1]);
cout<<sum-s[i]-s[i+1]+dist[a[i-1]]+dist[a[i+1]]-2*dist[num1]<<" ";
}
}
return 0;
}