Codeforces 995F Cowmpany Cowmpensation

题目大意

给定一棵 $1$ 为根的树,你要给每一个点分配一个 $1$ 到 $D$ 之间的权值,使得每一个点的权值都小于等于其父亲节点的权值。
求出方案数,对 $10^9+7$ 取模。

$1\leq n\leq3\times10^3,1\leq D\leq10^9$


题目分析

考虑最简单的 dp,令 $f_{x,i}$ 表示做到第 $x$ 个点为根的子树中,$x$ 填 $i$ 时的方案总数。
使用前缀和优化一下就可以做到 $O(nD)$。

经过观察,可以发现其实 $f_{x,i}$ 是一个关于 $i$ 的最高次数为 $size(x)-1$ 的多项式。
这个用归纳法很好证明:在叶子节点处 $f_{x,i}=1$ 显然成立;对于非叶子节点,假设其所有儿子节点都满足,定义 $g_{x,i}$ 是 $f_{x,i}$ 的前缀和,那么其所有儿子 $y$ 的 $g_{y,i}$ 都是最高次数为 $size(y)$ 的多项式,dp 的转移是对应位置相乘,因此点 $f_{x,i}$ 就是一个最高次数为 $size(x)-1$ 的多项式。

有了这个结论,我们 dp 的时候只需要保留数组的前 $0…n$ 项,然后直接插值插出 $g_{1,D}$ 就好了。
时间复杂度 $O(n^2)$。


代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <iostream>
#include <cstdio>

using namespace std;

const int P=1000000007;
const int N=3005;

inline void add(int &x,int y){x=x+y>=P?x+y-P:x+y;}

inline int quick_power(int x,int y)
{
int ret=1;
for (;y;y>>=1,x=1ll*x*x%P) if (y&1) ret=1ll*ret*x%P;
return ret;
}

inline int sig(int x){return x&1?P-1:1;}

int fact[N],invf[N],prex[N],sufx[N],last[N],fa[N],tov[N],nxt[N];
int f[N][N];
int n,D,tot;

inline void insert(int x,int y){tov[++tot]=y,nxt[tot]=last[x],last[x]=tot;}

void pre()
{
fact[0]=1;
for (int i=1;i<=n;++i) fact[i]=1ll*fact[i-1]*i%P;
invf[n]=quick_power(fact[n],P-2);
for (int i=n;i>=1;--i) invf[i-1]=1ll*invf[i]*i%P;
}

void dp(int x,int fa=0)
{
for (int i=1;i<=n;++i) f[x][i]=1;
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa)
{
dp(y,x);
for (int j=1;j<=n;++j) f[x][j]=1ll*f[x][j]*f[y][j]%P;
}
for (int i=1;i<=n;++i) add(f[x][i],f[x][i-1]);
}

inline int F(int *f,int x)
{
prex[0]=x;
for (int i=1;i<=n;++i) prex[i]=1ll*prex[i-1]*(x-i+P)%P;
sufx[n]=(x-n+P)%P;
for (int i=n-1;i>=0;--i) sufx[i]=1ll*sufx[i+1]*(x-i+P)%P;
int ret=0;
for (int i=0;i<=n;++i) add(ret,1ll*(i?prex[i-1]:1)*(i<n?sufx[i+1]:1)%P*invf[i]%P*invf[n-i]%P*sig(n-i)%P*f[i]%P);
return ret;
}

int main()
{
freopen("cow.in","r",stdin),freopen("cow.out","w",stdout);
scanf("%d%d",&n,&D),pre();
for (int x=2;x<=n;++x) scanf("%d",&fa[x]),insert(fa[x],x);
dp(1),printf("%d\n",F(f[1],D));
fclose(stdin),fclose(stdout);
return 0;
}