HDU 4812
思路:
点分治
先预处理好1e6 + 3以内到逆元
然后用map 映射以分治点为起点的链的值a 成他的下标 u
然后暴力跑出以分治点儿子为起点的链的值b,然后在map里查找inv[b]*k
代码:
#include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pii pair<int, int> #define mem(a, b) memset(a, b, sizeof(a)) #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout); //headconst int MOD = 1e6 + 3; const int INF = 0x7f7f7f7f; const int N = 1e5 + 5; int inv[MOD + 5], mp[MOD + 5], head[N], mxsz[N], sz[N], v[N], cnt = 0, rt = 0, n, k, ans1, ans2; int deep[N], dis[N], id[N], top = 0; bool vis[N]; struct edge {int to, nxt; }edge[N*2]; void add_edge(int u, int v) {edge[cnt].to = v;edge[cnt].nxt = head[u];head[u] = cnt++; } void init() {inv[1] = 1;for (int i = 2; i < MOD; i++) inv[i] = (MOD - MOD/i) * 1LL * inv[MOD%i] % MOD; } void update(int x, int y) {int t = (1LL * inv[x] * k) % MOD;int now = mp[t];if(!now) return ;if(now > y) swap(now, y);if(now < ans1 || now == ans1 && y < ans2) ans1 = now, ans2 = y; } void get_rt(int o, int u) {sz[u] = 1, mxsz[u] = 0;for (int i = head[u]; ~i; i = edge[i].nxt) {if(edge[i].to != o && !vis[edge[i].to]) {get_rt(u, edge[i].to);sz[u] += sz[edge[i].to];mxsz[u] = max(mxsz[u], sz[edge[i].to]);}}mxsz[u] = max(mxsz[u], n - sz[u]);if(mxsz[u] < mxsz[rt]) rt = u; } void get_d(int o, int u) {deep[++top] = dis[u];id[top] = u;for (int i = head[u]; ~i; i = edge[i].nxt) {if(!vis[edge[i].to] && edge[i].to != o) {dis[edge[i].to] = (1LL * dis[u] * v[edge[i].to])%MOD;get_d(u, edge[i].to);}} } void solve(int u) {vis[u] = true;mp[v[u]] = u;for (int i = head[u]; ~i; i = edge[i].nxt) {if(!vis[edge[i].to]) {top = 0, dis[edge[i].to] = v[edge[i].to];get_d(u, edge[i].to);for (int j = 1; j <= top; j++) update(deep[j], id[j]);top = 0, dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD;get_d(u, edge[i].to);for (int j = 1; j <= top; j++) {int t = deep[j];if(!mp[t] || id[j] < mp[t]) mp[t] = id[j];}}}mp[v[u]] = 0;for (int i = head[u]; ~i; i = edge[i].nxt) {if(!vis[edge[i].to]) {top = 0, dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD;get_d(u, edge[i].to);for (int j = 1; j <= top; j++) mp[deep[j]] = 0;}}for (int i = head[u]; ~i; i = edge[i].nxt) {if(!vis[edge[i].to]) {mxsz[0] = n = sz[edge[i].to];get_rt(rt = 0, edge[i].to);solve(rt);}} } int main() {init();int u, V;while(~scanf("%d%d", &n, &k)) {mem(head, -1);mem(vis, false);mem(mp, 0);cnt = 0;ans1 = ans2 = INF;for (int i = 1; i <= n; i++) scanf("%d", &v[i]);for (int i = 1; i < n; i++) scanf("%d%d", &u, &V), add_edge(u, V), add_edge(V, u);mxsz[0] = n;get_rt(rt = 0, 1);solve(rt);if(ans1 == INF) printf("No solution\n");else printf("%d %d\n", ans1, ans2);}return 0; }