ZJT's Blog

大変に気分がいい~

[自选题 #127] Ball

2018-1-2 17:212018-1-2 17:22
集训队作业自选题生成函数FFT线性递推

传送门

考虑DP:$f_{i,j}$表示$i$个球分$j$组的方案数。显然$f_{i,j}=f_{i-1,j}+f_{i-1,j-1}+f_{i-2,j-1}$。

如果我们把$f_i$看成一个$x$多项式,则$f_i(x)=(x+1)f_{i-1}(x)+xf_{i-2}(x)$。这是个常系数线性递推,特征方程为$\lambda^2-(x+1)\lambda-x=0$,解出特征根为$\lambda_1=\frac{x+1+\sqrt{x^2+6x+1}}{2},\lambda_2=\frac{x+1-\sqrt{x^2+6x+1}}{2}$。

我们钦定$f_{-1}(x)=0$,此时转移方程在所有正整数处都成立。我们设$f_i(x)=c_1(x)\lambda_1^{i+1}(x)+c_2(x)\lambda_2^{i+1}(x)$,把$f_{-1},f_0$代进去,得到$c_1=\frac 1{\lambda_1-\lambda_2},c_2=\frac 1{\lambda_2-\lambda_1}$。所以$f_n$就能表示成:$$f_n(x)=\frac {\lambda_1^{n+1}(x)-\lambda_2^{n+1}(x)}{\lambda_1(x)-\lambda_2(x)}$$

直接开根把$\lambda$算出来之后$\ln+\exp$求幂再求逆就行了。

由于$m\leq n$(当$m>n$时后面那些显然就是$0$),而$\lambda_2$的常数项是$0$,所以$\lambda_2^{n+1}(x)\equiv 0\pmod {x^{m+1}}$,这意味着分子后面那项根本没有贡献,可以直接忽略,从而减小常数。

#include <bits/stdc++.h>
#define MAXN 1100000
#define MAXW 1048576
#define LL long long 
using namespace std;

const LL P=998244353;

int n,m,sizew;
LL w[MAXW],inv_v[MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

void init(){
    w[0]=1;
    w[1]=getPow(3,(P-1)/MAXW);
    for(int i=2;i<MAXW;i++) w[i]=w[i-1]*w[1]%P;
    inv_v[1]=1;
    for(int i=2;i<MAXN;i++)
        inv_v[i]=-inv_v[P%i]*(P/i)%P;
}

void FFT(LL *a,int n,int flag){
    static int rev[MAXN],revn;
    if(revn!=n){
        revn=n;
        for(int i=1;i<n;i++) rev[i]=rev[i>>1]>>1|((i&1)?(n>>1):0);
    }
    for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int l=2;l<=n;l<<=1){
        int l2=l>>1;
        for(int i=0;i<n;i+=l){
            for(int j=0;j<l2;j++){
                LL t1=a[i+j],t2=a[i+j+l2]*w[MAXW/l*j];
                a[i+j]=(t1+t2)%P;
                a[i+j+l2]=(t1-t2)%P;
            }
        }
    }
    if(flag==-1){
        LL invn=getPow(n,P-2);
        for(int i=0;i<n;i++) a[i]=a[i]*invn%P;
        for(int i=1;i<n;i++) if(i<n-i) swap(a[i],a[n-i]);
    }
}

void getInv(LL *b,LL *a,int len){
    if(len==1){
        b[0]=getPow(a[0],P-2);
        return;
    }
    static LL t1[MAXN],t2[MAXN];
    getInv(t1,a,len>>1);
    for(int i=len>>1;i<len<<1;i++) t1[i]=t2[i]=0;
    for(int i=0;i<len;i++) t2[i]=a[i];
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<len<<1;i++) t1[i]=(2*t1[i]-t2[i]*t1[i]%P*t1[i])%P;
    FFT(t1,len<<1,-1);
    for(int i=0;i<len;i++) b[i]=t1[i];
}

void getLn(LL *b,LL *a,int len){
    static LL t1[MAXN],t2[MAXN];
    for(int i=0;i<len;i++) t1[i]=a[i+1]*(i+1)%P,t2[i]=a[i];
    t1[len-1]=0;
    getInv(t2,t2,len);
    for(int i=0;i<len;i++) t1[i+len]=t2[i+len]=0;
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<len<<1;i++) t1[i]=t1[i]*t2[i]%P;
    FFT(t1,len<<1,-1);
    for(int i=1;i<len;i++) b[i]=t1[i-1]*inv_v[i]%P;
    b[0]=0;
}

