正题
luogu 2486
金牌导航 树链剖分-3
题目大意
给你一棵树,让你进行以下操作:
1.把一条路径染上一个颜色
2.查询一条路径上有多少个颜色段
解题思路
用树链剖分把问题转化为链上问题
然后维护一下左右端点颜色和颜色总数就好了
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define N 100010
using namespace std;
int n, m, x, y, z, w, tot;
int c[N], q[N], hs[N], fa[N], sz[N], fq[N], top[N], dep[N], head[N];
char str[100];
struct rec
{int to, next;
}a[N<<1];
struct node
{int ln, rn, cn;
};
node merge(node a, node b)//合并两段
{node h;h.ln = a.ln;h.rn = b.rn;h.cn = a.cn + b.cn - (a.rn == b.ln);return h;
}
void swapp(node &a, node &b)
{node h = a;a = b;b = h;return;
}
void add(int x, int y)
{a[++tot].to = y;a[tot].next = head[x];head[x] = tot;return;
}
struct Tree
{#define ls x*2#define rs x*2+1node v[N<<2];int lazy[N<<2];void push_up(int x){v[x] = merge(v[ls], v[rs]);return;}void build(int x, int l, int r){lazy[x] = -1;if (l == r){v[x] = (node){c[fq[l]], c[fq[l]], 1};return;}int mid = l + r >> 1;build(ls, l, mid);build(rs, mid + 1, r);push_up(x);}void push_down(int x){if (lazy[x] != -1){lazy[ls] = lazy[rs] = lazy[x];v[ls] = v[rs] = (node){lazy[x], lazy[x], 1};lazy[x] = -1;}return;}void change(int x, int L, int R, int l, int r, int y){if (L == l && R == r){lazy[x] = y;v[x] = (node){y, y, 1};return;}int mid = L + R >> 1;push_down(x);if (r <= mid) change(ls, L, mid, l, r, y);else if (l > mid) change(rs, mid + 1, R, l, r, y);else change(ls, L, mid, l, mid, y), change(rs, mid + 1, R, mid + 1, r, y);push_up(x);return;}node ask(int x, int L, int R, int l, int r){if (L == l && R == r) return (node){v[x].ln, v[x].rn, v[x].cn};int mid = L + R >> 1;push_down(x);if (r <= mid) return ask(ls, L, mid, l, r);else if (l > mid) return ask(rs, mid + 1, R, l, r);else return merge(ask(ls, L, mid, l, mid), ask(rs, mid + 1, R, mid + 1, r));}
}T;
void dfs1(int x)
{sz[x] = 1;dep[x] = dep[fa[x]] + 1;for (int i = head[x]; i; i = a[i].next)if (a[i].to != fa[x]){fa[a[i].to] = x;dfs1(a[i].to);sz[x] += sz[a[i].to];if (sz[a[i].to] > sz[hs[x]]) hs[x] = a[i].to;}return;
}
void dfs2(int x, int y)
{q[x] = ++w;fq[w] = x;top[x] = y;if (hs[x]) dfs2(hs[x], y);for (int i = head[x]; i; i = a[i].next)if (a[i].to != fa[x] && a[i].to != hs[x])dfs2(a[i].to, a[i].to);return;
}
void solve(int x, int y, int z)
{while(top[x] != top[y]){if (dep[top[x]] < dep[top[y]]) swap(x, y);T.change(1, 1, n, q[top[x]], q[x], z);x = fa[top[x]];}if (dep[x] > dep[y]) swap(x, y);T.change(1, 1, n, q[x], q[y], z);return;
}
int ask(int x, int y)
{int pa = 0, pb = 0;node na, nb;while(top[x] != top[y]){if (dep[top[x]] < dep[top[y]]){swap(pa, pb);swapp(na, nb);swap(x, y);}if (!pa) na = T.ask(1, 1, n, q[top[x]], q[x]), pa = 1;else na = merge(T.ask(1, 1, n, q[top[x]], q[x]), na);x = fa[top[x]];}if (dep[x] > dep[y]){swap(pa, pb);swapp(na, nb);swap(x, y);}if (!pb) nb = T.ask(1, 1, n, q[x], q[y]), pb = 1;else nb = merge(T.ask(1, 1, n, q[x], q[y]), nb);if (!pa) return nb.cn;else if (!pb) return na.cn;else return merge((node){na.rn, na.ln, na.cn}, nb).cn;//连接点对上
}
int main()
{scanf("%d%d", &n, &m);for (int i = 1; i <= n; ++i)scanf("%d", &c[i]);for (int i = 1; i < n; ++i){scanf("%d%d", &x, &y);add(x, y);add(y, x);}dfs1(1);dfs2(1, 1);T.build(1, 1, n);while(m--){scanf("%s", str);if (str[0] == 'C'){scanf("%d%d%d", &x, &y, &z);solve(x, y, z);}else{scanf("%d%d", &x, &y);printf("%d\n", ask(x, y));}}return 0;
}