快速傅里叶变换(FFT)

梦开始的地方

注:以上指噩梦


在刚入门的时候想必我们都学过高精度乘法

仿照高精度乘法的思想,直接将两个 \(n\) 次多项式相乘的时空复杂度为 \(O(n^2)\)

这不够快,但是现在我们拿相乘的两层循环丝毫没有办法

需要另想方法来完成多项式相乘

我们知道,\(n+1\) 个点可以唯一确定一个 \(n\) 次多项式
例如 \(3\) 个点可以确定一个二次多项式 \(ax^2+bx+c\)

以下设相乘的两个多项式为 \(f(x)\) 和 \(g(x)\)
次数分别为 \(n\) 与 \(m\)

那么我们只要在 \(f(x)\) 与 \(g(x)\) 上分别找 \(n+m+1\)

点值相乘的时间复杂度只有 \(O(n)\)!

现在来看,我们找到了解决方案

不过等等,在 \(n\) 次多项式上找 \(n+m+1\) 个点的时间复杂度是 \(O(n^2)\)!

想办法使用分治!

对于一个多项式 \(f(x)=a_0+a_1x+a_2x^2+\dots+a_{n-1}x^{n-1}\)

(设:\(n=2^k\))

我们将奇次项分为一部分,将偶次项分为一部分
\(f(x)=(a_0+a_2x^2+\dots+a_{n-2}x^{n-2})+(a_1x^1+a_3x^3+\dots+a_{n-1}x^{n-1})\)

中间加号左边设为 \(fl(x)\)
\(fl(x)=a_0+a_2x+\dots+a_{n-2}x^{n/2-1}\)

右边设为 \(fr(x)\)
\(fr(x)=a_1+a_3x+\dots+a_{n-1}x^{n/2-1}\)

显然 \(f(x)=fl(x^2)+xfr(x^2)\)


我们令 \(x=\omega_n^k\)

\[\begin{align*} f(\omega_n^k)=fl((\omega_n^k)^2)+\omega_n^kfr((\omega_n^k)^2) \\ f(\omega_n^k)=fl(\omega_{n/2}^k)+\omega_n^kfr(\omega_{n/2}^k) \end{align*} \]

我们令 \(x=\omega_n^{k+n/2}\)

\[\begin{align*} f(\omega_n^{k+n/2})&=fl((\omega_n^{k+n/2})^2)+\omega_n^{k+n/2}fr((\omega_n^{k+n/2})^2) \\ &=fl(\omega_n^{2k+n})+\omega_n^{k+n/2}fr(\omega_n^{2k+n}) \\ &=fl(\omega_{n/2}^{k})+\omega_n^{k+n/2}fr(\omega_{n/2}^{k}) \\ f(\omega_n^{k+n/2})&=fl(\omega_{n/2}^{k})-\omega_n^kfr(\omega_{n/2}^{k}) \end{align*} \]

此时 \(f(\omega_n^k)\) 与 \(f(\omega_n^{k+n/2})\)

并且我们注意到 \(fl(x)\) 与 \(fr(x)\) 的性质与 \(f(x)\) 完全相同,我们可以对 \(fl(x)\) 和 \(fr(x)\)

至此,我们在 \(O(nlogn)\) 的时间内完成了将一个函数转化为点值表示的过程,可记作 \(\text{DFT}(f)\)


点值相乘是很简单的,之后我们需要把点值表示重新转化为系数表示

即求出 \(\text{IDFT}(\text{DFT}(f))\)

设点值向量 \(\vec{G}=\text{DFT}(f)=\{y_0,y_1,\dots,y_{n-1}\}\)

此时我们将 \(\vec{G}\) 当作系数向量再构造点值向量 \(\vec{H}=\sum\limits_{i=0}^{n-1}G_i(\omega_n^{-k})^i\)

即将单位根的倒数代入 \(G_0+G_1x^1+\dots+G_{n-1}x^{n-1}\),化简

\[\begin{align*} \vec{H}&=\sum\limits_{i=0}^{n-1}G_i(\omega_n^{-k})^i \\ &=\sum\limits_{i=0}^{n-1}(\sum\limits_{j=0}^{n-1}f_j(\omega_n^i)^j)(\omega_n^{-k})^i \\ &=\sum\limits_{i=0}^{n-1}(\sum\limits_{j=0}^{n-1}f_j\omega_n^{ij})\omega_n^{-ik} \\ &=\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{n-1}f_j(\omega_n^{j-k})^i \\ &=\sum\limits_{j=0}^{n-1}f_j(\sum\limits_{i=0}^{n-1}(\omega_n^{j-k})^i) \end{align*} \]

\(\sum\limits_{i=0}^{n-1}(\omega_n^{j-k})^i\)

