正题
题目链接:https://www.ybtoj.com.cn/problem/526
题目大意
一个n×mn\times mn×m的网格上有字母,你每次可以沿平行坐标轴对折网格,要求对折的对应位置字母相同。
询问有多少个可能对折出来的子矩阵。
1≤n×m≤1061\leq n\times m\leq 10^61≤n×m≤106
解题思路
首先行和列是独立的,行的对折不会和列的对折有任何关联,所以可以分开考虑行和列可以对折出的区间。
然后设每一行分开对每个轴求出一个最大对折距离(这个用二分+hashhashhash或者马拉车就可以求出来了),然后同位置的所有行取最小值就好了。
之后对于每个轴的位置就有一个可以转移过来的区间,而且左右的对折如果过头了不会影响答案(可以自己画个图,因为回文串的性质,那么两边一定可以先对折出一个更小不会冲突的区间)
维护一个前缀和就好了(考场上犯病写了个树状数组)
时间复杂度O(nlogn)O(n\log n)O(nlogn)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define ull unsigned long long
#define lowbit(x) (x&-x)
using namespace std;
const ll N=1e6+10;
const ull g=131;
ll n,m,t[N],ac[N],cr[N],dp[N],lim;
ull h[N],f[N],pw[N];
char c[N],*s[N];
void Change(ll x,ll val){while(x<=lim){t[x]+=val;x+=lowbit(x);}return;
}
ll Ask(ll x){ll ans=0;while(x){ans+=t[x];x-=lowbit(x);}return ans;
}
ll Query(ll l,ll r)
{return Ask(r)-Ask(l-1);}
ull geth(ll l,ll r)
{return h[r]-h[l-1]*pw[r-l+1];}
ull getf(ll l,ll r)
{return f[l]-f[r+1]*pw[r-l+1];}
signed main()
{freopen("paper.in","r",stdin);freopen("paper.out","w",stdout);scanf("%lld%lld",&n,&m);pw[0]=1;for(ll i=1;i<=max(n,m);i++)pw[i]=pw[i-1]*g;memset(ac,0x3f,sizeof(ac));memset(cr,0x3f,sizeof(cr));s[1]=c-1;for(ll p=1;p<=n;p++){scanf("%s",s[p]+1);for(ll i=1;i<=m;i++)h[i]=h[i-1]*g+s[p][i]-'a';for(ll i=m;i>=1;i--)f[i]=f[i+1]*g+s[p][i]-'a';for(ll i=2;i<=m;i++){ll l=0,r=min(i-2,m-i);while(l<=r){ll mid=(l+r)>>1;if(geth(i-mid-1,i-1)==getf(i,i+mid))l=mid+1;else r=mid-1;}ac[i]=min(ac[i],r);}s[p+1]=s[p]+m;}f[n+1]=0;for(ll p=1;p<=m;p++){for(ll i=1;i<=n;i++)h[i]=h[i-1]*g+s[i][p]-'a';for(ll i=n;i>=1;i--)f[i]=f[i+1]*g+s[i][p]-'a';for(ll i=2;i<=n;i++){ll l=0,r=min(i-2,n-i);while(l<=r){ll mid=(l+r)>>1;if(geth(i-mid-1,i-1)==getf(i,i+mid))l=mid+1;else r=mid-1;}cr[i]=min(cr[i],r);}}lim=m;Change(1,1);dp[1]=1;for(ll i=2;i<=m;i++){bool tmp=(Query(i-ac[i]-1,i-1)!=0);dp[i]=dp[i-1]+tmp;if(tmp)Change(i,1);}memset(t,0,sizeof(t));Change(m,1);ll sum=dp[m];for(ll i=m-1;i>=1;i--){bool tmp=(Query(i+1,i+ac[i+1]+1)!=0);if(tmp)sum+=dp[i],Change(i,1);}memset(t,0,sizeof(t));lim=n;Change(1,1);for(ll i=2;i<=n;i++){bool tmp=(Query(i-cr[i]-1,i-1)!=0);dp[i]=dp[i-1]+tmp;if(tmp)Change(i,1);}memset(t,0,sizeof(t));Change(n,1);ll ans=dp[n];for(ll i=n-1;i>=1;i--){bool tmp=(Query(i+1,i+cr[i+1]+1)!=0);if(tmp)ans+=dp[i],Change(i,1);}printf("%lld\n",ans*sum);return 0;
}