ZJT's Blog

大変に気分がいい~

集训队作业之AGC019-F

2017-12-13 18:442017-12-13 18:44
集训队作业AtCoder组合数学

传送门

先考虑朴素的$O(n^2)$DP:$f_{i,j}=\frac{i}{i+j}f_{i-1,j}+\frac{j}{i+j}f_{i,j-1}+\frac{\max(i,j)}{i+j}$,最后一项取max是因为当前最优决策肯定是猜更多的那一项。

现在考虑后面那项$\frac{\max(i,j)}{i+j}$对$f_{n,m}$的贡献。推推推之后得到:$$f_{n,m}=\sum_{0\leq x\leq n}\sum_{0\leq y\leq m}\frac{\binom nx\binom ny}{\binom{n+m}{x+y}}\frac{\max(n-x,m-y)}{n+m-(x+y)}$$

然后因为后面有个$\max$,不能直接卷积,要分治FFT,显然跑不过去这题,最多跑个部分分。。。

结果并没有成功想到更低复杂度的做法。然后去膜了别人的题解,发现有两种做法,都是线性的。一种是题解的做法,比较复杂,我没怎么看,就不管了。

另一种做法是把状态表示在二维平面上,点$(x,y)$表示当前有$x$个Yes,$y$个No。我们考虑直线$y=x$,我们可以钦定直线下的点(即$y<x$)都回答Yes,其他点(即$y\geq x$)都回答No,这样的策略显然是最优的。这里我们假定$n\geq m$。每个题目中的$n+m$长的序列都对应了$(n,m)$到$(0,0)$的一条路径,通过观察可以发现这条路径猜对的次数是$(n+$这条路径在直线$y=x$上往下拐的次数$)$。 由于直线 $y=x$ 经过的点数是$O(n)$的,我们可以直接枚举对角线上的点,计算出每个点的贡献即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 1000010
#define LL long long
using namespace std;

const LL P=998244353;

