[九省联考2018]秘密袭击CoaT

这篇题解写的是官方正解

正解需要以下知识

  • 整体DP
  • 多项式初步
  • 拉格朗日插值法 (可以在这里学)
  • 生成函数
  • 线段树合并

题目大意:

​ 给一颗有$N$个点的树,点权在$1 \sim W$之间,求树的每一个联通块的第$K$大点权之和。

看到这个题时,我的心里是懵逼的。

我们首先开始考虑转化题目。

$(1)$到$(2)$的过程是枚举每个权值的贡献。

$(2)$到$(3)$的过程比较难理解,我们直接考虑从每个$i$在公式$(2)$中会被算多少次,显然是$i$次。

那在公式$(3)$中也是会被算$i$次。

$(3)$到$(4)$的过程就比较显然了,这里不再赘述。

那么到现在,问题就变得比较显然了:

枚举权值$v$,求树上权值$v$出现次数超过$k$的联通块个数。

我们设计一个DP。

$f[i][j][k]$表示以$i$为根的子树中,权值大于等于$j$的权值出现为$k$次的方案数。

转移显然

最后答案就是$\sum\limits{k’=1}^{k}\sum\limits{j=1}^{W}\sum\limits_{i=1}^{N} f[i][j][k’]$

复杂度$\mathcal{O}(N^2*k)$,据说有人大力过了。

我们考虑优化这个转移,不难发现,这个DP其实是背包的一种,而转移就是背包的合并。

那我们不妨直接考虑生成函数。

设$F[i][j]$表示以$i$为根的子树中,权值大于等于$j$的权值的生成函数。

则$F[i][j]=\sum\limits_{k=0}^{n} f[i][j][k] \times x^k$,这是一个$N$次多项式。

但是最后我们要求的是整棵树的所有$F[i][j]$之和,所以我们不妨再设一个$G[i][j]$。

$G[i][j]=\sum\limits_{x \in subtree(i)}F[x][j]$

$F[i][j]$在转移时是多项式卷积,还是很慢,$G[i][j]$在转移时只要维护一下就行了。

所以我们考虑将它转换为$N+1$个点值,这样的话转移时就是普通乘法了。

我们就令$x=1 \sim N+1$,然后将所有$G[i][j]$在$x$时的值都求出来,最后进行拉格朗日插值法将原始的多项式差出来就行了,可具体怎么实现呢?

我们首先在最外层枚举$x \in [1,N+1]$,然后每次进行一次$DFS$,但具体如何进行转移呢?

我们不难发现,$F[i][j] \leftarrow F[son[i]][j]$转移过程中其实就是$[j]$的对应位置相乘。

所以我们可以使用整体$DP$的思想在每个点上都维护一颗线段树,然后在转移时进行线段树合并就可以了。

正解差不多就是这个意思。

具体合并方法如下:

初始化:

转移时: $F[i][j] \times =(F[son[i]][j]+1) , G[i][j]+=G[son[i]][j]$

最后,$G[i][j]+=F[i][j]$。

我们可以将$F[son[i]][j]+1$的操作放在$DFS$ son[i]后进行。

可是你说了这么多,线段树到底应该怎么写?

我们设变换$(a,b,c,d)$可以使$(f,g)$变换为$(a \times f+b,c \times f+d+g)$

然后每个线段树维护一个变换即可。

变换结合的话手推一下即可。

注意:变换在普遍情况下没有交换律,只有结合律。

剩下的实现方法详见代码。

