后缀自动机扩展到树形结构上。
先建出大的Trie,然后我们得到了一棵Trie树,对于树上的每个节点,保存一个后缀自动机从根走它代表的字符串后到达的节点,每次其儿子就从父亲的这个节点开始扩展。
1 /************************************************************** 2 Problem: 3926 3 User: idy002 4 Language: C++ 5 Result: Accepted 6 Time:4320 ms 7 Memory:460976 kb 8 ****************************************************************/ 9 10 #include <cstdio> 11 #include <cassert> 12 #include <cstring> 13 #include <algorithm> 14 #define N 100010 15 #define ST 3000100 16 #define SS ST<<1 17 using namespace std; 18 19 typedef long long dnt; 20 21 int qu[SS], stk[SS], bg, ed, top; 22 23 struct Sam { 24 int son[SS][10], val[SS], pnt[SS], ntot; 25 Sam() { pnt[0] = -1; }; 26 int append( int p, int c ) { 27 int np = ++ntot; 28 val[np] = val[p]+1; 29 while( p!=-1 && !son[p][c] ) 30 son[p][c]=np, p=pnt[p]; 31 if( p==-1 ) { 32 pnt[np] = 0; 33 } else { 34 int q=son[p][c]; 35 if( val[q]==val[p]+1 ) { 36 pnt[np] = q; 37 } else { 38 int nq = ++ntot; 39 memcpy( son[nq], son[q], sizeof(son[nq]) ); 40 val[nq] = val[p]+1; 41 pnt[nq] = pnt[q]; 42 pnt[q] = pnt[np] = nq; 43 while( p!=-1 && son[p][c]==q ) 44 son[p][c]=nq, p=pnt[p]; 45 } 46 } 47 return np; 48 } 49 dnt count() { 50 dnt rt = 0; 51 for( int u=1; u<=ntot; u++ ) 52 rt += val[u]-val[pnt[u]]; 53 return rt; 54 } 55 }sam; 56 struct Trie { 57 int son[ST][10], last[ST], ntot; 58 void insert( int *T ) { 59 int u=0; 60 for( int i=0; T[i]!=-1; i++ ) { 61 int c=T[i]; 62 if( !(0<=T[i]&&T[i]<10) ) { 63 assert( T[i]>=0 && T[i]<=9 ); 64 } 65 if( !son[u][c] ) son[u][c]=++ntot; 66 u=son[u][c]; 67 } 68 } 69 void build() { 70 last[0] = 0; 71 qu[bg=ed=1] = 0; 72 while( bg<=ed ) { 73 int u=qu[bg++]; 74 for( int c=0; c<10; c++ ) { 75 int v=son[u][c]; 76 if( !v ) continue; 77 last[v] = sam.append( last[u], c ); 78 qu[++ed] = v; 79 } 80 } 81 } 82 }trie; 83 84 int n, c; 85 int head[N], wght[N], dest[N<<1], next[N<<1], etot; 86 int dgr[N], fat[N]; 87 88 void adde( int u, int v ) { 89 etot++; 90 dest[etot] = v; 91 next[etot] = head[u]; 92 head[u] = etot; 93 } 94 void bfs( int s ) { 95 qu[bg=ed=1] = s; 96 fat[s] = 0; 97 while( bg<=ed ) { 98 int u=qu[bg++]; 99 for( int t=head[u]; t; t=next[t] ) { 100 int v=dest[t]; 101 if( v==fat[u] ) continue; 102 qu[++ed] = v; 103 fat[v] = u; 104 } 105 } 106 for( int t=1; t<=top; t++ ) { 107 int u=stk[t]; 108 bg=1, ed=0; 109 while( u ) { 110 qu[++ed] = wght[u]; 111 u=fat[u]; 112 } 113 reverse( qu+1, qu+1+ed ); 114 qu[++ed] = -1; 115 trie.insert( qu+1 ); 116 } 117 } 118 int main() { 119 scanf( "%d%d", &n, &c ); 120 for( int i=1; i<=n; i++ ) 121 scanf( "%d", wght+i ); 122 for( int i=1,u,v; i<n; i++ ) { 123 scanf( "%d%d", &u, &v ); 124 adde( u, v ); 125 adde( v, u ); 126 dgr[u]++, dgr[v]++; 127 } 128 for( int u=1; u<=n; u++ ) 129 if( dgr[u]<=1 ) stk[++top] = u; 130 for( int t=1; t<=top; t++ ) 131 bfs(stk[t]); 132 trie.build(); 133 printf( "%lld\n", sam.count() ); 134 return 0; 135 }