通过对求和分类讨论可以得出结论:\(H_i=nf_i\Rightarrow f_i=\dfrac{H_i}{n}\)

所以我们可以将一开始对 \(f\) 进行DFT的取值取倒,再对 \(G\)

综上,本问题得到解决

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#define N 3000001
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define R register int
using namespace std;
const double pi=3.14159265358979323846264;
int n,m,x=1;
struct complex{
	double real,img;
	complex(double a=0,double b=0){real=a,img=b;}
	complex operator + (const complex &b) const{return complex(real+b.real,img+b.img);}
	complex operator - (const complex &b) const{return complex(real-b.real,img-b.img);}
	complex operator * (const complex &b) const{return complex(real*b.real-img*b.img,real*b.img+img*b.real);}
}f[N],g[N],save[N];
fx(void,FFT)(complex *f,int len,short s){
	if(len==1) return;
	int hlen=len>>1;
	complex *fl=f,*fr=f+hlen;
	for(int i=0;i<len;i++) save[i]=f[i];
	for(int i=0;i<hlen;i++){
		fl[i]=save[i<<1];
		fr[i]=save[i<<1|1];
	}
	FFT(fl,hlen,s);FFT(fr,hlen,s);
	complex dw(cos(2*pi/len),sin(2*pi/len)),now(1,0);
	dw.img*=s;
	for(int i=0;i<hlen;i++){
		save[i]=fl[i]+now*fr[i];
		save[i+hlen]=fl[i]-now*fr[i];
		now=now*dw;
	}
	for(int i=0;i<len;i++) f[i]=save[i];
}
signed main(){
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++) scanf("%lf",&f[i].real);
	for(int i=0;i<=m;i++) scanf("%lf",&g[i].real);
	while(x<=n+m) x<<=1;
	FFT(f,x,1),FFT(g,x,1);
	for(int i=0;i<x;i++) f[i]=f[i]*g[i];
	FFT(f,x,-1);
	for(int i=0;i<=n+m;i++) printf("%d ",(int)(f[i].real/x+0.5));
}

我们发现对于每层递归,我们都进行了数组拷贝,拷贝的重要目的就是为了实现 \(f(x)=fl(x^2)+xfr(x^2)\)

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
第一次变为
0 2 4 6 8 10 12 14|1 3 5 7 9 11 13 15
第二次变为
0 4 8 12|2 6 10 14|1 5 9 13|3 7 11 15
...

第一次其实就是把二进制上第 \(0\) 位为 \(0\) 的分成一组,为 \(1\)

第二次把二进制上第 \(1\) 位为 \(0\) 的分成一组,为 \(1\)

以此类推...

观察发现这其实就是二进制反转,我们可以 \(O(n)\)

for(R i=0;i<x;i++) bf[i]=(bf[i>>1]>>1)|((i&1)?x>>1:0);

最后我们将递归实现改为迭代实现

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#define N 2600010
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define R register int
using namespace std;
const double pi=3.14159265358979323846264;
int n,m,x=1,bf[N];
struct complex{
	double real,img;
	complex(double a=0,double b=0){real=a,img=b;}
	complex operator + (const complex &b) const{return complex(real+b.real,img+b.img);}
	complex operator - (const complex &b) const{return complex(real-b.real,img-b.img);}
	complex operator * (const complex &b) const{return complex(real*b.real-img*b.img,real*b.img+img*b.real);}
}f[N],g[N];
fx(void,FFT)(complex *f,short r){
	R l,hl,s,i;
	for(i=0;i<x;i++) if(i<bf[i]) swap(f[i],f[bf[i]]);
	for(l=2,hl=1;l<=x;hl=l,l<<=1){
		complex dw(cos(2*pi/l),r*sin(2*pi/l));
		for(s=0;s<x;s+=l){
			complex now(1,0);
			for(i=s;i<s+hl;i++){
				complex uni=now*f[i+hl];
				f[i+hl]=f[i]-uni;f[i]=f[i]+uni;
				now=now*dw;
			}
		}
	}
}
signed main(){
	scanf("%d%d",&n,&m);
	for(R i=0;i<=n;i++) scanf("%lf",&f[i].real);
	for(R i=0;i<=m;i++) scanf("%lf",&g[i].real);
	while(x<=n+m) x<<=1;
	for(R i=0;i<x;i++) bf[i]=(bf[i>>1]>>1)|((i&1)?x>>1:0);
	FFT(f,1),FFT(g,1);
	for(R i=0;i<x;i++) f[i]=f[i]*g[i];
	FFT(f,-1);
	for(R i=0;i<=n+m;i++) printf("%d ",(int)(f[i].real/x+0.5));
}

快速数论变换(NTT)

显然,因为各种三角函数参加计算,FFT会有精度丢失问题

所以我们需要找到一个单位根的替代品

但是数学家们已经证明了在复数域 \(\mathbb{C}\)

我们所有的计算都是在模意义下的

我们可以引入原根

此时只需要证明原根满足单位根的性质

证明大都很简单,故此处只证明单位根其中一条性质: \(\omega_n^k=-\omega_n^{k+\frac{n}{2}}\)

换成原根就是 \((g^{\frac{p-1}{n}})^k=-(g^{\frac{p-1}{n}})^{k+\frac{n}{2}}\pmod{p}\)

先进行简单的化简:

\[\begin{align*} &(g^{\frac{p-1}{n}})^k=-(g^{\frac{p-1}{n}})^{k+\frac{n}{2}}\pmod{p} \\ &(g^{\frac{p-1}{n}})^k(1+g^{\frac{p-1}{2}})=0\pmod{p} \end{align*} \]

我们看到原式中有一个负号,这不由得使我们想起了Wilson定理,接着我们逆推回去

\[\begin{align*} -1&\equiv\prod_{i=1}^{p-1}g^i\pmod{p} \\ &\equiv g^\frac{p(p-1)}{2}\pmod{p} \end{align*} \]

拆开,由费马小定理得:\(g^{\frac{p-1}{2}}\equiv-1\pmod{p}\)

所以原式得证

