HackerEarth June Circuit '18 - Choose items

前言

并不是一道怎么棒棒的题,但是鏖战了一定的时间,还是 mark 一下吧……


题目大意

给定长度为 $n$ 的序列 ${c_i}$。
有 $q$ 次询问,每次询问给定 $l,r,k$。你需要回答$$[x^k] \prod_{i=l}^r(1+c_ix)$$答案对 $998244353$ 取模。强制在线。

$1\leq n,q\leq 2^{14}$


题目分析

有一个很 naive 的想法就是直接存下每一个位置的前缀积的各项点值,这个可以 $O(n^2)$ 暴力算。然后每次询问两端相除再 IDFT 回来就好了,特判一下 $0$ 的情况。
这样做会 MLE。考虑分块,设一个阈值 $T$,我们每隔 $T$ 个位置存一下前缀积点值。询问时对于整块直接 IDFT。对于散块,考虑到 $T$ 可以设置得很小(只要在空间允许范围内,并不会影响时间复杂度),而算法的瓶颈主要在于 IDFT,因此我们可以直接用 $O(T^2)$ 的暴力来完成散块多项式的计算。
这样可以做到时间复杂度 $O(qn\log n)$。
可是这样还是通过不了这一题。注意到整块两两的搭配只有 $O\left((\frac nT)^2\right)$ 种,我们直接预处理任意两个块之间的所有双项式的乘积。然后对于整块,我们考虑使用分治 FFT 或者线段树加 FFT 来得到多项式。
这样做的复杂度是 $O\left((\frac nT)^2n\log n+qT\log^2 T\right)$。
考虑阈值均衡,将 $q$ 看成与 $n$ 同阶,粗略地取 $T=\left(\frac{n^2}{\log n}\right)^{\frac 13}$ 即可通过此题。
总用时跑了 80 多秒……并不知道只跑了 9s 的 mcfx 神仙写的是什么……


代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cmath>

using namespace std;

inline int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x*f;
}

int buf[30];

inline void write(int x)
{
if (x<0) putchar('-'),x=-x;
for (;x;x/=10) buf[++buf[0]]=x%10;
if (!buf[0]) buf[++buf[0]]=0;
for (;buf[0];putchar('0'+buf[buf[0]--]));
}

const int P=998244353;
const int L=14;
const int N=1<<L;
const int MAX_BLOCK_SIZE=275;
const int MAX_BLOCK_CNT=65;
const int G=3;

inline void Add(int &x,int y){x=x+y>=P?x+y-P:x+y;}
inline void Sub(int &x,int y){x=x<y?x-y+P:x-y;}

inline int add(int x,int y){return x+y>=P?x+y-P:x+y;}
inline int sub(int x,int y){return x<y?x-y+P:x-y;}

inline int quick_power(int x,int y)
{
int ret=1;
for (;y;y>>=1,x=1ll*x*x%P) if (y&1) ret=1ll*ret*x%P;
return ret;
}

int a[N],t[N],trs[N];
int g[L][1024];
int omega[N+5];
int len,len_;

int f[MAX_BLOCK_CNT*MAX_BLOCK_CNT>>1][N];
int st[MAX_BLOCK_CNT],en[MAX_BLOCK_CNT];
int bel[N],c[N],coef[N];
int n,q,siz,cnt,ans;

inline void NTT_pre()
{
int g=quick_power(G,(P-1)/N);omega[0]=1;
for (int i=1;i<=N;++i) omega[i]=1ll*omega[i-1]*g%P;
}

inline void deg_pre(int deg)
{
for (len=1;len<deg;len<<=1);
for (int i=0,ret;i<len;++i)
{
ret=0;
for (int x=i,j=1;j<len;j<<=1,x>>=1) ret=(ret<<1)|(x&1);
trs[i]=ret;
}
}

