description
戳我看题目哦
solution
有一道非常相似的题目
一棵树,每条边限制两个端点的大小关系(限制 a[u]>a[v]a[u]>a[v]a[u]>a[v] 或 a[u]<a[v]a[u]<a[v]a[u]<a[v])
求有多少种符合要求的排列aaa满足整棵树的限制。n<=5000n<=5000n<=5000
考虑如果所有边都是朝一个方向的话很好做
答案就是n!n!n!除以每个子树的大小
如果存在反向边的话,暴力枚举断开若干个反向边,剩下的边改为正向,然后计算答案
容斥即可。这样暴力做的复杂度是 O(2n∗n)O(2^n*n)O(2n∗n) 的
考虑 dpdpdp,f(i,j,k)f(i,j,k)f(i,j,k) 表示以 iii 为根的子树,当前 iii 所在连通块内有 jjj 个点,总共反向 kkk 条边的方案数
合并两棵子树时,如果边是正向的,那么直接合并;
否则要么断开,要么让 k+1k+1k+1 并且按照正向合并
复杂度 nnn 的若干次方
考虑最后的容斥只需要关注 kkk 的奇偶性,因此第三维完全可以省掉
即合并两棵子树时,如果边是正向则直接合并,否则值就是断开的方案减掉把边正向的方案
因此就是一个简单的树背包,复杂度 O(n2)O(n^2)O(n2)
此题只是需要将二维dpdpdp再次优化即可
设dp[i]dp[i]dp[i]表示前缀iii的合法方案数,cnt[i]cnt[i]cnt[i]表示前缀iii中>>>的个数
dp[i]i!=∑j=0i−1[s[j]=′>′](i−j)!(−1)cnt[i−1]−cnt[j]×dp[j]j!\frac{dp[i]}{i!}=\sum_{j=0}^{i-1}\frac{[s[j]='>']}{(i-j)!}(-1)^{cnt[i-1]-cnt[j]}\times \frac{dp[j]}{j!}i!dp[i]=j=0∑i−1(i−j)![s[j]=′>′](−1)cnt[i−1]−cnt[j]×j!dp[j]
将(−1)cnt[i−1](-1)^{cnt[i-1]}(−1)cnt[i−1]提出来,剩余部分用NTTNTTNTT分治完成
有难度
code
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define mod 998244353
#define int long long
#define maxn 400005
int len, inv;
char s[maxn];
int cnt[maxn];
int fac[maxn], ifac[maxn], r[maxn];
int f[maxn], g[maxn], dp[maxn];int qkpow( int x, int y ) {int ans = 1;while( y ) {if( y & 1 ) ans = ans * x % mod;x = x * x % mod;y >>= 1;}return ans;
}void NTT( int *c, int opt ) {for( int i = 0;i < len;i ++ )if( i < r[i] ) swap( c[i], c[r[i]] );for( int i = 1;i < len;i <<= 1 ) {int omega = qkpow( opt == 1 ? 3 : mod / 3 + 1, ( mod - 1 ) / ( i << 1 ) );for( int j = 0;j < len;j += ( i << 1 ) ) {int w = 1;for( int k = 0;k < i;k ++, w = w * omega % mod ) {int x = c[j + k], y = w * c[j + k + i] % mod;c[j + k] = ( x + y ) % mod;c[j + k + i] = ( x - y + mod ) % mod;}}}if( opt == -1 ) {int inv = qkpow( len, mod - 2 );for( int i = 0;i < len;i ++ )c[i] = c[i] * inv % mod;}
}void solve( int L, int R ) {if( L == R ) {if( ! L ) dp[L] = 1;else dp[L] = cnt[L] & 1 ? mod - dp[L] : dp[L];//单独提出来 return;}int mid = ( L + R ) >> 1;solve( L, mid );len = 1; int l = 0;while( len <= R - L + 1 + mid - L ) len <<= 1, l ++;for( int i = 0;i < len;i ++ )r[i] = ( r[i >> 1] >> 1 ) | ( ( i & 1 ) << ( l - 1 ) );for( int i = 0;i <= mid - L;i ++ )if( s[i + L] == '<' && i + L != 0 ) f[i] = 0;else f[i] = cnt[i + L] & 1 ? dp[i + L] : mod - dp[i + L];//注意奇偶转换 for( int i = mid - L + 1;i < len;i ++ ) f[i] = 0;for( int i = 0;i <= R - L + 1;i ++ ) g[i] = ifac[i];for( int i = R - L + 2;i < len;i ++ ) g[i] = 0;NTT( f, 1 );NTT( g, 1 );for( int i = 0;i < len;i ++ ) f[i] = f[i] * g[i] % mod;NTT( f, -1 );for( int i = mid + 1;i <= R;i ++ ) dp[i] = ( dp[i] + f[i - L] ) % mod;solve( mid + 1, R );
}signed main() {scanf( "%s", s + 1 ); int n = strlen( s + 1 );s[++ n] = '>';fac[0] = 1;for( int i = 1;i <= n;i ++ )fac[i] = fac[i - 1] * i % mod;ifac[n] = qkpow( fac[n], mod - 2 );for( int i = n - 1;~ i;i -- )ifac[i] = ifac[i + 1] * ( i + 1 ) % mod;for( int i = 1;i <= n;i ++ )cnt[i] = cnt[i - 1] + ( s[i] == '>' );solve( 0, n );printf( "%lld\n", dp[n] * fac[n] % mod );return 0;
}