接下来把代码中的单位根全部换掉即可

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 2500001
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define R register int
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int x=1,n,m,br[N],f[N],g[N],prp[N],invp,invx;
fx(int,pow)(int a,int b){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,short r){
	R l,hl,exp,uni,s,i;
	for(i=0;i<x;++i) if(i<br[i]) swap(f[i],f[br[i]]);
	for(l=2,hl=1;l<=x;hl=l,l<<=1){
		exp=pow(r==1?pr:invp,(mod-1)/l);
		for(i=1;i<hl;i++) prp[i]=prp[i-1]*exp%mod;
		for(s=0;s<x;s+=l){
			for(i=0;i<hl;++i){
				uni=prp[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
}
signed main(){
	scanf("%d%d",&n,&m);
	for(R i=0;i<=n;++i) scanf("%d",&f[i]);
	for(R i=0;i<=m;++i) scanf("%d",&g[i]);
	while(x<=n+m) x<<=1;
	prp[0]=1,invp=pow(pr,mod-2),invx=pow(x,mod-2);
	for(R i=0;i<x;++i) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	NTT(f,1),NTT(g,1);
	for(R i=0;i<x;++i) (f[i]*=g[i])%=mod;
	NTT(f,-1);
	for(R i=0;i<=n+m;++i) printf("%d ",f[i]*invx%mod);
}

多项式乘法逆

由多项式乘法,我们可知如果要求一个多项式 \(f(x)\) 的乘法逆 \(g(x)\),就要:

\[\begin{align*} x_n&=\sum\limits_{i=0}^ng_if_{n-i}=[n=0] \\ &=\sum\limits_{i=0}^{n-1}g_if_{n-i}+g_nf_0=0 \\ g_n&=-\dfrac{\sum\limits_{i=0}^{n-1}g_if_{n-i}}{f_0} \end{align*} \]

递推,复杂度 \(O(n^2)\)

这个复杂度是很难让人满意的,所以我们想结合NTT,将其复杂度优化至 \(O(n\log n)\)

上文介绍FFT/NTT时,曾提到一个核心思想:分治

我们又可以很容易的发现常数项的逆元是很好求的

所以我们假设已经求出了 \(f^\prime(x)\),\(f^\prime(x)*f(x)=1 \pmod{x^\frac{n}{2}},g(x)*f(x)=1\pmod{x^n}\)

\[\begin{align*} fg-ff^\prime&\equiv0&\pmod{x^\frac{n}{2}} \\ g-f^\prime&\equiv0&\pmod{x^\frac{n}{2}} \\ g^2-2gf^\prime+f^{\prime2}&\equiv0&\pmod{x^n} \\ g-2f^\prime+gf^{\prime2}&\equiv0&\pmod{x^n} \\ g&\equiv2f^\prime-gf^{\prime2}&\pmod{x^n} \end{align*} \]

多项式最高项次数每次扩大 \(1\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 2000001
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define cpy(f,t,len) memcpy(t,f,sizeof(int)*len)
#define R register int
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int x=1,n,m,br[N],f[N],g[N],prp[N],invp,invx,h[N];
fx(int,pow)(int a,int b){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,short r,const int x){
	R l,hl,exp,uni,s,i;
	for(R i=0;i<x;++i) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	for(i=0;i<x;++i) if(i<br[i]) swap(f[i],f[br[i]]);
	for(l=2,hl=1;l<=x;hl=l,l<<=1){
		exp=pow(r==1?pr:invp,(mod-1)/l);
		for(i=1;i<hl;i++) prp[i]=prp[i-1]*exp%mod;
		for(s=0;s<x;s+=l){
			for(i=0;i<hl;++i){
				uni=prp[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
	if(r==-1){
		invx=pow(x,mod-2);
		for(R o=0;o<x;++o) (f[o]*=invx)%=mod;
	}
}
fx(void,inv)(int *f,int x){
	static int j[N],k[N];
	h[0]=pow(f[0],mod-2);
	for(int len=2,hl=1;len<=x;hl=len,len<<=1){
		for(int o=0;o<hl;o++) j[o]=(h[o]<<1)%mod;
		cpy(f,k,len);
		NTT(h,1,len<<1);
		for(R o=0;o<(len<<1);++o) (h[o]*=h[o])%=mod;
		NTT(k,1,len<<1);
		for(R o=0;o<(len<<1);++o) (h[o]*=k[o])%=mod;
		NTT(h,-1,len<<1);
		for(R o=0;o<len;++o) h[o]=(j[o]-h[o]+mod)%mod;
		memset(h+len,0,sizeof(int)*len);
		
	}
}
signed main(){
	scanf("%d",&n);
	for(R i=0;i<n;++i) scanf("%d",&f[i]);
	while(x<n) x<<=1;
	prp[0]=1,invp=pow(pr,mod-2);
	inv(f,x);
	for(R o=0;o<n;++o) cout<<h[o]<<" ";
}

多项式对数函数(多项式ln)

直接推导:

\[\begin{align*} B(x)&\equiv \ln A(x)\pmod{x^n} \\ B^\prime(x)&\equiv\dfrac{A^\prime(x)}{A(x)}\pmod{x^n} \\ B(x)&\equiv\int\dfrac{A^\prime(x)}{A(x)}\text{d}x\pmod{x^n} \end{align*} \]

按步骤直接做就可以

  1. 对 \(A(x)\)
  2. 对 \(A(x)\)
  3. NTT相乘
  4. 对 \(\dfrac{A^\prime(x)}{A(x)}\)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 1000001
#define INF 1100000000
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define R register int
#define C const
#define int long long
using namespace std;
C int mod=998244353,pr=3;
int n,br[N],ppr[N],x=1,invp,invx,A[N],B[N];
fx(int,gi)(){
	char c=getchar();int s=0,f=1;
	while(c<'0'||c>'9'){
		if(c=='-') f=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
	return s*f;
}
fx(int,pow)(int a,int b=mod-2){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,C short r,C int x){
	R len,hl,exp,uni,s,i;
	for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
	for(len=2,hl=1;len<=x;hl=len,len<<=1){
		exp=pow(r==1?pr:invp,(mod-1)/len);
		for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*exp%mod;
		for(s=0;s<x;s+=len){
			for(i=0;i<hl;i++){
				uni=ppr[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
	if(r==-1){
		invx=pow(x);
		for(i=0;i<x;i++) (f[i]*=invx)%=mod;
	}
}
fx(void,INV)(int *f,const int x){
	static int le[N],ri[N],inv[N];
	inv[0]=pow(f[0]);
	for(R len=2,hl=1,o;len<=x;hl=len,len<<=1){
		for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
		cpy(f,ri,int,len);
		NTT(inv,1,len<<1);NTT(ri,1,len<<1);
		for(o=0;o<(len<<1);o++) (((inv[o]*=inv[o])%=mod)*=ri[o])%=mod;
		NTT(inv,-1,len<<1);
		for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
		set(inv+len,0,int,len);
	}
	cpy(inv,f,int,n);
}
fx(void,DER)(int *f,C int len){
	for(int i=1;i<len;i++) f[i-1]=f[i]*i%mod;
	f[len-1]=0;
}
fx(void,INT)(int *f,C int len){
	for(R i=len-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
	f[0]=0;
}
signed main(){
	n=gi();
	for(R i=0;i<n;i++) A[i]=gi();
	cpy(A,B,int,n);
	while(x<n) x<<=1;
	invp=pow(pr);ppr[0]=1;
	INV(A,x);DER(B,n);
	while(x<n+n) x<<=1;
	NTT(A,1,x);NTT(B,1,x);
	for(R i=0;i<x;i++) (A[i]*=B[i])%=mod;
	NTT(A,-1,x);INT(A,n);
	for(R i=0;i<n;i++) printf("%lld ",A[i]);
}

多项式指数函数(多项式exp)

多项式牛顿迭代在这里不做过多叙述

此处 \(\exp A(x)\equiv B(x)\pmod{x^n}\)

两边取对数得 \(A(x)\equiv\ln B(x)\pmod{x^n}\)

即 \(\ln B(x)-A(x)\equiv0\pmod{x^n}\)

利用多项式牛顿迭代结果:

\[G(x)=G_0(x)-\dfrac{F(G_0(x))}{F^\prime(G_0(x))} \]

\(A(x)\) 是确定的,故 \(F(B(x))=\ln B(x)-A(x),F^\prime(B(x))=\dfrac1{B(x)}\)

\(B(x)=B_0(x)-B_0(x)(\ln B_0(x)-A(x))=B_0(x)(1-\ln B_0(x)+A(x))\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define int long long
#define R register int
#define C const
using namespace std;
C int mod=998244353,pr=3;
int f[N],ppr[N],br[N],expcp[N],exp[N],n,hx=1,x=1,invx,invp;
fx(int,gi)(){
	char c=getchar();int s=0,f=1;
	while(c>'9'||c<'0'){
		if(c=='-') f=-f;
		c=getchar();
	}
	while(c>='0'&&c<='9') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
	return s*f;
}
fx(int,pow)(int a,int b=mod-2){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
	R i,len,hl,s,expr,uni;
	for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
	for(len=2,hl=1;len<=x;hl=len,len<<=1){
		expr=pow(r?pr:invp,(mod-1)/len);
		for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
		for(s=0;s<x;s+=len){
			for(i=0;i<hl;i++){
				uni=ppr[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
	if(!r){
		invx=pow(x);
		for(i=0;i<x;i++) (f[i]*=invx)%=mod;
	}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
	inv[0]=pow(f[0]);
	for(R dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
		set(le+hl,0,int,len+hl);set(ri+len,0,int,len);
		for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
		cpy(f,ri,int,len);
		NTT(inv,1,dl),NTT(ri,1,dl);
		for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
		NTT(inv,0,dl);
		for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
		set(inv+len,0,int,len);
	}
	cpy(inv,f,int,x);set(inv,0,int,x);
}
fx(void,DER)(int *f,C int x){
	for(R i=1;i<x;i++) f[i-1]=f[i]*i%mod;
	f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
	for(R i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
	f[0]=0;
}
int lncp[N];
fx(void,LN)(int *f,C int hx,C int x){
	cpy(f,lncp,int,hx);set(lncp+hx,0,int,hx);
	INV(f,hx);DER(lncp,hx);
	NTT(f,1,x);NTT(lncp,1,x);
	for(R i=0;i<x;i++) (f[i]*=lncp[i])%=mod;
	NTT(f,0,x);INT(f,hx);
	set(f+hx,0,int,hx);
}
fx(void,EXP)(int *f,C int x){
	exp[0]=1;
	for(R hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
		set(expcp,0,expcp,1);
		cpy(exp,expcp,int,hl);
		LN(expcp,len,dl);
		for(i=0;i<len;i++) expcp[i]=(f[i]-expcp[i]+mod)%mod;
		(expcp[0]+=1)%mod;
		NTT(exp,1,dl);NTT(expcp,1,dl);
		for(i=0;i<dl;i++) (exp[i]*=expcp[i])%=mod;
		NTT(exp,0,dl);
		set(exp+len,0,int,len);
	}
	cpy(exp,f,int,n);
}
signed main(){
	ppr[0]=1;invp=pow(pr);
	n=gi();hx=1,x=1;
	set(f,0,f,1);
	for(R i=0;i<n;i++) f[i]=gi();
	while(x<n) x<<=1;
	EXP(f,x);
	for(R i=0;i<n;i++) printf("%lld ",f[i]);
	printf("\n");
}

多项式快速幂

\(A^k(x)\equiv B(x)\pmod{x^n}\)

两边取对数:\(\ln A^k(x)\equiv\ln B(x)\pmod{x^n}\)

\(k\ln A(x)\equiv\ln B(x)\pmod{x^n}\)

\(\text e^{k\ln A(x)}\equiv B(x)\pmod{x^n}\)

先取对数,之后 \(\exp\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define int long long
#define R register int
#define C const
using namespace std;
C int mod=998244353,pr=3;
int f[N],ppr[N],br[N],expcp[N],exp[N],n,x=1,invx,invp,k;
fx(int,gi)(){
	char c=getchar();int s=0,f=1;
	while(c>'9'||c<'0'){
		if(c=='-') f=-f;
		c=getchar();
	}
	while(c>='0'&&c<='9') s=((s<<3)+(s<<1)+(c-'0'))%mod,c=getchar();
	return s*f%mod;
}
fx(int,pow)(int a,int b=mod-2){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
	R i,len,hl,s,expr,uni;
	for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
	for(len=2,hl=1;len<=x;hl=len,len<<=1){
		expr=pow(r?pr:invp,(mod-1)/len);
		for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
		for(s=0;s<x;s+=len){
			for(i=0;i<hl;i++){
				uni=ppr[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
	if(!r){
		invx=pow(x);
		for(i=0;i<x;i++) (f[i]*=invx)%=mod;
	}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
	inv[0]=pow(f[0]);
	for(R dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
		set(le+hl,0,int,len+hl);set(ri+len,0,int,len);
		for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
		cpy(f,ri,int,len);
		NTT(inv,1,dl),NTT(ri,1,dl);
		for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
		NTT(inv,0,dl);
		for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
		set(inv+len,0,int,len);
	}
	cpy(inv,f,int,x);set(inv,0,int,x);
}
fx(void,DER)(int *f,C int x){
	for(R i=1;i<x;i++) f[i-1]=f[i]*i%mod;
	f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
	for(R i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
	f[0]=0;
}
int lncp[N];
fx(void,LN)(int *f,C int hx,C int x){
	cpy(f,lncp,int,hx);set(lncp+hx,0,int,hx);
	INV(f,hx);DER(lncp,hx);
	NTT(f,1,x);NTT(lncp,1,x);
	for(R i=0;i<x;i++) (f[i]*=lncp[i])%=mod;
	NTT(f,0,x);INT(f,hx);
	set(f+hx,0,int,hx);
}
fx(void,EXP)(int *f,C int x){
	exp[0]=1;
	for(R hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
		set(expcp,0,expcp,1);
		cpy(exp,expcp,int,hl);
		LN(expcp,len,dl);
		for(i=0;i<len;i++) expcp[i]=(f[i]-expcp[i]+mod)%mod;
		(expcp[0]+=1)%mod;
		NTT(exp,1,dl);NTT(expcp,1,dl);
		for(i=0;i<dl;i++) (exp[i]*=expcp[i])%=mod;
		NTT(exp,0,dl);
		set(exp+len,0,int,len);
	}
	cpy(exp,f,int,n);
}
signed main(){
	ppr[0]=1;invp=pow(pr);
	n=gi();k=gi();
	set(f,0,f,1);
	for(R i=0;i<n;i++) f[i]=gi();
	while(x<n) x<<=1;
	LN(f,x,x<<1);
	for(R i=0;i<n;i++) (f[i]*=k)%=mod;
	EXP(f,x);
	for(R i=0;i<n;i++) printf("%lld ",f[i]);
	printf("\n");
}

多项式开根

尝试使用牛顿迭代,\(F(G_0(x))=B_0^2(x)-A(x)\)

直接无脑推式子就可以

\[\begin{align*} G(x)=&G_0(x)-\dfrac{F(G_0(x))}{F^\prime(G_0(x))} \\ \Rightarrow B(x)=&B_0(x)-\dfrac{B_0^2(x)-A(x)}{2B_0(x)} \\ =&\dfrac{B_0^2(x)+A(x)}{2B_0(x)} \\ =&\dfrac{1}{2}\left(\dfrac{A(x)}{B_0(x)}+B_0(x)\right) \end{align*} \]

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define M 5001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define R register
#define C const
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int br[N],x=1,n,f[N],invp,invx,invt,ppr[N];
fx(int,gi)(){
	R char c=getchar();R int s=0,f=1;
	while(c>'9'||c<'0'){
		if(c=='-') f=-f;
		c=getchar();
	}
	while(c<='9'&&c>='0') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
	return s*f;
}
fx(int,pow)(int a,int b=mod-2){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
	R int i,len,hl,s,uni,expr;
	for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
	for(hl=1,len=2;len<=x;hl=len,len<<=1){
		expr=pow(r?pr:invp,(mod-1)/len);
		for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
		for(s=0;s<x;s+=len){
			for(i=0;i<hl;i++){
				uni=ppr[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
	if(!r){
		invx=pow(x);
		for(i=0;i<x;i++) (f[i]*=invx)%=mod;
	}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
	inv[0]=pow(f[0]);
	for(R int dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
		for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
		cpy(f,ri,int,len);
		NTT(inv,1,dl),NTT(ri,1,dl);
		for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
		NTT(inv,0,dl);
		for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
		set(inv+len,0,int,len);
	}
	cpy(inv,f,int,x);set(inv,0,int,x);
	set(le,0,int,x>>1);set(ri,0,int,x<<1);
}
int sqrt[N],scp[N],sac[N];
fx(void,SQRT)(int *f,int x){
	sqrt[0]=1;
	for(R int hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
		cpy(sqrt,scp,int,len);cpy(f,sac,int,len);
		INV(scp,len);
		NTT(sac,1,dl);NTT(scp,1,dl);
		for(i=0;i<dl;i++) (sac[i]*=scp[i])%=mod;
		NTT(sac,0,dl);
		for(i=0;i<len;i++) sqrt[i]=(sqrt[i]+sac[i])*invt%mod;
	}
	cpy(sqrt,f,int,x);
}
signed main(){
	ppr[0]=1;invp=pow(pr);invt=pow(2);
	n=gi();
	for(R int i=0;i<n;i++) f[i]=gi();
	while(x<n) x<<=1;
	SQRT(f,x);
	for(R int i=0;i<n;i++) printf("%lld ",f[i]);
}

多项式三角函数

根据欧拉公式:

\[\sin x=\dfrac{\text e^{ix}-\text e^{-ix}}{2i} \\ \cos x=\dfrac{\text e^{ix}+\text e^{-ix}}2 \]

类似NTT使用原根表示单位根,在模意义下 \(i=g^\frac{p-1}4\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define int long long
#define R register int
#define C const
using namespace std;
C int mod=998244353,pr=3;
int f[N],g[N],ppr[N],br[N],expcp[N],exp[N],n,x=1,invx,invp,ty,I,invt,invi;
fx(int,gi)(){
	char c=getchar();int s=0,f=1;
	while(c>'9'||c<'0'){
		if(c=='-') f=-f;
		c=getchar();
	}
	while(c>='0'&&c<='9') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
	return s*f;
}
fx(int,pow)(int a,int b=mod-2){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
	R i,len,hl,s,expr,uni;
	for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
	for(len=2,hl=1;len<=x;hl=len,len<<=1){
		expr=pow(r?pr:invp,(mod-1)/len);
		for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
		for(s=0;s<x;s+=len){
			for(i=0;i<hl;i++){
				uni=ppr[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
	if(!r){
		invx=pow(x);
		for(i=0;i<x;i++) (f[i]*=invx)%=mod;
	}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
	inv[0]=pow(f[0]);
	for(R dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
		for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
		cpy(f,ri,int,len);
		NTT(inv,1,dl),NTT(ri,1,dl);
		for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
		NTT(inv,0,dl);
		for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
		set(inv+len,0,int,len);
	}
	cpy(inv,f,int,x);set(inv,0,int,x);
	set(le,0,int,x>>1);set(ri,0,int,x<<1);
}
fx(void,DER)(int *f,C int x){
	for(R i=1;i<x;i++) f[i-1]=f[i]*i%mod;
	f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
	for(R i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
	f[0]=0;
}
int lncp[N];
fx(void,LN)(int *f,C int hx,C int x){
	cpy(f,lncp,int,hx);set(lncp+hx,0,int,hx);
	INV(f,hx);DER(lncp,hx);
	NTT(f,1,x);NTT(lncp,1,x);
	for(R i=0;i<x;i++) (f[i]*=lncp[i])%=mod;
	NTT(f,0,x);INT(f,hx);
	set(f+hx,0,int,hx);
}
fx(void,EXP)(int *f,C int x){
	set(expcp,0,expcp,1);set(exp,0,exp,1);exp[0]=1;
	for(R hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
		cpy(exp,expcp,int,hl);set(expcp+hl,0,int,hl);
		LN(expcp,len,dl);
		for(i=0;i<len;i++) expcp[i]=(f[i]-expcp[i]+mod)%mod;
		(expcp[0]+=1)%mod;
		NTT(exp,1,dl);NTT(expcp,1,dl);
		for(i=0;i<dl;i++) (exp[i]*=expcp[i])%=mod;
		NTT(exp,0,dl);
		set(exp+len,0,int,len);
	}
	cpy(exp,f,int,x);
}
fx(void,SIN)(int *f,int *g,C int x){
	EXP(f,x);EXP(g,x);
	for(R i=0;i<x;i++) f[i]=(f[i]-g[i]+mod)*invt%mod*invi%mod;
}
fx(void,COS)(int *f,int *g,C int x){
	EXP(f,x);EXP(g,x);
	for(R i=0;i<x;i++) f[i]=(f[i]+g[i])*invt%mod;
}
signed main(){
	ppr[0]=1;invp=pow(pr);I=pow(pr,(mod-1)/4);invt=pow(2);invi=pow(I);
	n=gi();ty=gi();
	for(R i=0;i<n;i++) f[i]=gi(),(f[i]*=I)%=mod,g[i]=mod-f[i];
	while(x<n) x<<=1;
	if(ty) COS(f,g,x);
	else SIN(f,g,x);
	for(R i=0;i<n;i++) printf("%lld ",f[i]);
	printf("\n");
}

多项式反三角函数

对于 \(\arcsin x\) 与 \(\arctan x\),我们可以将其求导,变成初等函数能表示的形式,然后积分

虽然字面上看这样没什么意义,求导再积分嘛,就和+1再-1一样,但是求导跟加减可不一样

\[\begin{align*} F(x)&\equiv\arcsin A(x)\pmod{x^n} \\ F^\prime(x)&\equiv\dfrac{A^\prime(x)}{\sqrt{1-A(x)^2}}\pmod{x^n} \\ F(x)&\equiv\int\dfrac{A^\prime(x)}{\sqrt{1-A(x)^2}}\pmod{x^n} \\ \\ F(x)&\equiv\arctan A(x)\pmod{x^n} \\ F^\prime(x)&\equiv\dfrac{A^\prime(x)}{1+A(x)^2}\pmod{x^n} \\ F(x)&\equiv\int\dfrac{A^\prime(x)}{1+A(x)^2}\pmod{x^n} \end{align*} \]

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define M 5001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define R register
#define C const
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int br[N],x=1,n,f[N],invp,invx,invt,ppr[N],g[N],ty;
fx(int,gi)(){
	R char c=getchar();R int s=0,f=1;
	while(c>'9'||c<'0'){
		if(c=='-') f=-f;
		c=getchar();
	}
	while(c<='9'&&c>='0') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
	return s*f;
}
fx(int,pow)(int a,int b=mod-2){
	int ans=1;
	while(b){
		if(b&1) (ans*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
	R int i,len,hl,s,uni,expr;
	for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
	for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
	for(hl=1,len=2;len<=x;hl=len,len<<=1){
		expr=pow(r?pr:invp,(mod-1)/len);
		for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
		for(s=0;s<x;s+=len){
			for(i=0;i<hl;i++){
				uni=ppr[i]*f[i|s|hl]%mod;
				f[i|s|hl]=(f[i|s]-uni+mod)%mod;
				f[i|s]=(f[i|s]+uni)%mod;
			}
		}
	}
	if(!r){
		invx=pow(x);
		for(i=0;i<x;i++) (f[i]*=invx)%=mod;
	}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
	inv[0]=pow(f[0]);
	for(R int dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
		for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
		cpy(f,ri,int,len);
		NTT(inv,1,dl),NTT(ri,1,dl);
		for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
		NTT(inv,0,dl);
		for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
		set(inv+len,0,int,len);
	}
	cpy(inv,f,int,x);set(inv,0,int,x);
	set(le,0,int,x>>1);set(ri,0,int,x<<1);
}
fx(void,DER)(int *f,C int x){
	for(R int i=1;i<x;i++) f[i-1]=f[i]*i%mod;
	f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
	for(R int i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
	f[0]=0;
}
int sqrt[N],scp[N],sac[N];
fx(void,SQRT)(int *f,C int x){
	sqrt[0]=1;
	for(R int len=2,dl=4,i;len<=x;len=dl,dl<<=1){
		cpy(sqrt,scp,int,len);cpy(f,sac,int,len);
		INV(scp,len);
		NTT(sac,1,dl);NTT(scp,1,dl);
		for(i=0;i<dl;i++) (sac[i]*=scp[i])%=mod;
		NTT(sac,0,dl);
		for(i=0;i<len;i++) sqrt[i]=(sqrt[i]+sac[i])*invt%mod;
	}
	cpy(sqrt,f,int,x);
}
fx(void,ARCSIN)(int *f,C int x){
	cpy(f,g,int,x);
	DER(g,x);NTT(f,1,x<<1);NTT(g,1,x<<1);
	for(R int i=0;i<(x<<1);i++) (f[i]*=f[i])%=mod;
	NTT(f,0,x<<1);
	set(f+x,0,int,x);
	for(R int i=0;i<x;i++) f[i]=mod-f[i];
	(f[0]+=1)%=mod;
	SQRT(f,x);INV(f,x);NTT(f,1,x<<1);
	for(R int i=0;i<(x<<1);i++) (f[i]*=g[i])%=mod;
	NTT(f,0,x<<1);INT(f,x);
}
fx(void,ARCTAN)(int *f,C int x){
	cpy(f,g,int,x);
	DER(g,x);NTT(f,1,x<<1);NTT(g,1,x<<1);
	for(R int i=0;i<(x<<1);i++) (f[i]*=f[i])%=mod;
	NTT(f,0,x<<1);set(f+x,0,int,x);
	(f[0]+=1)%=mod;
	INV(f,x);NTT(f,1,x<<1);
	for(R int i=0;i<(x<<1);i++) (f[i]*=g[i])%=mod;
	NTT(f,0,x<<1);INT(f,x);
}
signed main(){
	ppr[0]=1;invp=pow(pr);invt=pow(2);
	n=gi();ty=gi();
	for(R int i=0;i<n;i++) f[i]=gi();
	while(x<n) x<<=1;
	if(ty) ARCTAN(f,x);
	else ARCSIN(f,x);
	for(R int i=0;i<n;i++) printf("%lld ",f[i]);
}