题意:给定长度为 nnn 的正整数序列 AAA,求满足 i<j<k,Aj−Ai=Ak−Aji<j<k,A_j-A_i=A_k-A_ji<j<k,Aj−Ai=Ak−Aj 的三元组个数。
n≤105,Ai≤3×104n\leq 10^5,A_i\leq 3\times 10^4n≤105,Ai≤3×104
三个位置只有 jjj 限制比较紧,“前后各找一个数差相同”显然比“往一个方向找一个数再找一个数差相同”容易得多。考虑计算每个位置作为 jjj 的贡献。
我们相当于要在前后各找一个数 x,yx,yx,y,使得 x+y=2akx+y=2a_kx+y=2ak。
这是个卷积形式,如果对每个 kkk 快速搞出 (∑i=1k−1xai)(∑i=k+1nxai)(\sum_{i=1}^{k-1}x^{a_i})(\sum_{i=k+1}^nx^{a_i})(∑i=1k−1xai)(∑i=k+1nxai) 就可以统计答案了。
但每个位置都要做 IDFT ,显得浪费。考虑分块。
设块长为 BBB,维护前面的块和后面的块的值域桶。暴力三元组里当前块有两个的,然后前面后面 FFT 乘起来,统计块内每个点作为中间位置的贡献。
这样复杂度是 O(NB(B2+AlogA))\Omicron\left(\frac NB(B^2+A\log A)\right)O(BN(B2+AlogA)),取 B∈O(AlogA)B\in \Omicron(\sqrt{A\log A})B∈O(AlogA) 后总复杂度 O(NAlogA)\Omicron(N\sqrt{A\log A})O(NAlogA),可以通过。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <cmath>
#define MAXN 100005
#define double long double
using namespace std;
typedef long long ll;
inline int read()
{int ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
const double Pi=acos(-1.0);
const int B=700;
struct complex{double x,y;inline complex(const double& x=0,const double& y=0):x(x),y(y){}};
inline complex operator +(const complex& a,const complex& b){return complex(a.x+b.x,a.y+b.y);}
inline complex operator -(const complex& a,const complex& b){return complex(a.x-b.x,a.y-b.y);}
inline complex operator *(const complex& a,const complex& b){return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
inline complex adj(const complex& a){return complex(a.x,-a.y);}
int l=16,lim=1<<l,r[MAXN];
inline void init(){for (int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));}
complex rt[2][20];
void fft(complex* a,int type)
{for (int i=0;i<lim;i++) if (i<r[i]) swap(a[i],a[r[i]]);for (int L=0;L<l;L++){int mid=1<<L,len=mid<<1;complex Wn=rt[type][L+1];for (int s=0;s<lim;s+=len){complex w(1,0);for (int k=0;k<mid;k++,w=w*Wn){complex x=a[s+k],y=w*a[s+mid+k];a[s+k]=x+y,a[s+mid+k]=x-y;}}}if (type) for (int i=0;i<lim;i++) a[i].x/=lim,a[i].y/=lim;
}
int a[MAXN],pre[MAXN],suf[MAXN],cnt[MAXN];
complex F[MAXN];
int main()
{init();for (int i=0;i<20;i++){double a=2*Pi/(1<<i);rt[1][i]=adj(rt[0][i]=complex(cos(a),sin(a)));}int n=read();for (int i=0;i<n;i++) a[i]=read();ll ans=0;for (int i=0;i<n;i++) ++suf[a[i]];for (int T=0;T<=n/B;T++){int l=T*B,r=min((T+1)*B,n)-1;for (int i=l;i<=r;i++) --suf[a[i]];for (int i=l+1;i<r;i++){for (int j=l;j<i;j++) ++cnt[a[j]];for (int j=i+1;j<=r;j++) ans+=cnt[2*a[i]-a[j]];for (int j=l;j<i;j++) cnt[a[j]]=0;}for (int i=l;i<r;i++)for (int j=i+1;j<=r;j++){if (2*a[j]>a[i]) ans+=suf[2*a[j]-a[i]];if (2*a[i]>a[j]) ans+=pre[2*a[i]-a[j]];}for (int i=0;i<lim;i++) F[i]=complex(pre[i],suf[i]);fft(F,0);for (int i=0;i<lim;i++) F[i]=F[i]*F[i];fft(F,1);for (int i=l;i<=r;i++) ans+=F[2*a[i]].y/2+0.5;for (int i=l;i<=r;i++) ++pre[a[i]];}cout<<ans;return 0;
}