D - Counting Stars HDU - 7059
题解:
长度为n的序列a,有三个操作:
- 对某个区间进行询问
- 对于某个区间内的每个数ai,减去ai&(-ai)
- 对于某个区间内的每个数ai,加上2k2^k2k,k满足2k<=ai<2k+12^k <= a_{i} <2^{k+1}2k<=ai<2k+1
题解:
很容易想到线段树维护,但是后两个操作都不是线段树的基础操作
对于第二个操作,如何维护,其实线段树问题中经常遇到,对于这种数值快速递降至稳定的函数(比如区间开根号,区间求欧拉函数),可以直接暴力修改。像本题中是减去lowbit(x),其实就是将x二进制中的最后一位1删除,那每个数最多也就操作个log x次就变成0,因此一共操作次数只有nlogn次,再加上线段树操作的复杂度,也就是O(nlog2n)O(nlog^2n)O(nlog2n)
对于第三个操作,加上2k2^k2k,k满足2k<=ai<2k+12^k <= a_{i} <2^{k+1}2k<=ai<2k+1,其实本质就是让ai的最左边的1左移一位。也就是说其实第三个操作只与ai的最高位有关,且是乘2,乘2这个操作是可以用线段树实现的。
具体实现就是:我们将ai的最高位和剩余位置拆开,sum1记录的是最高位的情况,sum2记录是剩余最高位情况,num记录ai中1的情况,因为操作2是要减去最后一位1,如果num==0,那么sum1和sum2就都等于0。对于操作2,让sum2-=lowbit(sum2),对于操作3,只需要对sum2进行乘2的维护。查询时再将sum1和sum2加在一起
代码:
详细看代码
// Problem: D - Counting Stars
// Contest: Virtual Judge - 2021杭电多校第八场
// URL: https://vjudge.net/contest/453140#problem/D
// Memory Limit: 131 MB
// Time Limit: 4000 ms
// Data:2021-08-13 12:59:28
// By Jozky#include <bits/stdc++.h>
#include <unordered_map>
#define debug(a, b) printf("%s = %d\n", a, b);
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
clock_t startTime, endTime;
//Fe~Jozky
const ll INF_ll= 1e18;
const int INF_int= 0x3f3f3f3f;
template <typename T> inline void read(T& x)
{T f= 1;x= 0;char ch= getchar();while (0 == isdigit(ch)) {if (ch == '-')f= -1;ch= getchar();}while (0 != isdigit(ch))x= (x << 1) + (x << 3) + ch - '0', ch= getchar();x*= f;
}
template <typename T> inline void write(T x)
{if (x < 0) {x= ~(x - 1);putchar('-');}if (x > 9)write(x / 10);putchar(x % 10 + '0');
}
void rd_test()
{
#ifdef ONLINE_JUDGE
#elsestartTime= clock();freopen("in.txt", "r", stdin);
#endif
}
void Time_test()
{
#ifdef ONLINE_JUDGE
#elseendTime= clock();printf("\nRun Time:%lfs\n", (double)(endTime - startTime) / CLOCKS_PER_SEC);
#endif
}
const int mod= 998244353;
const int maxn= 2e5 + 9;
int a[maxn];
int pw[maxn];
struct node
{int l, r;ll sum1; //记录最高位的1ll sum2; //记录除了最高位剩下的ll num; //记录二进制一共有几个1ll lazy; //记录乘2的标记
} tr[maxn << 2];
int lg[maxn]; //lg[i]表示a[i]的二进制有几位
int work(int x)
{ //计算x的二进制有几位int ans= 0;while (x) {ans++;x>>= 1;}return ans;
}
int work2(int x)
{ //计算x的二进制有几位是1int ans= 0;while (x) {if (x & 1)ans++;x>>= 1;}return ans;
}
int lowbit(int x)
{return x & (-x);
}
void pushup(int rt)
{tr[rt].sum1= (tr[rt << 1].sum1 + tr[rt << 1 | 1].sum1) % mod;tr[rt].sum2= (tr[rt << 1].sum2 + tr[rt << 1 | 1].sum2) % mod;tr[rt].num= max(tr[rt << 1].num, tr[rt << 1 | 1].num);
}
void solve(int rt, int val)
{tr[rt].sum1= 1ll * tr[rt].sum1 * pw[val] % mod;tr[rt].lazy+= val;
}
void pushdown(int rt)
{solve(rt << 1, tr[rt].lazy);solve(rt << 1 | 1, tr[rt].lazy);tr[rt].lazy= 0;
}
void build(int rt, int l, int r)
{tr[rt].l= l;tr[rt].r= r;tr[rt].lazy= 0;if (l == r) {tr[rt].num= work2(a[l]);tr[rt].sum1= (a[l] & (1 << (lg[l] - 1)));tr[rt].sum2= a[l] - tr[rt].sum1;return;}int mid= (l + r) >> 1;build(rt << 1, l, mid);build(rt << 1 | 1, mid + 1, r);pushup(rt);
}
void update1(int rt, int l, int r)
{if (tr[rt].num == 0)return;if (tr[rt].l > r || tr[rt].r < l)return;if (tr[rt].l == tr[rt].r) {tr[rt].sum2= tr[rt].sum2 - lowbit(tr[rt].sum2);tr[rt].num--;if (tr[rt].num == 0) { //全减没了tr[rt].sum1= 0;tr[rt].sum2= 0;}return;}pushdown(rt);int mid= (tr[rt].l + tr[rt].r) >> 1;if (l <= mid)update1(rt << 1, l, r);if (r > mid)update1(rt << 1 | 1, l, r);pushup(rt);
}
void update2(int rt, int l, int r)
{if (tr[rt].l > r || tr[rt].r < l)return;if (tr[rt].l >= l && tr[rt].r <= r) {solve(rt, 1);return;}pushdown(rt);int mid= (tr[rt].l + tr[rt].r) >> 1;if (l <= mid)update2(rt << 1, l, r);if (r > mid)update2(rt << 1 | 1, l, r);pushup(rt);
}
ll query(int rt, int l, int r)
{if (tr[rt].l > r || tr[rt].r < l)return 0;if (tr[rt].l >= l && tr[rt].r <= r) {return (1ll * tr[rt].sum1 + tr[rt].sum2) % mod;}pushdown(rt);int mid= (tr[rt].l + tr[rt].r) >> 1;ll ans= 0;if (l <= mid)ans= (ans + query(rt << 1, l, r)) % mod;if (r > mid)ans= (ans + query(rt << 1 | 1, l, r)) % mod;return ans % mod;
}
int main()
{//rd_test();int t;pw[0]= 1;for (int i= 1; i < 200004; i++)pw[i]= 1ll * pw[i - 1] * 2 % mod;scanf("%d", &t);while (t--) {int n;scanf("%d", &n);for (int i= 1; i <= n; i++) {scanf("%d", &a[i]);lg[i]= work(a[i]);}//memset(tr,0,sizeof(tr));build(1, 1, n);int m;read(m);for (int i= 1; i <= m; i++) {int op, l, r;scanf("%d%d%d", &op, &l, &r);if (op == 1)printf("%d\n", query(1, l, r));else if (op == 2)update1(1, l, r);else if (op == 3)update2(1, l, r);}}return 0;//Time_test();
}