void getExp(LL *b,LL *a,int len){
    if(len==1){
        b[0]=1;
        return;
    }
    static LL t1[MAXN],t2[MAXN],t3[MAXN];
    getExp(t1,a,len>>1);
    for(int i=0;i<len>>1;i++) t3[i]=t1[i],t1[i+(len>>1)]=t3[i+(len>>1)]=0;
    getLn(t2,t1,len);
    for(int i=0;i<len>>1;i++) t2[i]=t2[i+(len>>1)]-a[i+(len>>1)],t2[i+(len>>1)]=0;
    FFT(t3,len,1); FFT(t2,len,1);
    for(int i=0;i<len;i++) t3[i]=t3[i]*t2[i]%P;
    FFT(t3,len,-1);
    for(int i=0;i<len>>1;i++) b[i]=t1[i],b[i+(len>>1)]=-t3[i];
}

void getPow(LL *b,LL *a,int len,LL k){
    static LL t1[MAXN];
    getLn(t1,a,len);
    for(int i=0;i<len;i++) t1[i]=t1[i]*k%P;
    getExp(b,t1,len);
}

void solve(){
    for(sizew=1;sizew<=n;sizew<<=1);
    static LL t1[MAXN],t2[MAXN];
    t1[0]=1; t1[1]=6; t1[2]=1;
    getPow(t1,t1,sizew,inv_v[2]);
    for(int i=0;i<sizew;i++) t2[i]=t1[i];
    t2[0]=(t2[0]+1)%P; t2[1]=(t2[1]+1)%P;
    for(int i=0;i<sizew;i++) t2[i]=t2[i]*inv_v[2]%P;
    getPow(t2,t2,sizew,m+1);
    getInv(t1,t1,sizew);
    for(int i=0;i<sizew;i++) t1[i+sizew]=t2[i+sizew]=0;
    FFT(t1,sizew<<1,1); FFT(t2,sizew<<1,1);
    for(int i=0;i<sizew<<1;i++) t1[i]=t1[i]*t2[i]%P;
    FFT(t1,sizew<<1,-1);
    for(int i=1;i<=n;i++) printf("%lld ",(t1[i]+P)%P);
}

int main(){
#ifdef DEBUG
    freopen("127.in","r",stdin);
#endif
    scanf("%d%d",&m,&n);
    int _n=n;
    if(n>m) n=m;
    init();
    solve();
    for(int i=1;i<=_n-n;i++) printf("0 ");
    return 0;
}

查看详细内容

