【BZOJ3509】【CodeChef】COUNTARI
暴力优化一下是n^2的,但这样并不能过。
标算好像是分块做,设每块的大小为V。
对于三个数都在同一块内以及只有一个数在块外的情况,用暴力优化处理,时间复杂度\[\frac{n}{V} \cdot {V^2} = nV\]
对于中间数在块内的情况,设这个块的范围是[L,R],对所有的\[{a_i}(i < L)\]和\[{a_i}(i > R)\]分别构造生成函数并相乘,得到多项式,指数表示两数之和,系数表示方案数。然后枚举每个块内的数,看它作为中间数的方案有多少个并加入答案。
这一步需要用FFT,所以时间复杂度为\[{n \cdot maxa \cdot \log maxa} \over V\],取个合适的V就好了。
刚开始我用NTT来做,发现本机上跑得飞快,linux下也跑得飞快,然而交上去就是T。后来改成复数运算后A了。事实证明long long+取模运算在某些地方是不能乱用的。
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long LL; const int maxn=100010,V=2000; const double pi=M_PI; inline void read(int &x) {char ch;while ((ch=getchar())<'0' || ch>'9');x=ch-'0';while ((ch=getchar())<='9' && ch>='0') x=x*10+ch-'0';} struct Complex { double r,i; Complex (double _r=0,double _i=0):r(_r),i(_i){} friend Complex operator + (const Complex &a,const Complex &b); friend Complex operator - (const Complex &a,const Complex &b); friend Complex operator * (const Complex &a,const Complex &b); }; Complex operator + (const Complex &a,const Complex &b) {return Complex(a.r+b.r,a.i+b.i);} Complex operator - (const Complex &a,const Complex &b) {return Complex(a.r-b.r,a.i-b.i);} Complex operator * (const Complex &a,const Complex &b) {return Complex(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r);} int rev[maxn]; void FFT(Complex *a,int N,int f) { for (int i=0;i<N;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); for (int i=1;i<N;i<<=1) { Complex wn(cos(M_PI/i),f*sin(M_PI/i)); for (int j=0;j<N;j+=(i<<1)) { Complex w(1,0); for (int k=0;k<i;k++,w=w*wn) { Complex x=a[j+k],y=w*a[j+k+i]; a[j+k]=x+y,a[j+k+i]=x-y; } } } if (f==-1) for (int i=0;i<N;i++) a[i].r/=N; } Complex tempA[maxn],tempB[maxn],tempC[maxn]; void Polynomial_Multyply(int *a,int *b,int n,int m,LL *c) { int N=1; while (N<n+m-1) N<<=1; for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(N>>1)); for (int i=0;i<n;i++) tempA[i].r=a[i],tempA[i].i=0;for (int i=n;i<N;i++) tempA[i].r=tempA[i].i=0; for (int i=0;i<m;i++) tempB[i].r=b[i],tempB[i].i=0;for (int i=m;i<N;i++) tempB[i].r=tempB[i].i=0; FFT(tempA,N,1);FFT(tempB,N,1); for (int i=0;i<N;i++) tempC[i]=tempA[i]*tempB[i]; FFT(tempC,N,-1); for (int i=0;i<N;i++) c[i]=(LL)(tempC[i].r+0.1); } int a[maxn]; int lcnt[maxn],rcnt[maxn]; int n; long long ans; void Solve1() { for (int i=1;i<=n;i++) rcnt[a[i]]++; for (int st=1;st<n;st+=V) { int ed=min(n,st+V-1); for (int i=st;i<=ed;i++) rcnt[a[i]]--; for (int i=st;i<=ed;i++) { for (int j=i+1;j<=ed;j++) { if (a[i]<=(a[j]<<1))ans+=rcnt[(a[j]<<1)-a[i]]; if (a[j]<=(a[i]<<1))ans+=lcnt[(a[i]<<1)-a[j]]; } lcnt[a[i]]++; } } } int A[maxn],B[maxn]; LL C[maxn]; void Solve2() { for (int st=V+1;st<n;st+=V) { memset(A,0,sizeof(A)); memset(B,0,sizeof(B)); int ed=st+V-1; if (ed>=n) return; int na=0,nb=0; for (int i=1;i<st;i++) A[a[i]]++,na=max(na,a[i]); for (int i=ed+1;i<=n;i++) B[a[i]]++,nb=max(nb,a[i]); na++,nb++; Polynomial_Multyply(A,B,na,nb,C); for (int i=st;i<=ed;i++) ans+=C[a[i]*2]; } } int main() { read(n); for (int i=1;i<=n;i++) read(a[i]); Solve1(); Solve2(); printf("%lld\n",ans); return 0; }