problem
luogu-P4070
魔咒串由许多魔咒字符组成,魔咒字符可以用数字表示。例如可以将魔咒字符 1,21,21,2 拼凑起来形成一个魔咒串 [1,2][1,2][1,2]。
一个魔咒串 S 的非空字串被称为魔咒串 S 的生成魔咒。
例如 S=[1,2,1]S=[1,2,1]S=[1,2,1] 时,它的生成魔咒有 [1],[2],[1,2],[2,1],[1,2,1][1],[2],[1,2],[2,1],[1,2,1][1],[2],[1,2],[2,1],[1,2,1] 五种。
S=[1,1,1]S=[1,1,1]S=[1,1,1] 时,它的生成魔咒有 [1],[1,1],[1,1,1][1],[1,1],[1,1,1][1],[1,1],[1,1,1] 三种,最初 S 为空串。
共进行 nnn 次操作,每次操作是在 SSS 的结尾加入一个魔咒字符。每次操作后都需要求出,当前的魔咒串 SSS 共有多少种生成魔咒。
solution
本质不同的子串个数,是后缀数组经典运用,∑n−i+1−hi=n(n+1)2−∑hi\sum n-i+1-h_i=\frac{n(n+1)}{2}-\sum h_i∑n−i+1−hi=2n(n+1)−∑hi。
考虑每次在字符串末尾加入一个数字的话,我们就需要每次重新求一遍所有的后缀,因为全都变了。hhh 的变化也是动态不定的。
但如果反过来,每次在字符串开头加入一个数字的话,相当于只是多加了一个后缀。hhh 的变化是 O(1)O(1)O(1) 的。
所以我们将 nnn 个数组成的字符串反转,提前求出每个后缀的排名,hhh 数组,建立 ststst 表。
倒着枚举 i=n∼1i=n\sim 1i=n∼1,加入 iii 开头的后缀,然后查询与其最相近的后缀,即 rnk[i]rnk[i]rnk[i] 的前后缀。
重复的子串个数为 cnt+lcp(pre,rnk[i])+lcp(rnk[i],nxt)−lcp(pre,nxt)cnt+lcp(pre,rnk[i])+lcp(rnk[i],nxt)-lcp(pre,nxt)cnt+lcp(pre,rnk[i])+lcp(rnk[i],nxt)−lcp(pre,nxt)。
用 setsetset 维护 rnkrnkrnk 数组即可,当然也可以链表等各种快速查前后缀的工具。
code
#include <bits/stdc++.h>
using namespace std;
#define maxn 100005
#define int long long
int h[maxn], s[maxn], x[maxn], sa[maxn], id[maxn], tot[maxn], rnk[maxn << 1];
int lg[maxn], st[maxn][20];
set < int > NB;
int n, m;void SA() {for( int i = 1;i <= n;i ++ ) tot[x[i] = s[i]] ++;for( int i = 1;i <= m;i ++ ) tot[i] += tot[i - 1];for( int i = n;i;i -- ) sa[tot[x[i]]--] = i;for( int k = 1;k <= n;k <<= 1 ) {int num = 0;for( int i = n - k + 1;i <= n;i ++ ) id[++ num] = i;for( int i = 1;i <= n;i ++ ) if( sa[i] > k ) id[++ num] = sa[i] - k;memset( tot, 0, sizeof( tot ) );for( int i = 1;i <= n;i ++ ) tot[x[i]] ++;for( int i = 1;i <= m;i ++ ) tot[i] += tot[i - 1];for( int i = n;i;i -- ) sa[tot[x[id[i]]]--] = id[i];for( int i = 1;i <= n;i ++ ) rnk[i] = x[i];x[sa[1]] = num = 1;for( int i = 2;i <= n;i ++ )x[sa[i]] = (rnk[sa[i]] == rnk[sa[i - 1]] and rnk[sa[i] + k] == rnk[sa[i - 1] + k]) ? num : ++ num;if( n == num ) break;m = num;}
}void height() {for( int i = 1;i <= n;i ++ ) rnk[sa[i]] = i;for( int i = 1, k = 0;i <= n;i ++ ) {if( rnk[i] == 1 ) continue;if( k ) k --;int j = sa[rnk[i] - 1];while( i + k <= n and j + k <= n and s[i + k] == s[j + k] ) k ++;h[rnk[i]] = k;}
}void ST() {lg[0] = -1; for( int i = 1;i <= n;i ++ ) lg[i] = lg[i >> 1] + 1;for( int i = 1;i <= n;i ++ ) st[i][0] = h[i];for( int j = 1;(1 << j) <= n;j ++ )for( int i = 1;i + (1 << j) - 1 <= n;i ++ )st[i][j] = min( st[i][j - 1], st[i + (1 << j - 1)][j - 1] );
}int lcp( int l, int r ) {l ++; int i = lg[r - l + 1];return min( st[l][i], st[r - (1 << i) + 1][i] );
}signed main() {scanf( "%lld", &n );for( int i = 1;i <= n;i ++ ) scanf( "%lld", &s[i] ), x[i] = s[i];sort( x + 1, x + n + 1 );m = unique( x + 1, x + n + 1 ) - x - 1;for( int i = 1;i <= n;i ++ ) s[i] = lower_bound( x + 1, x + m + 1, s[i] ) - x;reverse( s + 1, s + n + 1 );SA();height();ST();NB.insert( 0 );NB.insert( n + 1 );int ans = 0;for( int i = n;i;i -- ) {NB.insert( rnk[i] );auto it = NB.find( rnk[i] );auto pre = it; -- pre;auto nxt = it; ++ nxt;ans = ans - lcp( *pre, *nxt ) + lcp( *pre, rnk[i] ) + lcp( rnk[i], *nxt );printf( "%lld\n", (n - i + 1) * (n - i + 2) / 2 - ans );}return 0;
}