[自选题 #153] Comb Avoiding Trees

2017-11-2 19:512018-1-2 17:22
集训队作业自选题生成函数FFT线性递推

传送门

这题有毒。。。先是写了一个$O(k^2+n\log n)$的,然后Mike神犇说可以把$O(k^2)$变成$O(k\log^2k)$,然后我仔细一想这又可以变成$O(k\log k)$,接着又发现了神秘规律,常数又小了几倍。。。于是这个本来$n,k\leq5000$的题就强行做到了$n,k\leq10^6$。。。然后接着又发现可以出到$O(k\log k\log n)$求单项(orz Mike)。。。

先考虑$O(kn^2)$的做法,直接进行dp,设$f_{i,j}$为$i$个叶子,最多含有$j$连树(即不含$j+1$连树)的方案数,我们有转移$f_{i,j}=\sum_{k=1}^{i-1} f_{k,j-1}f_{i-k,j}$。

考虑优化这个东西,我们先找一找它的生成函数,设$f_{i,j}$的生成函数为$F_j(x)=\sum_i f_{i,j}x^i$。通过一些技巧可以发现,$F_j(x)=\frac x{1-F_{j-1}(x)}$。如果强行多项式求逆算的话,复杂度是$O(kn\log n)$,显然跑不过$5000$。其实我们直接把$F_j(x)$当成分式$A_j(x)/B_j(x)$算就行了。我们把$A,B$代入之前的式子,得到:$$F_j(x)=\frac x{1-F_{j-1}(x)}=\frac{xB_{j-1}(x)}{B_{j-1}(x)-A_{j-1}(x)}$$

这样,我们就只用处理多项式的减法和移位了,成功把复杂度降到了$O(k^2+n\log n)$。后面有个$O(n\log n)$是因为我懒得写暴力多项式求逆了,就直接写了FFT快速求逆。其实如果只是想跑过$5000$,直接暴力多项式求逆就行了。

为了进一步降到$O(k\log k+n\log n)$,我们观察一下刚才的$A,B$的转移:$A_j=xB_{j-1},B_j=B_{j-1}-A_{j-1}$,发现这是个常系数线性递推。所以可以直接矩阵快速幂套多项式乘法做到$O(k\log^2k)$,也可以直接把$O(k)$个单位根代进去跑,跑完之后插值回来,这样就是$O(k\log k)$了,加上后面的多项式求逆就是$O(k\log k+n\log n)$了。(其实还有一种常数更小的$O(k\log k)$的方法,可以找一找$A,B$的规律,具体是什么规律我就不说了,说了就没意思啦233)

然而对这道题的加强并没有结束,我们甚至可以在$O(k\log k\log n)$的时间内算出$A_k(x)/B_k(x)$的第$n$项(这里只能算出单项),大概就是根据暴力算多项式求逆的过程,我们可以把多项式的逆元表示成一个常系数线性递推的形式。这样,我们根据特征多项式和多项式取模的那一套理论,就可以在$O(k\log k\log n)$的时间内算出$B_k(x)$的逆元在$x^{n-k}$在$x^n$之间的系数。然后再卷积一下,就能算出$A_k(x)/B_k(x)$在$x^n$的系数啦。

代码(矩阵快速幂+插值+多项式求逆):

#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 530000
#define LL long long
using namespace std;

const LL P=998244353;

int n,m,sizew;
LL a[MAXN],b[MAXN];

struct Mat{
    LL a[2][2];
    Mat(){ memset(a,0,sizeof a); }
    friend Mat operator*(Mat x,Mat y){
        Mat res;
        res.a[0][0]=(x.a[0][0]*y.a[0][0]+x.a[0][1]*y.a[1][0])%P;
        res.a[0][1]=(x.a[0][0]*y.a[0][1]+x.a[0][1]*y.a[1][1])%P;
        res.a[1][0]=(x.a[1][0]*y.a[0][0]+x.a[1][1]*y.a[1][0])%P;
        res.a[1][1]=(x.a[1][0]*y.a[0][1]+x.a[1][1]*y.a[1][1])%P;
        return res;
    }
};

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

Mat getPow(Mat x,int y){
    Mat res;
    res.a[0][0]=res.a[1][1]=1;
    while(y){
        if(y&1) res=res*x;
        x=x*x;
        y>>=1;
    }
    return res;
}

void FFT(LL *a,int len,int flag){
    static int rev[MAXN],revlen;
    if(revlen^len){
        revlen=len;
        for(int i=1;i<len;i++) rev[i]=rev[i>>1]>>1|((i&1)?(len>>1):0);
    }
    for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int l=2;l<=len;l<<=1){
        LL w=getPow(3,(P-1)/l);
        int l2=l>>1;
        for(int i=0;i<len;i+=l){
            LL temp=1;
            for(int j=0;j<l2;j++){
                LL t1=a[i+j],t2=a[i+j+l2]*temp;
                a[i+j]=(t1+t2)%P;
                a[i+j+l2]=(t1-t2)%P;
                temp=temp*w%P;
            }
        }
    }
    if(flag==-1){
        LL invn=getPow(len,P-2);
        for(int i=0;i<len;i++) a[i]=a[i]*invn%P;
        for(int i=1;i<len;i++) if(i<len-i) swap(a[i],a[len-i]);
    }
}

void pre_gao(){
    LL wn=getPow(3,(P-1)/sizew),x=1;
    for(int i=0;i<sizew;i++){
        Mat trans;
        trans.a[0][0]=0; trans.a[0][1]=-1;
        trans.a[1][0]=x; trans.a[1][1]=1;
        x=x*wn%P;
        trans=getPow(trans,m-1);
        a[i]=trans.a[1][0];
        b[i]=trans.a[1][1];
    }
    FFT(a,sizew,-1);
    FFT(b,sizew,-1);
}

void getInv(LL *b,LL *a,int len){
    if(len==1){
        b[0]=1;
        return;
    }
    static LL t1[MAXN],t2[MAXN];
    getInv(t1,a,len>>1);
    for(int i=len>>1;i<len<<1;i++) t1[i]=0;
    for(int i=0;i<len;i++) t2[i]=a[i],t2[i+len]=0;
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<len<<1;i++) t1[i]=(2*t1[i]-t2[i]*t1[i]%P*t1[i])%P;
    FFT(t1,len<<1,-1);
    for(int i=0;i<len;i++) b[i]=t1[i];
}

void getMul(LL *b,LL *a,int len){
    LL t1[MAXN],t2[MAXN];
    for(int i=0;i<len;i++){
        t1[i]=b[i];
        t2[i]=a[i];
        t1[i+len]=t2[i+len]=0;
    }
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<len<<1;i++) t1[i]=t1[i]*t2[i]%P;
    FFT(t1,len<<1,-1);
    for(int i=0;i<len;i++) b[i]=t1[i];
}

void gao(){
    static LL t1[MAXN];
    getInv(t1,b,sizew);
    getMul(a,t1,sizew);
}

int main(){
    scanf("%d%d",&m,&n);
    for(sizew=1;sizew<=n || sizew<=m;sizew<<=1);
    pre_gao();
    gao();
    for(int i=1;i<=n;i++) printf("%lld\n",(a[i]+P)%P);
}

查看详细内容

集训队作业之AGC013-E

2017-11-2 17:72018-1-2 17:21
集训队作业AtCoderDP容斥线性递推

传送门

