题干:
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
题目大意:
第一行给定n和k,然后给定一棵n个点的无向树,每条边有边权v,求点对(u,v)的个数,使得dis[u][v]<=k,n<=1e4
解题报告:
点分治裸题,注意求重心的时候别大于号小于号别弄反了,应该是求最大子树最小的顶点。还有别忘size的重构,很多代码貌似换根之后都没有重构,这样得到的all变量应该就是不对的,相应求出的当前子树的重心也应该是不对的。哦对了,别忘容斥。
AC代码:
//严格树的重心
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>
#include<stack>
#include<map>
#include<vector>
#include<set>
#include<string>
#include<cmath>
#include<cstring>
#define FF first
#define SS second
#define ll long long
#define pb push_back
#define pm make_pair
using namespace std;
typedef pair<int,int> PII;
const int MAX = 2e5 + 5;
int tot,head[MAX];
int n,k;
ll ans ;
struct Edge {int u,v,w,ne;
} e[MAX];
void add(int u,int v,int w) {e[++tot].u = u; e[tot].v = v; e[tot].w = w;e[tot].ne = head[u]; head[u] = tot;
}
int size[MAX],vis[MAX],son[MAX],rt,all;
void getRoot(int cur,int fa) {size[cur] = 1; son[cur] = 0;for(int i = head[cur]; ~i; i = e[i].ne) {int v = e[i].v;if(v == fa || vis[v]) continue;getRoot(v,cur);size[cur] += size[v];son[cur] = max(son[cur],size[v]); }son[cur] = max(son[cur],all - size[cur]);if(son[rt] == 0 || son[rt] > son[cur]) rt = cur;
}
int tott,dis[MAX];
void getdis(int cur,int fa,int ddis) {dis[++tott] = ddis;for(int i = head[cur]; ~i; i = e[i].ne) {int v = e[i].v;if(v == fa || vis[v]) continue;getdis(v,cur,ddis + e[i].w); }
}
int cal(int cur,int diss) {tott = 0; getdis(cur,0,diss);sort(dis+1,dis+tott+1);int l = 1, r = tott,res = 0;for(;l<r; l++) {while(dis[l]+dis[r] > k && r > l) r--;res += r-l;}return res;
}
void gx(int cur,int fa) {size[cur] = 1;for(int i = head[cur]; ~i; i = e[i].ne) {int v = e[i].v;if(v == fa || vis[v]) continue;gx(v,cur);size[cur] += size[v]; }
}
void dfs(int cur) {rt = 0;getRoot(cur,0);cur = rt;//getRoot(cur,0);vis[cur] = 1;gx(cur,0);ans += cal(cur,0);for(int i = head[cur]; ~i; i = e[i].ne) {int v = e[i].v;if(vis[v]) continue;ans -= cal(v,e[i].w);all = size[v]; dfs(v);}
}
void init() {for(int i = 1; i<=n; i++) {vis[i] = size[i] = 0;head[i] = -1; }tot=ans=0;
}
int main()
{while(~scanf("%d%d",&n,&k) && n+k) {init(); for(int u,v,w,i = 1; i<=n-1; i++) {scanf("%d%d%d",&u,&v,&w);add(u,v,w);add(v,u,w);} dfs(1);printf("%lld\n",ans);} return 0 ;
}