【BZOJ3771】Triple
首先构造三个生成函数。
\[A(x) = {x^{{a_1}}} + {x^{{a_2}}} + {x^{{a_3}}} + \cdots + {x^{{a_n}}}\]
\[B(x) = {x^{2{a_1}}} + {x^{2{a_2}}} + {x^{2{a_3}}} + \cdots + {x^{2{a_n}}}\]
\[C(x) = {x^{3{a_1}}} + {x^{3{a_2}}} + {x^{3{a_3}}} + \cdots + {x^{3{a_n}}}\]
然后合并一下同类项。其中A(x)表示只有一个斧头,B(x)表示有两个斧头,C(x)表示有三个斧头。
多项式相乘后就可以发现它们的指数表示的是斧头的价值和,系数表示方案数。
那么答案就是如下得到的多项式\[A(x) + \frac{{{A^2}(x) - B(x)}}{2} + \frac{{{A^3}(x) - 3A(x)B(x) + 2C(x)}}{6}\]
指数为斧头价值和,系数为方案数。式中的三项分别为一个、两个、三个斧头的方案数。
#include <cstdio> #include <cstring> #include <algorithm> #include <complex> #include <cmath> using namespace std; const int maxn=262144; typedef complex<double> CP; int rev[maxn]; void FFT(CP *a,int N,int f) { for (int i=0;i<N;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); for (int i=1;i<N;i<<=1) { CP wn(cos(M_PI/i),f*sin(M_PI/i)); for (int j=0;j<N;j+=(i<<1)) { CP w(1,0); for (int k=0;k<i;k++,w*=wn) { CP x=a[j+k],y=w*a[j+k+i]; a[j+k]=x+y,a[j+k+i]=x-y; } } } if (f==-1) for (int i=0;i<N;i++) a[i]/=N; } CP tempA[maxn],tempB[maxn],tempC[maxn]; void Polynomial_Multyply(int *a,int *b,int n,int m,int *c) { int N=1; while (N<n+m-1) N<<=1; for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(N>>1)); for (int i=0;i<n;i++) tempA[i]=a[i];for (int i=n;i<N;i++) tempA[i]=0; for (int i=0;i<m;i++) tempB[i]=b[i];for (int i=m;i<N;i++) tempB[i]=0; FFT(tempA,N,1);FFT(tempB,N,1); for (int i=0;i<N;i++) tempC[i]=tempA[i]*tempB[i]; FFT(tempC,N,-1); for (int i=0;i<N;i++) c[i]=(int)(tempC[i].real()+0.1); } int A[maxn],B[maxn],C[maxn],temp[maxn],temp2[maxn],Ans[maxn]; int a[maxn]; int n,maxv; int main() { scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%d",&a[i]),maxv=max(maxv,a[i]); for (int i=1;i<=n;i++) A[a[i]]++,B[a[i]*2]++,C[a[i]*3]++; for (int i=1;i<=maxv;i++) Ans[i]+=A[i]; Polynomial_Multyply(A,A,maxv+1,maxv+1,temp); for (int i=1;i<=maxv*2;i++) Ans[i]+=(temp[i]-B[i])/2; Polynomial_Multyply(A,temp,maxv+1,maxv*2+1,temp); Polynomial_Multyply(A,B,maxv+1,maxv*2+1,temp2); for (int i=1;i<=maxv*3;i++) Ans[i]+=(temp[i]-3*temp2[i]+2*C[i])/6; for (int i=1;i<=maxv*3;i++) if (Ans[i]) printf("%d %d\n",i,Ans[i]); return 0; }