正题
题目链接:https://www.luogu.com.cn/problem/P4769
题目大意
有一个冒泡排序的算法
输入:一个长度为 n 的排列 p[1...n]
输出:p 排序后的结果。
for i = 1 to n dofor j = 1 to n - 1 doif(p[j] > p[j + 1])交换 p[j] 与 p[j + 1] 的值
然后给出一个排列aaa,求在所有字典序大于aaa的排列ppp中冒泡排序交换次数恰好为∑i=1n∣i−pi∣\sum_{i=1}^n|i-p_i|∑i=1n∣i−pi∣的排列数。
1≤n≤6×105,∑n≤2×1061\leq n\leq 6\times 10^5,\sum n\leq 2\times 10^61≤n≤6×105,∑n≤2×106
解题思路
打一下表发现合法的排列条件是最长下降子序列不超过222。
然后我们先不考虑字典序限制条件怎么做,我们设fi,1/2f_{i,1/2}fi,1/2表示目前下降子序列长度为1/21/21/2中末尾最大的那个。
那么fi,1f_{i,1}fi,1就是目前出现的数中最大的,然后如果我们从前往后填数,那么如果≤fi,2\leq f_{i,2}≤fi,2的数中有没有填进去的,肯定不合法,所以fi,2f_{i,2}fi,2肯定比目前没有填进去的数中所有数字都小,不需要考虑。
设gi,jg_{i,j}gi,j表示目前还剩下iii个数没填,其中fi,1f_{i,1}fi,1大于其中的jjj个数,那么有gi,jg_{i,j}gi,j可以转移到gi−1,j−1g_{i-1,j-1}gi−1,j−1(填在最底)和gi−1,k(k≥j)g_{i-1,k}(k\geq j)gi−1,k(k≥j)(填在jjj上面)。
我们考虑快速的求出每个ggg,反过来就是gi,jg_{i,j}gi,j转移到gi+1,j+1g_{i+1,j+1}gi+1,j+1和gi+1,jg_{i+1,j}gi+1,j。
我们维护一个hi,j=gi,i−jh_{i,j}=g_{i,i-j}hi,j=gi,i−j,那么每次的转移就是hi,jh_{i,j}hi,j转移到hi,k(j≤k≤i)h_{i,k}(j\leq k\leq i)hi,k(j≤k≤i)
这个转移很像卡特兰数的要求,每次可以往下或者往右,但是不能超过对角线。
这样来说移动到位置(n,m)(m≤n)(n,m)(m\leq n)(n,m)(m≤n)的话方案数就是(n+mm)−(n+mm−1)\binom{n+m}{m}-\binom{n+m}{m-1}(mn+m)−(m−1n+m)。
然后就是枚举第一个超过该字典序的位置,这样前面的方案固定,剩下的数可以用树状数组计算得出,再用组合数求答案即可。
时间复杂度:O(nlogn)O(n\log n)O(nlogn)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define lowbit(x) (x&-x)
using namespace std;
const ll N=6e5*2,P=998244353;
ll T,n,t[N],a[N],fac[N],inv[N],ans;
ll C(ll n,ll m)
{if(m<0)return 0;return fac[n]*inv[m]%P*inv[n-m]%P;}
void Change(ll x,ll val){while(x<=n){t[x]+=val;x+=lowbit(x);}return;
}
ll Ask(ll x){ll ans=0;while(x){ans+=t[x];x-=lowbit(x);}return ans;
}
ll F(ll n,ll m){m=n-1-m;if(!m)return 0;m--;return (C(n+m,m)-C(n+m,m-1)+P)%P;
}
signed main()
{
// freopen("inverse3.in","r",stdin);fac[0]=inv[0]=inv[1]=1;for(ll i=2;i<N;i++)inv[i]=P-inv[P%i]*(P/i)%P;for(ll i=1;i<N;i++)fac[i]=fac[i-1]*i%P,inv[i]=inv[i-1]*inv[i]%P;scanf("%lld",&T);while(T--){scanf("%lld",&n);ans=0;for(ll i=1;i<=n;i++)scanf("%lld",&a[i]),Change(a[i],1);for(ll i=1,mx=0;i<=n;i++){Change(a[i],-1);mx=max(mx,a[i]);(ans+=F(n-i+1,Ask(mx)))%=P;if(a[i]<mx&&Ask(a[i]))break;}for(int i=1;i<=n;i++)t[i]=0;printf("%lld\n",ans);}return 0;
}