注意以下坑点:

  • 模数是64123,做乘法时要用unsigned int。
  • unsigned int 取模时模数必须是unsigned类型。
  • 初始化变换时应该是$(1,0,0,0)$

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<set>
#include<map>
#include<queue>
#include<cassert>
#include<stack>
typedef unsigned int uint;
using namespace std;
const uint P=64123;
const int MAXN=1700;
const int MAXM=5000;
struct __edge{
int nxt,v;
}Edge[MAXM];
int head[MAXN],cnt_e,rt[MAXN],stk[MAXN<<6],top,cnt,d[MAXN],n,k,w,u,v,ans[MAXN],f[MAXN],tmp[MAXN];
int inv[P+10];
inline void add(int u,int v) {Edge[++cnt_e].v=v;Edge[cnt_e].nxt=head[u];head[u]=cnt_e;}
struct data{
uint a,b,c,d;
inline void clear(){a=1;b=c=d=0;}
inline data(uint _a=1,uint _b=0,uint _c=0,uint _d=0):a(_a),b(_b),c(_c),d(_d){}
inline data operator * (const data &rhs) const {return data(a*rhs.a%P,(rhs.b+b*rhs.a%P)%P,(a*rhs.c%P+c)%P,(b*rhs.c%P+d+rhs.d)%P);}
inline data operator *= (const data &rhs) {return (*this)=(*this)*rhs;}
};
struct node{
int ls,rs;
data val;
node(){val.a=1;val.b=val.c=val.d=ls=rs=0;}
inline void clear(){val.a=1;val.b=val.c=val.d=ls=rs=0;}
}t[MAXN<<6];
inline int newnode() {if(top) return stk[top--];else return ++cnt;}
inline void delnode(int &x)
{
if(!x) return;
delnode(t[x].ls);delnode(t[x].rs);
stk[++top]=x;t[x].clear();x=0;
}
inline void pushdown(const int &x)
{
if(!t[x].ls) t[x].ls=newnode();
if(!t[x].rs) t[x].rs=newnode();
t[t[x].ls].val*=t[x].val;
t[t[x].rs].val*=t[x].val;
t[x].val.clear();
}
void change(int &x,int l,int r,int cl,int cr,data val)
{
if(!x) x=newnode();
if(cl<=l&&r<=cr) {t[x].val*=val;return;}
int mid=(l+r)>>1;pushdown(x);
if(cl<=mid) change(t[x].ls,l,mid,cl,cr,val);
if(cr>mid) change(t[x].rs,mid+1,r,cl,cr,val);
}
int merge(int &x,int &y)
{
if(!x||!y) return x|y;
if(!t[x].ls&&!t[x].rs) swap(x,y);
if(!t[y].ls&&!t[y].rs)
{
t[x].val*=data(t[y].val.b,0,0,0);
t[x].val*=data(1,0,0,t[y].val.d);
return x;
}
pushdown(x);pushdown(y);
t[x].ls=merge(t[x].ls,t[y].ls);
t[x].rs=merge(t[x].rs,t[y].rs);
return x;
}
uint query(int x,int l,int r)
{
if(l==r) return t[x].val.d;
int mid=(l+r)>>1;
uint ret=0;
pushdown(x);
ret=query(t[x].ls,l,mid);
(ret+=query(t[x].rs,mid+1,r))%=P;
return ret;
}
void dfs(int x,int fa,int k0)
{
change(rt[x],1,w,1,w,data(0,1,0,0));
for(int i=head[x];i;i=Edge[i].nxt)
{
int v=Edge[i].v;
if(v==fa) continue;
dfs(v,x,k0);
merge(rt[x],rt[v]);
delnode(rt[v]);
}
change(rt[x],1,w,1,d[x],data(k0,0,0,0));
change(rt[x],1,w,1,w,data(1,0,1,0));
change(rt[x],1,w,1,w,data(1,1,0,0));
}
uint Lagrange_interpolation()
{
uint ret=0;
memset(f,0,sizeof f);f[0]=1;
for(int i=1;i<=n+1;i++)
{
for(int j=n+1;j>=1;j--)
f[j]=f[j]*(P-i)%P,f[j]=(f[j]+f[j-1])%P;
f[0]=f[0]*(P-i)%P;
}
for(int i=1;i<=n+1;i++)
{
memcpy(tmp,f,sizeof(uint)*(n+1));
for(int j=0;j<=n;j++) //divide by (x-i)
tmp[j]=P-tmp[j]*inv[i]%P,tmp[j+1]=(tmp[j+1]-tmp[j]+P)%P;
uint tans=0;
for(int j=k;j<=n;j++) tans=(tans+tmp[j])%P;
for(int j=1;j<=n+1;j++)
{
if(i==j) continue;
if(j<i) tans=tans*inv[i-j]%P;
else tans=tans*(P-inv[j-i])%P;
}
ret=(ret+tans*ans[i]%P)%P;
}
return ret;
}
int main()
{
scanf("%d%d%d",&n,&k,&w);
for(int i=1;i<=n;i++) scanf("%d",&d[i]);
for(int i=1;i<n;i++) scanf("%d%d",&u,&v),add(u,v),add(v,u);
for(int i=1;i<=n+1;i++)
{
dfs(1,0,i);
ans[i]=query(rt[1],1,w);
delnode(rt[1]);
}
inv[1]=1;
for(int i=2;i<P;i++)
inv[i]=(P-(P/i)*inv[P%i]%P)%P,assert(inv[i]>0);
printf("%u\n",Lagrange_interpolation());
}