应该是一道中等难度的点分?麻烦在一些细节。
题目描述
lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义s(i,j) 为i 到j 的颜色数量。以及
现在他想让你求出所有的sum[i]
输入输出格式
输入格式:
第一行为一个整数n,表示树节点的数量
第二行为n个整数,分别表示n个节点的颜色c[1],c[2]……c[n]
接下来n-1行,每行为两个整数x,y,表示x和y之间有一条边
输出格式:
输出n行,第i行为sum[i]
说明
sum[1]=s(1,1)+s(1,2)+s(1,3)+s(1,4)+s(1,5)=1+2+3+2+2=10
sum[2]=s(2,1)+s(2,2)+s(2,3)+s(2,4)+s(2,5)=2+1+2+1+3=9
sum[3]=s(3,1)+s(3,2)+s(3,3)+s(3,4)+s(3,5)=3+2+1+2+3=11
sum[4]=s(4,1)+s(4,2)+s(4,3)+s(4,4)+s(4,5)=2+1+2+1+3=9
sum[5]=s(5,1)+s(5,2)+s(5,3)+s(5,4)+s(5,5)=2+3+3+3+1=12
对于40%的数据,n<=2000
对于100%的数据,1<=n,c[i]<=10^5
题目分析
想法一:按颜色拆贡献
这里应该是有一种小颜色大颜色的分块套路的。但是这个想法我只能解决全局路径的数量和,并不会落实到点的询问。
想法二:点分治
目前尚未归结出点分治适用的具体问题范围……不过这一题是可以用点分解决的。
考虑每一层点分树,我们只需要对它的节点处理贡献。这里的贡献分为两部分:重心答案;经过重心的路径对子树的贡献。
重心的答案只需要以它自身为根,遍历一边该层点分树即可。子树内的答案处理要略微麻烦一些,需要分颜色来考虑贡献。记$colCnt[i]$为所有以重心为起点的路径中,含有颜色$i$的路径条数。然后首先假定子树内所有点的答案都为$\sum colCnt[i]$,再容斥考虑重心到子树路径上的颜色所产生的贡献。
记当前点分树中除去正在处理的子树的大小为$outTot$,那么对于子树内点$x$,由于它具有颜色$c[x]$,所以对自身的答案有一个$outTot-colCnt[c[x]]$的贡献。并且,这一个贡献对于$x$的子树也是一概适用的,所以这一个标记要差分式地下传。
整体思路就是这些。这一题的点分涉及到例如“子树结构的重定向”或是“两个颜色桶并存”的一些细节问题,所以实现上面可能有一定的难度。
(话说这题的码风怎么这么丑)
1 #include<bits/stdc++.h> 2 typedef long long ll; 3 const int maxn = 100035; 4 const int maxm = 200035; 5 6 ll ans[maxn]; 7 int n,bloTot,outTot,c[maxn]; 8 int size[maxn],son[maxn],root; 9 int edgeTot,head[maxn],nxt[maxm],edges[maxm]; 10 int cols,cnt,cl,col[maxn],colTmp[maxn],colTim[maxn],colCnt[maxn],subCnt[maxn]; 11 bool colEx[maxn],divEx[maxn]; 12 13 int read() 14 { 15 char ch = getchar(); 16 int num = 0, fl = 1; 17 for (; !isdigit(ch); ch=getchar()) 18 if (ch=='-') fl = -1; 19 for (; isdigit(ch); ch=getchar()) 20 num = (num<<1)+(num<<3)+ch-48; 21 return num*fl; 22 } 23 void addedge(int u, int v) 24 { 25 edges[++edgeTot] = v, nxt[edgeTot] = head[u], head[u] = edgeTot; 26 edges[++edgeTot] = u, nxt[edgeTot] = head[v], head[v] = edgeTot; 27 } 28 void getRoot(int x, int fa) 29 { 30 size[x] = 1, son[x] = 0; 31 for (int i=head[x]; i!=-1; i=nxt[i]) 32 { 33 int v = edges[i]; 34 if (divEx[v]||v==fa) continue; 35 getRoot(v, x), size[x] += size[v]; 36 son[x] = std::max(son[x], size[v]); 37 } 38 son[x] = std::max(son[x], bloTot-size[x]); 39 if (son[x] < son[root]) root = x; 40 } 41 void colDfs(int x, int fa, int *cnt) 42 { 43 if (!colEx[c[x]]) colEx[c[x]] = 1, col[++cols] = c[x]; 44 if ((++colTim[c[x]])==1) cnt[c[x]] += size[x]; 45 for (int i=head[x]; i!=-1; i=nxt[i]) 46 if ((!divEx[edges[i]])&&(edges[i]!=fa)) 47 colDfs(edges[i], x, cnt); 48 --colTim[c[x]]; 49 } 50 void colClear() 51 { 52 for (int i=1; i<=cl; i++) colEx[colTmp[i]] = 0; 53 cols = 0; 54 } 55 void modify(int x, int fa, ll tag) 56 { 57 if ((++colTim[c[x]])==1) tag += outTot-colCnt[c[x]]; 58 ans[x] += tag+cnt; 59 for (int i=head[x]; i!=-1; i=nxt[i]) 60 { 61 int v = edges[i]; 62 if (v==fa||divEx[v]) continue; 63 modify(v, x, tag); 64 } 65 --colTim[c[x]]; 66 } 67 void calc(int rt) //核心操作在这里 68 { 69 colClear(), getRoot(rt, 0); 70 colDfs(rt, 0, colCnt); 71 cnt = 0, cl = cols; 72 for (int i=1; i<=cols; i++) 73 cnt += colCnt[col[i]], colTmp[i] = col[i]; 74 ans[rt] += cnt; 75 for (int i=head[rt]; i!=-1; i=nxt[i]) 76 { 77 int v = edges[i]; 78 if (divEx[v]) continue; 79 for (int j=1; j<=cl; j++) subCnt[colTmp[j]] = 0; //及时清除数组 80 colClear(); 81 colEx[c[rt]] = 1; 82 colDfs(v, rt, subCnt); //统计子树内的含颜色i路径条数 83 colEx[c[rt]] = 0; 84 colCnt[c[rt]] -= size[v], cnt -= size[v]; //除去重心出发的路径 85 for (int j=1; j<=cols; j++) 86 { 87 colCnt[col[j]] -= subCnt[col[j]]; //除去子树内的路径(因为考虑子树外路径) 88 cnt -= subCnt[col[j]]; 89 } 90 outTot = size[rt]-size[v], modify(v, rt, 0); //对子树内累加贡献 91 colCnt[c[rt]] += size[v], cnt += size[v]; //恢复处理子树前状态 92 for (int j=1; j<=cols; j++) 93 { 94 colCnt[col[j]] += subCnt[col[j]]; 95 cnt += subCnt[col[j]]; 96 } 97 } 98 for (int i=1; i<=cl; i++) 99 colCnt[colTmp[i]] = 0; //colTmp[]的作用;清空colCnt[] 100 } 101 void deal(int rt) 102 { 103 calc(rt), divEx[rt] = 1; 104 for (int i=head[rt]; i!=-1; i=nxt[i]) 105 { 106 int v = edges[i]; 107 if (divEx[v]) continue; 108 root = 0, bloTot = size[v]; 109 getRoot(v, 0), deal(root); 110 } 111 } 112 int main() 113 { 114 memset(head, -1, sizeof head); 115 n = read(), son[0] = n; 116 for (int i=1; i<=n; i++) c[i] = read(); 117 for (int i=1; i<n; i++) addedge(read(), read()); 118 bloTot = n, getRoot(1, 0), deal(root); 119 for (int i=1; i<=n; i++) printf("%lld\n",ans[i]); 120 return 0; 121 }
END