Tree Cutting HDU - 5909
题意:
一个无根树,n个点,n-1条边,每个节点有一个权值,一棵树的权值就是其节点(包含本身及其子节点)的权值的异或和;求价值为[0,m)的树有多少颗?(所谓的树其实就是原连通图的任意子图)
n<=1000
m<=2102^{10}210
题解:
不难想到树形dp
设dp[u][i]表示以u节点为根的价值为i的树的数量
能得到转移方程:
dp[u][j⨁k]=dp[u][j⨁k]+dp[u][j]∗d[v][k]dp[u][j\bigoplus k]=dp[u][j\bigoplus k]+dp[u][j]*d[v][k]dp[u][j⨁k]=dp[u][j⨁k]+dp[u][j]∗d[v][k]
u是v的父亲节点
不过这个式子直接算会超时,复杂度为O(n∗m∗m)O(n*m*m)O(n∗m∗m)
我们将式子变形:
dp[u][i]=∑j∗k=idp[u][j]∗d[v][k]dp[u][i]=\sum_{j*k=i}dp[u][j]*d[v][k]dp[u][i]=∑j∗k=idp[u][j]∗d[v][k]
这个就长得很像FWT
没错,就可以用FWT优化了
每次将dp[u][]和dp[v][]卷起来,并记录当前节点u为根的答案
优化后复杂度为O(nmlogm)O(nmlogm)O(nmlogm)
代码:
#include <bits/stdc++.h>
#include <unordered_map>
#define debug(a, b) printf("%s = %d\n", a, b);
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
clock_t startTime, endTime;
//Fe~Jozky
const ll INF_ll= 1e18;
const int INF_int= 0x3f3f3f3f;
void read(){};
template <typename _Tp, typename... _Tps> void read(_Tp& x, _Tps&... Ar)
{x= 0;char c= getchar();bool flag= 0;while (c < '0' || c > '9')flag|= (c == '-'), c= getchar();while (c >= '0' && c <= '9')x= (x << 3) + (x << 1) + (c ^ 48), c= getchar();if (flag)x= -x;read(Ar...);
}
template <typename T> inline void write(T x)
{if (x < 0) {x= ~(x - 1);putchar('-');}if (x > 9)write(x / 10);putchar(x % 10 + '0');
}
void rd_test()
{
#ifdef ONLINE_JUDGE
#elsestartTime = clock ();freopen("data.in", "r", stdin);
#endif
}
void Time_test()
{
#ifdef ONLINE_JUDGE
#elseendTime= clock();printf("\nRun Time:%lfs\n", (double)(endTime - startTime) / CLOCKS_PER_SEC);
#endif
}
const int P=1e9+7;
const int mod=1e9+7;
const int maxn1=(1<<13);
#define int long long
int val[1020];
vector<int>vec[2000];
int ans[maxn1];
int tmp[maxn1];
int dp[1020][maxn1];
int n,m;
/*
设dp[i][j]:表示以i为根的子树中异或和为j的数量
*/
int qpow(int a,int b){int ans=1;while(b){if(b&1)ans=ans*a%mod;a=a*a%mod;b>>=1;}return ans%mod;
}
void FWT(int x[],int t1,int t2,int len)
{const ll inv2= qpow(2,mod-2);for(int i=1;i<len;i<<=1)for(int j=0;j<len;j+=(i<<1))for(int k=0;k<i;k++){ll p=x[j+k],q=x[i+j+k];if(t1==0) x[i+j+k]=(q+P+t2*p)%P; //orelse if(t1==1) x[j+k]=(p+P+t2*q)%P; //andelse if(t1==2) //xor{x[j+k]=(p+q)%P*(t2<0?inv2:1)%P;x[i+j+k]=(p+P-q)%P*(t2<0?inv2:1)%P;} }
}
void say(int a[],int len){for(int i=0;i<len;i++){printf("a[%d]=%d\n",i,a[i]);}}
void solve(int a[],int b[],int len){
// say(a,len);
// say(b,len);FWT(a,2,1,len);FWT(b,2,1,len);for(int i=0;i<len;i++)a[i]=a[i]*b[i]%mod;FWT(a,2,-1,len);
}
void dfs(int u,int fa){dp[u][val[u]]=1;for(auto v:vec[u]){if(v==fa)continue;dfs(v,u);for(int j=0;j<m;j++){tmp[j]=dp[u][j];}solve(dp[u],dp[v],m);for(int j=0;j<m;j++){dp[u][j]=(tmp[j]+dp[u][j])%mod;}}for(int i=0;i<m;i++){ans[i]=(ans[i]+dp[u][i])%mod;}
}
signed main()
{
// rd_test();int t;read(t);while(t--){read(n,m);for(int i=1;i<=n;i++)read(val[i]);for(int i=1;i<n;i++){int u,v;read(u,v);vec[u].push_back(v);vec[v].push_back(u);}memset(dp,0,sizeof dp);memset(ans,0,sizeof ans);dfs(1,1);for(int i=0;i<m;i++){if(i==0)cout<<ans[i];else cout<<" "<<ans[i];}cout<<endl;for(int i=1;i<=n;i++)vec[i].clear();}//Time_test();
}
/**/