正题
大意
有n个节目,每个节目对3个东西贡献不同,要求选择至少k个让第一个东西的值最大。求方案数
解题思路
至少k个我们可以计算选择任何个数的结果减去选择k个的结果。由于k比较小,我们考虑直接暴搜
数据不是很大,我们可以将节目分成两段进行搜索所有结果。
然后第一部分计算第1个东西的值减去第2个东西的值ab1iab1i,和减去第3个东西的值ac1iac1i。
第二部分一样计算ab2iab2i,ac2iac2i。
问题就变成了选择两个数ii,使得
ab1i−ab2j>0ab1i−ab2j>0
andand
ac1i−ac2j>0ac1i−ac2j>0
然后我们将两个中的 abab合在一起进行离散化,之后用一个树状数组或权值线段数进行第二部分查询每个区间内数的个数,这样就可以 n log nnlogn查询了。
总共时间复杂度
O(217×log 217+1676116)O(217×log217+1676116)
代码
#include<cstdio>
#include<algorithm>
#define ll long long
#define N 131080*4
#define lobit(x) x&(-x)
using namespace std;
struct ansnode{ll c,b;bool f;
}ans[N];
ll n,k,a[51],b[51],c[51],t,nz;
long long sum,tr[N],ans1;
void dfs(ll dep,ll w,ll sum1,ll sum2,ll sum3)//暴力处理k以内的结果
{if (w>=k) return;if (w<k) {if (sum1>sum2&&sum1>sum3)ans1++;}for (ll i=dep+1;i<=n;i++) dfs(i,w+1,sum1+a[i],sum2+b[i],sum3+c[i]);
}
void dfs1(ll dep,ll sum1,ll sum2,ll sum3)//第一部分搜索
{ans[++t].b=sum1-sum2;ans[t].c=sum1-sum3;ans[t].f=0;for (ll i=dep+1;i<=nz;i++) dfs1(i,sum1+a[i],sum2+b[i],sum3+c[i]);
}
void dfs2(ll dep,ll sum1,ll sum2,ll sum3)//第二部分搜索
{ans[++t].b-=sum1-sum2;ans[t].c-=sum1-sum3;ans[t].f=1;for (ll i=dep+1;i<=n;i++) dfs2(i,sum1+a[i],sum2+b[i],sum3+c[i]);
}
bool cmp1(ansnode x,ansnode y)//排序
{return x.c<y.c;}
bool cmp2(ansnode x,ansnode y)
{return x.b<y.b||x.b==y.b;}
void change(ll x,ll up)//树状数组——修改
{while (x<=up){tr[x]++;x+=lobit(x);}
}
long long find(ll x)//树状数组——查询
{long long ans=0;while (x>0){ans+=tr[x];x-=lobit(x);}return ans;
}
int main()
{//freopen("show.in","r",stdin);//freopen("show.out","w",stdout);scanf("%lld%lld",&n,&k);nz=n/2;for (ll i=1;i<=n;i++)scanf("%lld",&a[i]);for (ll i=1;i<=n;i++)scanf("%lld",&b[i]);for (ll i=1;i<=n;i++)scanf("%lld",&c[i]);dfs(0,0,0,0,0);//搜索dfs1(0,0,0,0);//第一部分搜索dfs2(nz,0,0,0);//第二部分搜索sort(ans+1,ans+1+t,cmp1);//离散化——排序ans[t+1].c=-2147483647;ll e=1,last=1;for (ll i=2;i<=t+1;i++){if (ans[i].c!=ans[~-i].c)//离散化——去重,标号{for (ll j=last;j<i;j++)ans[j].c=e;e++;last=i;}}e--;sort(ans+1,ans+1+t,cmp2);ll o=0,g;while (o<=t){g=ans[o].b;last=o;while (ans[o].b==g&&o<=t) o++;for (ll i=last;i<o;i++) if (!ans[i].f) sum+=find(~-ans[i].c);//查询for (ll i=last;i<o;i++) if (ans[i].f) change(ans[i].c,e);//修改}printf("%lld",sum-ans1);//输出
}