【UOJ#34】多项式乘法

Zarxdy34 posted @ 2016年2月27日 18:49 in UOJ with tags FFT , 540 阅读

  模板题,好好学习了一下FFT。

标准FFT:

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=100010;
const double pi=M_PI;

struct Complex
{
	double r,i;
	Complex (double _r=0,double _i=0):r(_r),i(_i){}
	friend Complex operator + (const Complex &a,const Complex &b);
	friend Complex operator - (const Complex &a,const Complex &b);
	friend Complex operator * (const Complex &a,const Complex &b);
}a[maxn<<2],b[maxn<<2];

Complex operator + (const Complex &a,const Complex &b) {return Complex(a.r+b.r,a.i+b.i);}
Complex operator - (const Complex &a,const Complex &b) {return Complex(a.r-b.r,a.i-b.i);}
Complex operator * (const Complex &a,const Complex &b) {return Complex(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r);}

int rev[maxn<<2];
int n,m,N;

void Init()
{
	N=1;while (N<n+m-1) N<<=1;
	for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(N>>1));
}

void FFT(Complex *arr,int f)
{
	for (int i=0;i<N;i++) if (i<rev[i]) swap(arr[i],arr[rev[i]]);
	for (int i=1;i<N;i<<=1)
	{
		Complex wn(cos(pi/i),f*sin(pi/i));
		for (int j=0;j<N;j+=(i<<1))
		{
			Complex w(1,0);
			for (int k=0;k<i;k++,w=w*wn)
			{
				Complex x=arr[j+k],y=w*arr[j+k+i];
				arr[j+k]=x+y,arr[j+k+i]=x-y;
			}
		}
	}
	if (f==-1) for (int i=0;i<N;i++) arr[i].r/=N;
}

int main()
{
	scanf("%d%d",&n,&m);n++,m++;
	for (int i=0,x;i<n;i++) scanf("%d",&x),a[i].r=x;
	for (int i=0,x;i<m;i++) scanf("%d",&x),b[i].r=x;
	Init();
	FFT(a,1),FFT(b,1);
	for (int i=0;i<N;i++) a[i]=a[i]*b[i];
	FFT(a,-1);
	for (int i=0;i<n+m-1;i++) printf("%d ",(int)(a[i].r+0.1));
	printf("\n");
	return 0;
}

 

快速数论变换:

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL PN=262144,p=8650753,g=10,gn=68387;

LL a[PN],b[PN];
int rev[PN];
int n,m,N;

int Pow(int x,int pow)
{
	LL res=1,xx=x;
	while (pow)
	{
		if (pow&1) res=res*xx%p;
		xx=xx*xx%p;
		pow>>=1;
	}
	return res;
}

void FFT(LL *arr,int f)
{
	for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(N>>1));
	for (int i=0;i<N;i++) if (i<rev[i]) swap(arr[i],arr[rev[i]]);
	for (int i=1;i<N;i<<=1)
	{
		LL wn=Pow(gn,f==1?((PN/i)>>1):PN-((PN/i)>>1));
		for (int j=0;j<N;j+=(i<<1))
		{
			LL w=1;
			for (int k=0;k<i;k++,w=w*wn%p)
			{
				LL x=arr[j+k],y=w*arr[j+k+i]%p;
				arr[j+k]=(x+y)%p,arr[j+k+i]=(x-y+p)%p;
			}
		}
	}
	LL invN=Pow(N,p-2);
	if (f==-1) for (int i=0;i<N;i++) arr[i]=(arr[i]*invN)%p;
}

int main()
{
	scanf("%d%d",&n,&m);n++,m++;
	for (int i=0;i<n;i++) scanf("%d",&a[i]);
	for (int i=0;i<m;i++) scanf("%d",&b[i]);
	N=1;while (N<n+m-1) N<<=1;
	FFT(a,1),FFT(b,1);
	for (int i=0;i<N;i++) a[i]=a[i]*b[i]%p;
	FFT(a,-1);
	for (int i=0;i<n+m-1;i++) printf("%d ",a[i]);
	printf("\n");
	return 0;
}

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter