一个 01 序列,序列里面包含了 n 个数,下标从 0 开始。这些数要么是 0,要么是 1,现在对于这个序列有五种变换操作和询问操作:
0 l r 把 [l,r] 区间内的所有数全变成 0;
1 l r 把 [l,r] 区间内的所有数全变成 1;
2 l r 把 [l,r] 区间内的所有数全部取反,也就是说把所有的 0 变成 1,把所有的 1 变成 0;
3 l r 询问 [l,r] 区间内总共有多少个 1;
4 l r 询问 [l,r] 区间内最多有多少个连续的 1。
对于每一种询问操作,lxhgww 都需要给出回答,聪明的程序员们,你们能帮助他吗?
线段树区间合并模板
区间合并就是考虑父亲的区间信息如何由左右儿子得到
区间连续1最大值可以是下面三种情况的最大值
- 左儿子的区间连续1最大值
- 右儿子的区间连续1最大值
- 左儿子的以右端点结束的区间 + 右儿子的以左端点开始的区间 的和
那分别维护0,1即可
还有 以左端点开始的区间 以右端点结束的区间
然后注意重置操作的优先级更高
tmp.ar[i][2] = max({ar[i][2], b.ar[i][2], ar[i][1] + b.ar[i][0]});
写成tmp.ar[i][2] = max({ar[i][0], b.ar[i][1], ar[i][1] + b.ar[i][0]});
调了 n 个小时
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define _ cout << "----------" << endl
//#define int ll
struct tag {
bool flip, st, v;
tag() : flip(0), st(0), v(0) {}
void apply(tag& T) {
if (T.st) {
flip = 0;
st = 1;
v = T.v;
} else if (T.flip) {
if (st) {
v ^= 1;
} else {
flip ^= 1;
}
}
}
};
struct node {
array<int,2> sum;
array<int,3> ar[2]; // l r max
node() {
sum = {0,0};
ar[0] = ar[1] = {0,0,0};
}
node operator+ (const node& b) const {
node tmp;
for (int i = 0; i <= 1; i++) {
tmp.sum[i] = sum[i] + b.sum[i];
int lenL = sum[0] + sum[1];
int lenR = b.sum[0] + b.sum[1];
tmp.ar[i][0] = (sum[i] == lenL) ? lenL + b.ar[i][0] : ar[i][0];
tmp.ar[i][1] = (b.sum[i] == lenR) ? lenR + ar[i][1] : b.ar[i][1];
tmp.ar[i][2] = max({ar[i][2], b.ar[i][2], ar[i][1] + b.ar[i][0]});
}
return tmp;
}
void apply(tag& T) {
if (T.st) {
sum[T.v] += sum[T.v^1];
ar[T.v] = {sum[T.v],sum[T.v],sum[T.v]};
sum[T.v^1] = 0;
ar[T.v^1] = {0,0,0};
} else if (T.flip) {
swap(sum[0],sum[1]);
swap(ar[0],ar[1]);
}
}
};
struct lazy {
int n;
vector<node> tree;
vector<tag> tg;
lazy(int _n, vector<node>& b) : n(_n) {
tree.resize(4*n);
tg.resize(4*n);
build(1,n,1,b);
}
void up(int p) {
tree[p] = tree[p<<1] + tree[p<<1|1];
}
void build(int l, int r, int p, vector<node>& b) {
if (l == r) {
tree[p] = b[l];
return;
}
int mid = (l + r) >> 1;
build(l,mid,p<<1,b);
build(mid+1,r,p<<1|1,b);
up(p);
}
void apply(int p, tag& T) {
tree[p].apply(T);
tg[p].apply(T);
}
void down(int p) {
apply(p<<1,tg[p]);
apply(p<<1|1,tg[p]);
tg[p] = tag();
}
void modify(int ml, int mr, int l, int r, int p, tag& T) {
if (r < ml || l > mr) return;
if (l >= ml && r <= mr) {
apply(p,T);
return;
}
down(p);
int mid = (l + r) >> 1;
modify(ml, mr, l, mid, p<<1, T);
modify(ml, mr, mid + 1, r, p<<1|1, T);
up(p);
}
void modify(int l, int r, tag& T) {
modify(l,r,1,n,1,T);
}
void query(int ql, int qr, int l, int r, int p, node& t) {
if (r < ql || l > qr) return;
if (l >= ql && r <= qr) {
t = t + tree[p];
return;
}
down(p);
int mid = (l + r) >> 1;
query(ql,qr,l,mid,p<<1,t);
query(ql,qr,mid+1,r,p<<1|1,t);
}
node query(int l, int r) {
node tmp;
query(l,r,1,n,1,tmp);
return tmp;
}
};
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, q;
cin >> n >> q;
vector<node> a(n+1);
for (int i = 1; i <= n; i++) {
int x;
cin >> x;
a[i].sum[x] = 1;
a[i].ar[x] = {1,1,1};
}
lazy tree(n, a);
while (q--) {
int op, l, r;
cin >> op >> l >> r;
l++, r++;
tag tg;
if (op == 0) {
tg.st = 1;
tg.v = 0;
tree.modify(l,r,tg);
} else if (op == 1) {
tg.st = 1;
tg.v = 1;
tree.modify(l,r,tg);
} else if (op == 2) {
tg.flip = 1;
tree.modify(l,r,tg);
} else if (op == 3){
cout << tree.query(l,r).sum[1] << '\n';
} else {
cout << tree.query(l,r).ar[1][2] << '\n';
}
}
return 0;
}
变体
给一个0/1串
最早出现长度是 k的全 0 串或全 1 串的开头下标是多少
然后翻转这k个位置的串,找不到输出-1
和前面那个序列操作几乎一模一样,关键是找到第一个
显然具有二分性
一开始直接在线段树外面二分 log^2 TLE
while (L <= R) {
int mid = (L + R) >> 1;
if (tree.query(1, mid).ar[bit][2] >= k) {
loc = min(mid, loc);
R = mid - 1;
} else {
L = mid + 1;
}
}
改成在线段树内部进行二分,优化掉一个log 即可
如果当前左半区间不满足,直接合并左区间并记录,然后往右边找
满足直接砍掉右区间,往左
void find_k(int l, int r, int p, int bit, int k, node& tmp, int& loc) {
// if (l > r) return;
if (l > r || (tmp + tree[p]).ar[bit][2] < k) return;
if (l == r) {
loc = min(loc, l);
return;
}
int mid = l + r >> 1;
down(p);
if ((tmp + tree[p << 1]).ar[bit][2] >= k) {
loc = min(loc, mid);
find_k(l, mid, p << 1, bit, k, tmp, loc);
} else {
tmp = tmp + tree[p << 1];
find_k(mid + 1, r, p << 1 | 1, bit, k, tmp, loc);
}
}
int find_k(int bit, int k) {
int loc = n + 1;
node tmp;
find_k(1, n, 1, bit, k, tmp, loc);
return loc;
}