【UOJ#34】多项式乘法
模板题,好好学习了一下FFT。
标准FFT:
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int maxn=100010; const double pi=M_PI; struct Complex { double r,i; Complex (double _r=0,double _i=0):r(_r),i(_i){} friend Complex operator + (const Complex &a,const Complex &b); friend Complex operator - (const Complex &a,const Complex &b); friend Complex operator * (const Complex &a,const Complex &b); }a[maxn<<2],b[maxn<<2]; Complex operator + (const Complex &a,const Complex &b) {return Complex(a.r+b.r,a.i+b.i);} Complex operator - (const Complex &a,const Complex &b) {return Complex(a.r-b.r,a.i-b.i);} Complex operator * (const Complex &a,const Complex &b) {return Complex(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r);} int rev[maxn<<2]; int n,m,N; void Init() { N=1;while (N<n+m-1) N<<=1; for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(N>>1)); } void FFT(Complex *arr,int f) { for (int i=0;i<N;i++) if (i<rev[i]) swap(arr[i],arr[rev[i]]); for (int i=1;i<N;i<<=1) { Complex wn(cos(pi/i),f*sin(pi/i)); for (int j=0;j<N;j+=(i<<1)) { Complex w(1,0); for (int k=0;k<i;k++,w=w*wn) { Complex x=arr[j+k],y=w*arr[j+k+i]; arr[j+k]=x+y,arr[j+k+i]=x-y; } } } if (f==-1) for (int i=0;i<N;i++) arr[i].r/=N; } int main() { scanf("%d%d",&n,&m);n++,m++; for (int i=0,x;i<n;i++) scanf("%d",&x),a[i].r=x; for (int i=0,x;i<m;i++) scanf("%d",&x),b[i].r=x; Init(); FFT(a,1),FFT(b,1); for (int i=0;i<N;i++) a[i]=a[i]*b[i]; FFT(a,-1); for (int i=0;i<n+m-1;i++) printf("%d ",(int)(a[i].r+0.1)); printf("\n"); return 0; }
快速数论变换:
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long LL; const LL PN=262144,p=8650753,g=10,gn=68387; LL a[PN],b[PN]; int rev[PN]; int n,m,N; int Pow(int x,int pow) { LL res=1,xx=x; while (pow) { if (pow&1) res=res*xx%p; xx=xx*xx%p; pow>>=1; } return res; } void FFT(LL *arr,int f) { for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(N>>1)); for (int i=0;i<N;i++) if (i<rev[i]) swap(arr[i],arr[rev[i]]); for (int i=1;i<N;i<<=1) { LL wn=Pow(gn,f==1?((PN/i)>>1):PN-((PN/i)>>1)); for (int j=0;j<N;j+=(i<<1)) { LL w=1; for (int k=0;k<i;k++,w=w*wn%p) { LL x=arr[j+k],y=w*arr[j+k+i]%p; arr[j+k]=(x+y)%p,arr[j+k+i]=(x-y+p)%p; } } } LL invN=Pow(N,p-2); if (f==-1) for (int i=0;i<N;i++) arr[i]=(arr[i]*invN)%p; } int main() { scanf("%d%d",&n,&m);n++,m++; for (int i=0;i<n;i++) scanf("%d",&a[i]); for (int i=0;i<m;i++) scanf("%d",&b[i]); N=1;while (N<n+m-1) N<<=1; FFT(a,1),FFT(b,1); for (int i=0;i<N;i++) a[i]=a[i]*b[i]%p; FFT(a,-1); for (int i=0;i<n+m-1;i++) printf("%d ",a[i]); printf("\n"); return 0; }