I love counting
题意
给定一个含有n
个元素的数组C
。
对这个数组进行查询操作,每次查询给出四个值L
、R
、a
和b
。
问L--R
之间有多少种
c
c
c使得,
(
c
异
或
a
)
≤
b
(c\ 异或\ a)\le b
(c 异或 a)≤b。(如果
C
i
=
C
j
=
c
C_i=C_j=c
Ci=Cj=c,则只算作一种)
思路
这道题和思路和杭电第一场的Xor sum很像(01字典树),最好把上一场的补了再看这一题,这一题比上一题更麻烦,因为还得用到树状数组。Xor sum题解链接
我感觉这道题还是挺难的,所以可能题解废话有点多。
考虑将区间问题转换为前驱问题,定义
f
(
R
)
f(R)
f(R)表示R
前面满足
(
c
⨁
a
)
≤
b
(c\bigoplus a)\le b
(c⨁a)≤b的c
的个数,则L--R
之间满足条件的个数就是
f
(
R
)
−
f
(
L
−
1
)
f(R)-f(L-1)
f(R)−f(L−1)。
在
(
c
⨁
a
)
≤
b
(c\bigoplus a)\le b
(c⨁a)≤b中,因为a
和b
已知,可以确定
c
c
c二进制位的前缀,这个前缀用字典树来维护。
以下面这组数据为例来进行分析:
5
1 2 2 4 5
4
1 3 1 3
2 4 4 2
1 5 2 3
4 5 3 6
建立字典树
用
C
C
C来建立一棵字典树,字典树上每个节点(node)存储所有经过这个节点的
C
i
C_i
Ci的下标
i
i
i,这些节点都存放到数组vec[node_id]
中,这个数组是用来建立树状数组的,等下再考虑。
- 插入第一个点
001
后字典树为:
- 插入第二个点
010
后字典树为:
-
……
-
最后,插入结束后字典树为:
注
:我这里为了演示方便,将位数设置成了3位。
以第一个查询(1 3 1 3)为例,对a
(二进制位为001
)和b
(二进制位为011
)进行分析:
- 先看最高位,因为
b
的最高位为0
,无法直接在这一位确定 ( c ⨁ a ) ≤ b (c\bigoplus a)\le b (c⨁a)≤b,又因为a
的最高位为0
,所以c
的这一位应该定义为0
,这样才能确保 c ⨁ a c\bigoplus a c⨁a不会大于b
,字典树指针cur
从root移动到左子树。 - 再看次高位,因为
b
的最高位为1
,所以如果 c ⨁ a c\bigoplus a c⨁a的这一位为0
的话, c ⨁ a c\bigoplus a c⨁a的值一定小于b
,可以在该节点上直接执行查询操作(离线查询,将查询操作也放到vec[node_id]中),但是 c ⨁ a c\bigoplus a c⨁a这一位的值可能也为1
,又因a
的这一位为0
,所以c
的这一位应该定义为1
,cur
移动到当前节点的右子树。 - ……
注
:当cur
走到最后的时候,需要特判
(
c
⨁
a
)
=
b
(c\bigoplus a)\ = b
(c⨁a) =b的情况。
对字典树上的每一个节点建立一个树状数组
说一下如何在一个节点上执行查询操作,还记得在刚才建立字典树时,所有经过该节点(node_id)的值的下标都被存储下来了,相当于vec[node_id]中存储的值的前缀都是相同的(除去查询操作),并且这些值都是满足条件 ( c ⨁ a ) ≤ b (c\bigoplus a)\le b (c⨁a)≤b的。
问题就转换成了:求数组S
中,下标R
前面的数有多少个不同的值。
这是一个很经典的问题,就不再废话了。
AC的代码(加注释的标程)
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e5 + 10;
struct {
int l, r, a, b;
}q[N];
int n, c[N], Q;
int res[N];// 记录最后的答案
int nxt[N], head[N];
// =======================================================================================
// 树状数组
int tr[N];
inline void add(int x, int y) {
while (x <= n)
tr[x] += y, x += x & -x;
}
inline int query(int x) {
int ans = 0;
while (x)
ans += tr[x], x -= x & -x;
return ans;
}
// =======================================================================================
// 字典树
int sum[N * 20];
int ch[N * 20][2], tot;
struct node {
int l, r, x, id;
inline bool operator <(const node &b)const {
if (x != b.x) return x < b.x;
if (id != b.id) return id < b.id;
return l < b.l;
}
};// node用来记录经过某一节点的值
vector<node>vec[N * 20];
void query(int rt, int now, int a, int b, int id) {
if (!rt && now != 16) {
// 字典树上存在这一节点时
return;
}
if (now < 0) {
// 已经移动到最后
vec[rt].push_back(node{ q[id].l, q[id].r, n - q[id].r + 1, id });
return;
}
int xx = (a & (1 << now)) != 0;
int yy = (b & (1 << now)) != 0;
if (yy) {// 这时可以确保后缀无论是什么,c^a一定小于b
vec[ch[rt][xx]].push_back(node{ q[id].l, q[id].r, n - q[id].r + 1, id });
}
xx ^= yy;
query(ch[rt][xx], now - 1, a, b, id);
}
void insert(int id) {
int x = c[id];
int now = 0;
for (int i = 16; i >= 0; --i) {
int xx = (x&(1 << i)) != 0;
if (!ch[now][xx])// 如果不存在这条边,创建新的
ch[now][xx] = ++tot;
now = ch[now][xx];// 移动到下一个节点
sum[now]++;
vec[now].push_back(node{ id, 0, n - nxt[id] + 2, 0 });
}
}
// =======================================================================================
int main() {
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); cin.exceptions(ios::badbit | ios::failbit);
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> c[i];
}
// ==================================================
for (int i = n; i >= 1; i--) {
if (!head[c[i]])
nxt[i] = n + 1;
else
nxt[i] = head[c[i]];
head[c[i]] = i;
}
// ==================================================构造字典树
for (int i = 1; i <= n; i++) insert(i);
// ==================================================查询离线
cin >> Q;
for (int i = 1; i <= Q; i++) {
cin >> q[i].l >> q[i].r >> q[i].a >> q[i].b;
query(0, 16, q[i].a, q[i].b, i);
}
// ==================================================
for (int i = 1; i <= tot; ++i) {
sort(vec[i].begin(), vec[i].end());
// ==================================================
// 树状数组
for (auto it : vec[i]) {
if (it.id == 0) {
add(it.l, 1);
}
else {
res[it.id] += query(it.r) - query(it.l - 1);
}
}
// 清空树状数组
// 不能直接memset,会超时
for (auto it : vec[i]) {
if (it.id == 0) {
add(it.l, -1);
}
}
}
for (int i = 1; i <= Q; ++i)
cout << res[i] << endl;
return 0;
}