感觉这题非常不错啊,中间的思路挺妙的(虽然做的时候卡了一会儿)。

先考虑没有标记的情况,我们来搞一个dp,设$f_n$为总长为$n$的所有方案的美丽度之和,则我们有转移:$f_n=\sum k^2f_{n-k}$。考虑用矩乘算这个东西,我们记$$s_2=\sum_{k=1}^n f_k(n-k)^2,s_1=\sum_{k=1}^n f_k(n-k),s_0=\sum_{k=1}^n f_k$$,则每次$n:=n+1$时,所有$(n-k)^2$会变成$(n-k)^2+2(n-k)+1$,这时拿$s_1,s_0$去更新一下$s_2$就行了,同理也可以用$s_0$去更新$s_1$。新的$f_{n+1}$也可以从这几个东西里面得到,直接加到$s_0$里面去就行了。这个东西是个常系数线性递推,所以可以直接矩乘,这样我们就可以在$O(\log n)$的时间里面算出没有标记的一段的贡献了。

考虑有标记的情况,我们还是dp,设$f_i$表示完全覆盖左端点到第$i$个标记这一段的所有方案的美丽度之和(这里第$i$个之前的标记不能作为正方形的端点,但第$i$个标记必须作为最右边的正方形的右端点)。我们考虑容斥,用$g(len)$表示上面矩乘算出来的长度为$len$的答案,则$f_i=g(X_i)-\sum_{j<i}f_jg(X_i-X_j)$,这里$j$是枚举第一个不合法的标记的位置。直接这样容斥dp,维护一下矩阵的前缀乘积和前缀乘积的逆,就可以做到$O(M^2)$的复杂度。

为了优化这个做法,我们考虑每个$f_j$对后面所有$f_i$的贡献。当 $i:=i+1$ 时,贡献由$g(X_i-X_j)$变成了$g(X_{i+1}-X_j)$。由于$g(x)$可以表示成某个矩阵$T^x$中的某个位置的值,所以$g(X_{i+1}-X_j)$可以由$T^{X_i-X_j}$乘上$T^{X_{i+1}-X_i}$得到。由于每次 $i:=i+1$ ,所有$j$乘上的矩阵都是$T^{X_{i+1}-X_i}$,再根据矩阵乘法的分配律,我们可以直接带着矩阵来做,每次转移的时候都乘上$T^{X_{i+1}-X_i}$,就可以在$O(m\log n)$的时间里完成dp了。

#include <cstdio>
#include <cstring>
#define MAXN 100010
#define LL long long

const LL P=1000000007;

struct Mat{
    struct Data{
        LL a[3][3];
    }*data;

    Mat(){
        data=new Data;
        memset(data->a,0,sizeof data->a);
    }

    void release(){ delete data; }

    friend void mul(Mat &x,Mat y){
        static Mat res;
        for(int i=0;i<3;i++) 
            for(int j=0;j<3;j++){
                res.data->a[i][j]=0;
                for(int k=0;k<3;k++)
                    res.data->a[i][j]+=x.data->a[i][k]*y.data->a[k][j];
                res.data->a[i][j]%=P;
            }
        for(int i=0;i<3;i++)
            for(int j=0;j<3;j++)
                x.data->a[i][j]=res.data->a[i][j];
    }
}trans;

Mat getPow(Mat _x,int y){
    Mat res;
    for(int i=0;i<3;i++) res.data->a[i][i]=1;
    static Mat x;
    for(int i=0;i<3;i++) for(int j=0;j<3;j++) x.data->a[i][j]=_x.data->a[i][j];
    while(y){
        if(y&1) mul(res,x);
        mul(x,x);
        y>>=1;
    }
    return res;
}

void init(){
    trans.data->a[0][0]=1; trans.data->a[0][1]=1; trans.data->a[0][2]=1;
    trans.data->a[1][0]=0; trans.data->a[1][1]=1; trans.data->a[1][2]=2;
    trans.data->a[2][0]=1; trans.data->a[2][1]=1; trans.data->a[2][2]=2;
}

LL getG(int x){
    Mat temp=getPow(trans,x);
    return temp.data->a[0][2];
}

int n,m;
int p[MAXN];
LL f[MAXN];

void gao(){
    Mat temp;
    f[1]=getG(p[1]);
    temp.data->a[0][0]=f[1];
    for(int i=2;i<=m;i++){
        mul(temp,getPow(trans,p[i]-p[i-1]));
        f[i]=(getG(p[i])-temp.data->a[0][2]+P)%P;
        temp.data->a[0][0]=(temp.data->a[0][0]+f[i])%P;
    }
}

int main(){
#ifdef DEBUG
    freopen("E.in","r",stdin);
#endif
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++) scanf("%d",p+i);
    p[++m]=n;
    init();
    gao();
    printf("%lld\n",f[m]);
    return 0;
}

查看详细内容