inline void DFT(int *a,int sig)
{
for (register int i=0;i<len;++i) t[trs[i]]=a[i];
for (register int l=2;l<=len;l<<=1)
for (register int h=l>>1,j=0,wn=omega[sig>0?N/l:N-N/l];j<len;j+=l)
for (register int i=0,w=1;i<h;++i,w=1ll*w*wn%P)
{
register int u=t[i+j],v=1ll*t[i+j+h]*w%P;
t[i+j]=add(u,v),t[i+j+h]=sub(u,v);
}
for (register int i=0;i<len;++i) a[i]=t[i];
if (sig<0) for (register int i=0,inv=quick_power(len,P-2);i<len;++i) a[i]=1ll*a[i]*inv%P;
}

inline void FFT(int *a,int *b,int *c)
{
DFT(a,1),DFT(b,1);
for (register int i=0;i<len;++i) c[i]=1ll*a[i]*b[i]%P;
DFT(c,-1);
}

void block()
{
siz=max((int)trunc(pow(1.*n*n/(log(n)/log(2)),1./3.)),1);
for (int src=0,tar;src<n;src=tar,++cnt)
{
tar=min(src+siz,n),st[cnt]=src,en[cnt]=tar-1;
for (int i=src;i<tar;++i) bel[i]=cnt;
}
}

inline int getid(int l,int r)
{
int en=cnt,st=cnt-l+1;
return (1ll*(st+en)*(en-st+1)>>1)+r-l+1;
}

void pre()
{
NTT_pre(),deg_pre(n),block();
for (int i=0,id;i<cnt;++i)
{
id=getid(i,i);
for (int j=0;j<len;++j) f[id][j]=1;
for (int j=st[i];j<=en[i];++j)
for (int k=0;k<len;++k)
f[id][k]=1ll*f[id][k]*(1ll*c[j]*omega[N/len*k]%P+1)%P;
}
for (int i=0;i<cnt;++i)
{
for (int j=i+1;j<cnt;++j)
for (int k=0,id=getid(i,j),id_=getid(i,j-1),id__=getid(j,j);k<len;++k)
f[id][k]=1ll*f[id_][k]*f[id__][k]%P;
for (int j=i;j<cnt;++j) DFT(f[getid(i,j)],-1);
}
}

void solve(int l,int r,int d=0)
{
if (l>r) return;
if (l==r){g[d][0]=1,g[d][1]=coef[l];return;}
int mid=l+r>>1,deg=r-l+2,deg0=mid-l+2,deg1=r-mid+1;
solve(l,mid,d+1),copy(g[d+1],g[d+1]+deg0,g[d]),solve(mid+1,r,d+1);
deg_pre(deg),fill(g[d]+deg0,g[d]+len,0),fill(g[d+1]+deg1,g[d+1]+len,0);
FFT(g[d],g[d+1],g[d]);
}

inline int calc(int l,int r,int k)
{
int ret=0,idx=0;
for (;l!=st[bel[l]]&&l<=r;++l) coef[idx++]=c[l];
for (;r!=en[bel[r]]&&l<=r;--r) coef[idx++]=c[r];
if (idx>30) solve(0,idx-1);
else
{
for (int i=0;i<=idx;++i) g[0][i]=0;
g[0][0]=1;
for (int i=0;i<idx;++i)
for (int j=idx;j>=1;--j)
Add(g[0][j],1ll*g[0][j-1]*coef[i]%P);
}
if (l<=r)
{
int id=getid(bel[l],bel[r]);
for (int i=0;i<=idx&&i<k;++i) Add(ret,1ll*g[0][i]*(k-i==len_?sub(f[id][0],1):f[id][k-i])%P);
if (k<=idx) Add(ret,g[0][k]);
}else ret=g[0][k];
return ret;
}

int main()
{
freopen("choose.in","r",stdin),freopen("choose.out","w",stdout);
n=read(),q=read();
for (int i=0;i<n;++i) c[i]=read();
pre(),len_=len,ans=0;
for (int i=1,l,r,k;i<=q;++i)
{
l=(read()+ans)%n,r=(read()+ans)%n,k=read();
ans=calc(l,r,k),write(ans),putchar('\n');
}
fclose(stdin),fclose(stdout);
return 0;
}