AGC021-F Trinity

题目大意

有一个$n\times m$的网格,有的格子被涂成了黑色,有的被涂成了白色。
定义三个数组$\{A_n\},\{B_m\},\{C_m\}$,分别表示第$i$行第一个黑格子的列编号,第$i$列第一个/最后一个黑格子的行编号。特殊地,如果这一行/列没有黑格子,值就分别是$m+1,n+1$和$0$。
求所有可能的数组三元组$(A,B,C)$的个数。
答案对$998244353$取模。

$1\leq n\leq8\times10^3,1\leq m\leq200$


题目分析

令$f_{n,m}$表示$n\times m$的网格,每一行都有黑色格子,能形成的不同的三元组个数。
行的约束比列的约束要少,考虑按列转移,每次加入一列,然后加入所有第一个黑格子在这一列的行,假设有$k$个,那么就换转移到$f_{n+k,m+1}$。
如果$k=0$,当新的一列没有黑格子的时候转移系数是$1$,否则就是$1\leq i\leq j\leq n$的$(i,j)$的对数,即${n+1\choose 2}$。
如果$k>0$,那么就会插入$k$个新的行,但是最小值最大值可能是在之前已经加入过的行里面。这个很简单,我们假定最小值向前移动一位另成一行,最大值向后移动一位另成一行,然后新增的行数$+2$,可以发现这样是等价的,因为我们只要将极值与另一个新增行连着当成新增行为极值,否则当成连着的那一行为极值。于是转移的系数是${n+k+2\choose k+2}$。
这样直接dp就是$O(n^2m)$的了。优化的话将组合数展开,写成卷积的形式,使用NTT加速就好了。
时间复杂度是$O(mn\log n)$。


代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <iostream>
#include <cstdio>
using namespace std;
const int P=998244353;
const int G=3;
const int L=16384;
int fact[L],invf[L],f[L],g[L],h[L],t[L],trs[L];
int omega[L+5];
int n,m,len,ans;
inline int quick_power(int x,int y)
{
int ret=1;
for (;y;y>>=1,x=1ll*x*x%P) if (y&1) ret=1ll*ret*x%P;
return ret;
}
void pre()
{
fact[0]=1;
for (int i=1;i<L;++i) fact[i]=1ll*fact[i-1]*i%P;
invf[L-1]=quick_power(fact[L-1],P-2);
for (int i=L-1;i>=1;--i) invf[i-1]=1ll*invf[i]*i%P;
}
inline int C(int n,int m){return n>=m?1ll*fact[n]*invf[m]%P*invf[n-m]%P:0;}
void NTT_pre()
{
for (len=1;len<=n<<1;len<<=1);
for (int i=0,ret;i<len;++i)
{
ret=0;
for (int x=i,j=1;j<len;j<<=1,x>>=1) ret=(ret<<1)|(x&1);
trs[i]=ret;
}
int g=quick_power(G,(P-1)/len);omega[0]=1;
for (int i=1;i<=len;++i) omega[i]=1ll*omega[i-1]*g%P;
}
void DFT(int *a,int sig)
{
for (register int i=0;i<len;++i) t[trs[i]]=a[i];
for (register int l=2;l<=len;l<<=1)
for (register int h=l>>1,j=0,wn=omega[sig>0?len/l:len-len/l];j<len;j+=l)
for (register int i=0,w=1;i<h;++i,w=1ll*w*wn%P)
{
register int u=t[i+j],v=1ll*t[i+j+h]*w%P;
t[i+j]=(u+v)%P,t[i+j+h]=(u-v+P)%P;
}
for (register int i=0;i<len;++i) a[i]=t[i];
if (sig<0) for (register int i=0,inv=quick_power(len,P-2);i<len;++i) a[i]=1ll*a[i]*inv%P;
}
int main()
{
freopen("trinity.in","r",stdin),freopen("trinity.out","w",stdout);
scanf("%d%d",&n,&m),pre(),NTT_pre();
f[0]=1;
for (int i=1;i<=n;++i) g[i]=invf[i+2];
DFT(g,1);
for (int i=1;i<=m;++i)
{
for (int j=0;j<=n;++j) h[j]=1ll*f[j]*invf[j]%P;
for (int j=n+1;j<len;++j) h[j]=0;
DFT(h,1);
for (int j=0;j<len;++j) h[j]=1ll*h[j]*g[j]%P;
DFT(h,-1);
for (int j=0;j<=n;++j) f[j]=(1ll*fact[j+2]*h[j]%P+1ll*(C(j+1,2)+1)*f[j]%P)%P;
}
ans=0;
for (int i=0;i<=n;++i) (ans+=1ll*C(n,i)*f[i]%P)%=P;
printf("%d\n",ans);
fclose(stdin),fclose(stdout);
return 0;
}