Description

有一棵$n$个点的树,$q$次操作,每次操作将某个点周围一圈点的点权+1. 你需要算出每次操作过后操作的点周围的点的点权异或和.

Input

第一行两个数$n, q$,表示树的点数和操作数。

接下来$n - 1$行每行两个数表示树上的一条边。

接下来$q$行每行一个数$x$,表示把$x$周围一圈点点权+1。

$n, q \le 5 \times 10^5$

Output

由于输出量较大,设第$i$个询问输出为$ans_i$,你只需要输出 $\sum_{i=1}^q ans_i(i^2 + i) \text{ mod } 10^9 + 7$ 。

Solution

好妙啊...

对每个节点弄个Trie, 其中按二进制维护所有儿子的权值.Trie中从根到叶子表示低位到高位.

考虑一次操作,操作节点$o$的父亲$f$可以提出来单独处理,此时只需要对$f$的父亲的Trie做修改.主要问题就是怎么支持将一个节点的儿子的权值全部$+1$.

注意到$+1$之后,所有权值的最低位都会变化.如果原来最低位是$1$, 那么第二位会发生变化.依此类推.

于是我们只需要从$o$的Trie的根节点往下走,始终走表示这一位为$1$的那个儿子,一路上交换左右子树,并维护异或值.这只需要记录Trie上每个节点$size$的奇偶性就行了.

$O(n \log n)$

Source

#include <bits/stdc++.h>

#define For(i, j, k) for(int i = j; i <= k; i++)

#ifdef __LINUX__
#define getchar getchar_unlocked
#endif

inline int Read(){
	char c = getchar();
	while(c > '9' || c < '0') c = getchar();
	int x = c - '0'; c = getchar();
	while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
	return x;
}

using namespace std;

const int N = 5e5 + 10;
const int Mod = 1e9 + 7;

int ans[N];

struct Trie{

	const static int Node = N * 40;
	const static int DEP = 18;

	int ch[Node][2], sz[Node], node;
	int root[N];

	void init(int x, int w){
		root[x] = ++node, sz[node] = w;
		int o = node;
		For(i, 0, DEP){
			ch[o][0] = ++node;
			o = ch[o][0];
			sz[o] = w;
		}
	}

	void modify(int x, int w){
		int o = root[x];
		ans[x] ^= w;
		For(i, 0, DEP){
			int &c = ch[o][bool(w & (1 << i))];
			if(!c) c = ++node;
			o = c, sz[o] ^= 1;
		}
	}

	void work(int x){
		int o = root[x];
		For(i, 0, DEP){
			int &lc = ch[o][0], &rc = ch[o][1];
			o = rc;
			swap(lc, rc);
			if(sz[lc] ^ sz[rc]) ans[x] ^= 1 << i;
		}
	}
	
}T;

int Begin[N], Next[N << 1], to[N << 1], e;

void add(int u, int v){
	to[++e] = v, Next[e] = Begin[u], Begin[u] = e;
}

int fa[N], cnt[N], val[N], deg[N];

void DFS_init(int o){
	for(int i = Begin[o]; i; i = Next[i]){
		int u = to[i];
		if(u == fa[o]) continue;
		fa[u] = o;
		DFS_init(u);
	}
}

int n, q;

int main(){
	
	n = Read(), q = Read();
	For(i, 2, n){
		int u = Read(), v = Read();
		add(u, v), add(v, u);
		deg[u] ^= 1, deg[v] ^= 1;
	}

	DFS_init(1);
	For(i, 1, n) T.init(i, deg[i] ^ (i != 1));

	long long S = 0;
	For(id, 1, q){
		int o = Read(), ret;
		++cnt[o];

		T.work(o);
		ret = ans[o];

		if(o ^ 1){
			int f = fa[o], x;
			++val[f];
			x = val[f] + cnt[fa[f]];

			ret ^= x;
			if(fa[f]) T.modify(fa[f], x - 1), T.modify(fa[f], x);
		}
		S = (S + (1ll * id * id + id) % Mod * ret) % Mod;
	}
	printf("%lld\n", S);

	return 0;
}