int n,m;
LL fac[MAXN],invfac[MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

void init(){
    fac[0]=1;
    for(int i=1;i<MAXN;i++) fac[i]=fac[i-1]*i%P;
    invfac[MAXN-1]=getPow(fac[MAXN-1],P-2);
    for(int i=MAXN-2;i>=0;i--) invfac[i]=invfac[i+1]*(i+1)%P;
}

LL getC(int x,int y){
    return fac[x]*invfac[y]%P*invfac[x-y]%P;
}

LL calc(int x,int y){
    return getC(x+y,x);
}

int main(){
    scanf("%d%d",&n,&m);
    init();
    if(n<m) n^=m^=n^=m;
    LL ans=calc(n,m)*n%P;
    for(int i=1;i<=m;i++)
        ans=(ans+calc(n-i,m-i)*calc(i-1,i))%P;
    ans=ans*getPow(calc(n,m),P-2)%P;
    printf("%lld\n",ans);
    return 0;
}

查看详细内容

集训队作业之AGC018-E

2017-11-23 13:552017-11-23 13:55
集训队作业AtCoder容斥组合数学FFT

传送门

看了题就先瞎推了一下,不知道怎么就推成了一个卷积的形式,然后一看模数$10^9+7$,范围是$10^6$,感觉AtCoder的机子那么牛逼应该可以强行任意模数FFT跑过去。然后就写了一个胡乱卷积,结果爆精度了。。。double改成long double之后又T了。。。然后尝试各种卡常+卡精度技巧,结果一点效果都没有。然后又仔细想了一下,发现这题随便$O(n)$。。。心情复杂.jpg

首先把起点、终点的限制容斥一下,问题转化成了在一个大框里选两个点作为起点和终点,在里面的一个小框里面选一个中间点,要求起点->中间点->终点路径数的和。对于某个中间点,设它到大框左端的距离是$d_x$,到顶端的距离是$d_y$,则在它左上方选一个起点走到它的路径总数就是:

$$ \begin{aligned} \sum_{x=0}^{d_x}\sum_{y=0}^{d_y}\binom{x+y}{x}&=\sum_{x=0}^{d_x}\binom{x+d_y+1}{x+1}\\ &=\sum_{x=0}^{d_x+1}\binom{x+d_y+1}{x+1}-1\\ &=\binom{d_x+d_y+2}{d_x+1}-1 \end{aligned} $$

后面那个$-1$在容斥的时候会被抵消掉,我们可以不用管他。

然而直接枚举中间点是$O(n^2)$的,为了减少这个复杂度,我想到了两个做法,一个是一开始那个爆精度的傻逼卷积做法(过不了),一个是比较简单的线性做法(能过)。

先讲一下那个过不了的弱智做法,就是考虑每条起点到终点的路径,如果穿过了中间那个小框,那么这条路径的贡献就是路径跟小框的交集的点数(这里面的每个点都能作为中间点产生$1$的贡献)。穿过小框的路径一定是从左边或者上方进入小框,从右边或者下方离开小框。我们根据入口、出口的位置,分左右、左下、上右、上下四种情况考虑,发现其实是一个卷积。然而这题模数是$10^9+7$,用按$\sqrt P$分段的那种任意模数FFT的话,double会爆精度,long double会T。。。这还是在AtCoder的机器上,在我的机子上double都要跑10s。。。

线性的做法也是类似,考虑每条路径穿过中间那个框的点数,但是把中间那个框容斥了一下,小框变成了大框左上角的一部分,这样就只用考虑出口了。由于随便选起点的路径总数能表示成一个组合数$\binom{d_x+d_y+2}{d_x+1}$,所以可以直接把起点看成固定点,终点也是一样。起点固定的情况下,起点->出口经过的点数也是确定的,所以直接枚举出口统计一下就行了,复杂度是$O(n)$(虽然因为容斥带了一个64的常数)。

代码($O(n)$,能过):

#include <cstdio>
#include <cstring>
#include <cassert>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#define MAXN 2100000
#define LL long long
#define y1 zjtsb_y1
using namespace std;

const LL P=1000000007;

LL fac[MAXN],invfac[MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

LL getC(int x,int y){
    if(x<y) return 0;
    return fac[x]*invfac[y]%P*invfac[x-y]%P;
}

void init(){
    fac[0]=1;
    for(int i=1;i<MAXN;i++) fac[i]=fac[i-1]*i%P;
    invfac[MAXN-1]=getPow(fac[MAXN-1],P-2);
    for(int i=MAXN-2;i>=0;i--) invfac[i]=invfac[i+1]*(i+1)%P;
}

LL calc(int x1,int x2,int y1,int y2){
    if(!x1 || !x2 || !y1 || !y2) return 0;
    LL res=0;
    for(int i=1;i<=x1;i++)
        res=(res+getC(y1+i-2,i-1)*getC(y2+x1+x2-i-1,y2-1)%P*(y1+i-1))%P;
    for(int i=1;i<=y1;i++)
        res=(res+getC(x1+i-2,i-1)*getC(x2+y1+y2-i-1,x2-1)%P*(x1+i-1))%P;
    return res;
}

int x1,x2,x3,x4,x5,x6;
int y1,y2,y3,y4,y5,y6;

LL gao(int sx,int sy,int ex,int ey){
    LL res=0;
    res+=calc(x4-sx+1,ex-x4,y4-sy+1,ey-y4);
    res-=calc(x3-sx,ex-x3+1,y4-sy+1,ey-y4);
    res-=calc(x4-sx+1,ex-x4,y3-sy,ey-y3+1);
    res+=calc(x3-sx,ex-x3+1,y3-sy,ey-y3+1);
    return (res%P+P)%P;
}

int main(){
#ifdef DEBUG
    freopen("E.in","r",stdin);
#endif
    scanf("%d%d%d%d%d%d",&x1,&x2,&x3,&x4,&x5,&x6);
    scanf("%d%d%d%d%d%d",&y1,&y2,&y3,&y4,&y5,&y6);
    init();
    LL ans=0;
    for(int i=0;i<16;i++){
        int c=0;
        int sx=x1-1,sy=y1-1,ex=x6+1,ey=y6+1;
        if(i&1){
            c++;
            sx=x2;
        }
        if(i&2){
            c++;
            sy=y2;
        }
        if(i&4){
            c++;
            ex=x5;
        }
        if(i&8){
            c++;
            ey=y5;
        }
        if(c&1) ans-=gao(sx,sy,ex,ey);
        else ans+=gao(sx,sy,ex,ey);
    }
    ans=(ans%P+P)%P;
    printf("%lld\n",ans);
    return 0;
}

代码($O(n\log n)$,爆精度):

#include <cstdio>
#include <cstring>
#include <cassert>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#define MAXN 2100000
#define LL long long
#define y1 zjtsb_y1
using namespace std;

const int MAXW=2097152;
const LL P=1000000007;
const long double PI=acos(-1.0);

namespace FFT{
    struct cplx{
        long double r,i;
        cplx(long double _r=0,long double _i=0):r(_r),i(_i){}
        friend cplx operator+(cplx x,cplx y){ return cplx(x.r+y.r,x.i+y.i); }
        friend cplx operator-(cplx x,cplx y){ return cplx(x.r-y.r,x.i-y.i); }
        friend cplx operator*(cplx x,cplx y){ return cplx(x.r*y.r-x.i*y.i,x.r*y.i+x.i*y.r); }
    }wn[MAXW];

    void init(){
        for(int i=0;i<MAXW;i++) wn[i]=cplx(cos(2*PI/MAXW*i),sin(2*PI/MAXW*i));
    }

    void fft(cplx *a,int len,int flag){
        static int rev[MAXN],revlen;
        if(revlen!=len){
            revlen=len;
            for(int i=1;i<len;i++) rev[i]=rev[i>>1]>>1|((i&1)?(len>>1):0);
        }
        for(int i=0;i<len;i++)
            if(i<rev[i])
                swap(a[i],a[rev[i]]);
        for(int l=2;l<=len;l<<=1){
            int l2=l>>1;
            for(int i=0;i<len;i+=l)
                for(int j=0;j<l2;j++){
                    cplx t1=a[i+j],t2=a[i+j+l2]*wn[MAXW/l*j];
                    a[i+j]=t1+t2;
                    a[i+j+l2]=t1-t2;
                }
        }
        if(flag==-1){
            for(int i=0;i<len;i++) a[i].r/=len;
            for(int i=1;i<len;i++)
                if(i<len-i) swap(a[i],a[len-i]);
        }
    }
}

LL fac[MAXN],invfac[MAXN];
int n,m;
int x1,x2,x3,x4,x5,x6;
int y1,y2,y3,y4,y5,y6;
LL f1[MAXN],f2[MAXN],f3[MAXN],f4[MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

LL getC(int x,int y){
    if(x<y) return 0;
    return fac[x]*invfac[y]%P*invfac[x-y]%P;
}

LL calcG(int x,int y){
    return getC(x+y+2,x+1)-1;
}

void init(){
    fac[0]=1;
    for(int i=1;i<MAXN;i++) fac[i]=fac[i-1]*i%P;
    invfac[MAXN-1]=getPow(fac[MAXN-1],P-2);
    for(int i=MAXN-2;i>=0;i--) invfac[i]=invfac[i+1]*(i+1)%P;
    FFT::init();
}

void gaoF(){
    for(int i=1;i<=m;i++){
        int x=x3-1,y=y3-1+i;
        f1[i]=((calcG(x-x1,y-y1)-calcG(x-x2-1,y-y1)-calcG(x-x1,y-y2-1)+calcG(x-x2-1,y-y2-1))%P+P)%P;
    }
    for(int i=1;i<=n;i++){
        int x=x3-1+i,y=y3-1;
        f2[i]=((calcG(x-x1,y-y1)-calcG(x-x2-1,y-y1)-calcG(x-x1,y-y2-1)+calcG(x-x2-1,y-y2-1))%P+P)%P;
    }
    for(int i=1;i<=m;i++){
        int x=x4+1,y=y3-1+i;
        f3[i]=((calcG(x6-x,y6-y)-calcG(x5-x-1,y6-y)-calcG(x6-x,y5-y-1)+calcG(x5-x-1,y5-y-1))%P+P)%P;
    }
    for(int i=1;i<=n;i++){
        int x=x3-1+i,y=y4+1;
        f4[i]=((calcG(x6-x,y6-y)-calcG(x5-x-1,y6-y)-calcG(x6-x,y5-y-1)+calcG(x5-x-1,y5-y-1))%P+P)%P;
    }
}

void mul(LL *a,LL *b,LL *c,int l1,int l2){
    using namespace FFT;
    const LL M=32000;
    static cplx t1[MAXN],t2[MAXN];
    static LL s1[MAXN],s2[MAXN],s3[MAXN];
    int sizew;
    for(sizew=1;sizew<=l1+l2;sizew<<=1);
    for(int i=0;i<sizew;i++) t1[i]=t2[i]=cplx();
    for(int i=0;i<=l1;i++) t1[i].r=(a[i]/M)+(a[i]%M);
    for(int i=0;i<=l2;i++) t2[i].r=(b[i]/M)+(b[i]%M);
    fft(t1,sizew,1); fft(t2,sizew,1);
    for(int i=0;i<sizew;i++) t1[i]=t1[i]*t2[i];
    fft(t1,sizew,-1);
    for(int i=0;i<sizew;i++) s1[i]=t1[i].r+0.5;

    for(int i=0;i<sizew;i++) t1[i]=t2[i]=cplx();
    for(int i=0;i<=l1;i++) t1[i].r=a[i]/M;
    for(int i=0;i<=l2;i++) t2[i].r=b[i]/M;
    fft(t1,sizew,1); fft(t2,sizew,1);
    for(int i=0;i<sizew;i++) t1[i]=t1[i]*t2[i];
    fft(t1,sizew,-1);
    for(int i=0;i<sizew;i++) s2[i]=t1[i].r+0.5;

    for(int i=0;i<sizew;i++) t1[i]=t2[i]=cplx();
    for(int i=0;i<=l1;i++) t1[i].r=a[i]%M;
    for(int i=0;i<=l2;i++) t2[i].r=b[i]%M;
    fft(t1,sizew,1); fft(t2,sizew,1);
    for(int i=0;i<sizew;i++) t1[i]=t1[i]*t2[i];
    fft(t1,sizew,-1);
    for(int i=0;i<sizew;i++) s3[i]=t1[i].r+0.5;

    for(int i=0;i<sizew;i++){
        s1[i]%=P;
        s2[i]%=P;
        s3[i]%=P;
    }
    for(int i=0;i<sizew;i++) s1[i]-=s2[i]+s3[i];
    for(int i=0;i<sizew;i++) c[i]=(s3[i]+M*s1[i]+M*M%P*s2[i])%P;
}

LL gao(){
    static LL t1[MAXN],t2[MAXN],t3[MAXN];
    LL res=0;
    //f1*f3
    for(int i=1;i<=m;i++) t1[i]=f1[i],t2[i]=f3[m-i+1];
    mul(t1,t2,t3,m,m);
    for(int i=0;i<=m-1;i++){
        LL t=t3[m+1-i];
        res=(res+t*(n+i)%P*getC(n-1+i,i))%P;
    }
    //f2*f4
    for(int i=1;i<=n;i++) t1[i]=f2[i],t2[i]=f4[n-i+1];
    mul(t1,t2,t3,n,n);
    for(int i=0;i<=n-1;i++){
        LL t=t3[n+1-i];
        res=(res+t*(m+i)%P*getC(m-1+i,i))%P;
    }
    //f1*f4
    for(int i=0;i<m;i++) t1[i]=f1[m-i]*invfac[i]%P;
    for(int i=0;i<n;i++) t2[i]=f4[i+1]*invfac[i]%P;
    mul(t1,t2,t3,m-1,n-1);
    for(int i=0;i<=n+m-2;i++){
        LL t=t3[i];
        res=(res+t*fac[i]%P*(i+1))%P;
    }
    //f2*f3
    for(int i=0;i<n;i++) t1[i]=f2[n-i]*invfac[i]%P;
    for(int i=0;i<m;i++) t2[i]=f3[i+1]*invfac[i]%P;
    mul(t1,t2,t3,n-1,m-1);
    for(int i=0;i<=n+m-2;i++){
        LL t=t3[i];
        res=(res+t*fac[i]%P*(i+1))%P;
    }
    return res;
}

int main(){
#ifdef DEBUG
    freopen("E.in","r",stdin);
#endif
    scanf("%d%d%d%d%d%d",&x1,&x2,&x3,&x4,&x5,&x6);
    scanf("%d%d%d%d%d%d",&y1,&y2,&y3,&y4,&y5,&y6);
    n=x4-x3+1;
    m=y4-y3+1;
    init();
    gaoF();
    LL ans=gao();
    printf("%lld\n",(ans%P+P)%P);
    return 0;
}

查看详细内容

集训队作业之AGC019-E

2017-10-26 15:362017-10-30 7:31
集训队作业AtCoderFFT生成函数组合数学

传送门

感觉这题出的很niubi啊,tourist还是太强了。。。

记有$n_1$位满足$A_i=B_i=1$,有$n_2$位满足$A_i=1,B_i=0$(也等于满足$A_i=0,B_i=1$的位数)。

考虑一个$n_1$中的位$i$(就是$A_i,B_i$都是1),如果他跟另一个$n_1$中的位$j$交换了,那另一位就再也不能跟任何满足$A_k=0,B_k=1$的位$k$交换了,因为跟$k$交换之后会导致第$j$位变成$A_j=0,B_j=1$,而在$b$序列中,$j$这一位已经在跟$i$交换时用掉了,所以第$j$位再也不能参与交换,自然$A_j,B_j$就不可能相等了。

所以,我们可以把这$n1$位中,所有参与了互相交换的位拿出来,这些位跟剩下的东西是独立的,他们可以随便交换,反正不影响结果。设剩下有$k$位满足$A_i=B_i=1$,这些位只能跟$A_i=0,B_i=1$的位交换。设当前有$x$位满足$A_i=0,B_i=1$,$x$初始为$n_2$,则剩下来的每一位都只能在这$x$位中选一位交换,并且$n_2$中的位参与交换之后会导致$x$减1。当这$k+n_2$位在$a$中的顺序确定后,对答案的贡献一定是类似

$$n_2\times n_2\times\dots\times n_2\times(n_2-1)\times(n_2-1)\times\dots\times1\times1$$

的形式,这里总共有$k+n_2$个数乘在一起。由于每个值至少出现一次(来自$n_2$中的那些位的贡献),我们可以把每种值去掉一个,最后再乘上去。记$i$个数的这种乘积的和为$s_i$,它的生成函数是$S=\sum s_ix^i$,则:

$$ \begin{align} S&=(1+n_2x+n_2^2x^2+...)(1+(n_2-1)x+(n_2-1)^2x^2+...)\\ &...(1+x+x^2+...)\\ &=\frac1{1-n_2x}\frac1{1-(n_2-1)x}...\frac1{1-x}\\ &=\frac1{\prod_{1\leq i\leq n_2}(1-ix)} \end{align} $$

这个东西直接分治乘起来,再做一个多项式求逆就能算出来了。算完这个就好办了,直接枚举$k$,把所有的贡献加起来就行了:

$$ ans=\sum_k\binom{n_1}kk!(n_2!)^2((n_1-k)!)^2\binom{n_1+n_2}{n_1-k}[x^k]S $$

#include <cstdio>
#include <cstring>
#include <cassert>
#include <algorithm>
#define MAXN 66010
#define LL long long 
using namespace std;

const LL P=998244353;

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

int n,n1,n2,sizew;
LL fac[MAXN],invfac[MAXN];
LL *a[MAXN],b[MAXN];

void init(){
    static char s1[MAXN],s2[MAXN];
    scanf("%s",s1+1);
    scanf("%s",s2+1);
    int len=strlen(s1+1);
    for(int i=1;i<=len;i++)
        if(s1[i]=='1' && s2[i]=='1') n1++;
        else if(s1[i]=='1' || s2[i]=='1') n2++;
    n2/=2;
    n=n1+n2;
    fac[0]=1;
    for(int i=1;i<MAXN;i++) fac[i]=fac[i-1]*i%P;
    invfac[MAXN-1]=getPow(fac[MAXN-1],P-2);
    for(int i=MAXN-2;i>=0;i--) invfac[i]=invfac[i+1]*(i+1)%P;
}

void FFT(LL *a,int len,int flag){
    static int rev[MAXN];
    for(int i=1;i<len;i++){
        rev[i]=rev[i>>1]>>1|((i&1)?(len>>1):0);
        if(i<rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int l=2;l<=len;l<<=1){
        LL w=getPow(3,(P-1)/l);
        int l2=l/2;
        for(int i=0;i<len;i+=l){
            LL temp=1;
            for(int j=0;j<l2;j++){
                LL t1=a[i+j],t2=a[i+j+l2]*temp;
                a[i+j]=(t1+t2)%P;
                a[i+j+l2]=(t1-t2)%P;
                temp=temp*w%P;
            }
        }
    }
    if(flag==-1){
        LL invn=getPow(len,P-2);
        for(int i=1;i<len;i++) if(i<len-i) swap(a[i],a[len-i]);
        for(int i=0;i<len;i++) a[i]=(a[i]*invn%P+P)%P;
    }
}

void mul(LL *a,LL *b,int len){
    static LL t1[MAXN],t2[MAXN];
    for(int i=0;i<len;i++) t1[i]=a[i],t2[i]=b[i],t1[i+len]=t2[i+len]=0;
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<(len<<1);i++) t1[i]=t1[i]*t2[i]%P;
    FFT(t1,len<<1,-1);
    for(int i=0;i<(len<<1);i++) a[i]=t1[i];
}

void getInv(LL *b,LL *a,int len){
    if(len==1){
        b[0]=1;
        return;
    }
    static LL t1[MAXN],t2[MAXN];
    getInv(t1,a,len>>1);
    for(int i=(len>>1);i<len;i++) t1[i]=0;
    for(int i=0;i<len;i++) t2[i]=a[i];
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<(len<<1);i++) t1[i]=(2*t1[i]-t2[i]*t1[i]%P*t1[i]%P+P)%P;
    FFT(t1,len<<1,-1);
    for(int i=0;i<len;i++) b[i]=t1[i];
}

LL getC(int x,int y){
    return fac[x]*invfac[y]%P*invfac[x-y]%P;
}

LL gao(){
    for(sizew=1;sizew<=n1 || sizew<=n2;sizew<<=1);
    static LL space[MAXN];
    for(int i=1;i<=sizew;i++){
        a[i]=space+(i-1)*2;
        a[i][0]=1;
        if(i<=n2) a[i][1]=-i;
    }
    for(int l=2;l<=sizew;l<<=1)
        for(int i=1;i<=sizew;i+=l)
            mul(a[i],a[i+l/2],l);
    getInv(b,a[1],sizew);
    assert(b[0]==1);
    LL ans=0;
    for(int i=0;i<=n1;i++){
        LL temp=getC(n1,i)*b[i]%P*fac[i]%P*fac[n2]%P*fac[n2]%P*fac[n1-i]%P*fac[n1-i]%P*getC(n1+n2,n1-i)%P;
        ans=(ans+temp)%P;
    }
    return ans;
}

int main(){
#ifdef DEBUG
    freopen("E.in","r",stdin);
#endif
    init();
    printf("%lld\n",gao());
}

查看详细内容

集训队作业之ARC058-D

2017-10-25 10:252017-10-25 10:25
集训队作业AtCoder容斥组合数学

传送门

考虑容斥掉每个不符合条件的方案,枚举进入非法区域时的纵坐标,然后几个组合数随便乘一乘减一减就行了。

#include <cstdio>
#include <cstring>
#define MAXN 1000010
#define LL long long

const LL P=1000000007;
int n,m,a,b;
LL fac[MAXN],invfac[MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

void init(){
    fac[0]=1;
    for(int i=1;i<MAXN;i++) fac[i]=fac[i-1]*i%P;
    invfac[MAXN-1]=getPow(fac[MAXN-1],P-2);
    for(int i=MAXN-2;i>=0;i--) invfac[i]=invfac[i+1]*(i+1)%P;
}

LL getC(int x,int y){
    return fac[x]*invfac[y]%P*invfac[x-y]%P;
}

LL g(int x,int y){ return getC(x+y-2,y-1); }

int main(){
    init();
    scanf("%d%d%d%d",&n,&m,&a,&b);
    LL ans=g(n,m);
    for(int i=1;i<=b;i++)
        ans=(ans-g(n-a,i)*g(a,m-i+1)%P+P)%P;
    printf("%lld\n",ans);
}

查看详细内容

集训队作业之AGC002-F

2017-9-27 19:172017-9-27 19:17
集训队作业AtCoder组合数学DP

传送门

本来打算早上先把题目看一下,然后上午上课的时候想题,中午再写。结果看题的时候时间紧没看数据范围,一直以为是100000。。。然后想题的过程中想了一个$O(n^2logn)$的做法和一个$O(n^2)$的做法,然而怎么也没想出能过100000的做法(废话这题就出到2000怎么可能想得出100000的)。直到中午又把题目看了一遍,才发现$n=2000$。。。

既然2000就可以随便做啦。考虑钦定除0以外的所有颜色的首次出现的顺序,即从前往后1先出现,2再出现,然后是3,4,···,n。最后把答案乘一个$n!$就行了。

钦定了顺序之后,就可以把最前面的0按顺序分配给$1$到$n$了。考虑从$n$到$1$,对于每个$k$,每次插$1$个$0$和$m-1$个$k$。这时$0$必须插到序列最前面,而第一个$k$必须插在原序列中第一个非零数的前面。

注意到插法跟第一个非零数的位置有关。受到这点的启发,我们可以直接dp,设$f[i][j]$表示已经插了$n-i+1$到$n$,序列前端有$j$个0的方案数。该dp的转移是:$$f[i][j+1]=\sum_{k>=j}f[i-1][k]\times g(m*(i-1)-j+1,m-2)$$

其中$g(x,y)$表示把$y$个数插到$x$个空当中(可以相邻),即$g(x,y)=\binom{x+y-1}{y}$

注意到这个转移后面的那个组合数跟$k$根本没有关系。所以我们直接维护$f[i]$的后缀和进行转移就行啦。

代码:

#include <cstdio>
#include <cstring>
#define MAXN 2010
#define LL long long
#define P 1000000007

int n,m,N;
LL fac[MAXN*MAXN],invfac[MAXN*MAXN];
LL f[MAXN][MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

LL getC(LL x,LL y){
    if(x<y) return 0;
    return fac[x]*invfac[y]%P*invfac[x-y]%P;
}

LL calc_g(LL x,LL y){ return getC(x+y-1,y); }

void init(){
    fac[0]=1;
    for(int i=1;i<=N;i++)
        fac[i]=fac[i-1]*i%P;
    invfac[N]=getPow(fac[N],P-2);
    for(int i=N-1;i>=0;i--)
        invfac[i]=invfac[i+1]*(i+1)%P;
}

int main(){
#ifdef DEBUG
    freopen("F.in","r",stdin);
#endif
    scanf("%d%d",&n,&m);
    if(m==1){
        puts("1");
        return 0;
    }
    N=n*m;
    init();
    f[1][0]=f[1][1]=1;
    for(int i=2;i<=n;i++){
        for(int j=0;j<i;j++)
            f[i][j+1]=f[i-1][j]*calc_g(m*(i-1)-j+1,m-2)%P;
        for(int j=i-1;j;j--)
            f[i][j]=(f[i][j]+f[i][j+1])%P;
        f[i][0]=f[i][1];
    }
    LL ans=f[n][0]*fac[n]%P;
    printf("%lld\n",ans);
    return 0;
}

查看详细内容