传送门

考虑DP:$f_{i,j}$表示$i$个球分$j$组的方案数。显然$f_{i,j}=f_{i-1,j}+f_{i-1,j-1}+f_{i-2,j-1}$。

如果我们把$f_i$看成一个$x$多项式,则$f_i(x)=(x+1)f_{i-1}(x)+xf_{i-2}(x)$。这是个常系数线性递推,特征方程为$\lambda^2-(x+1)\lambda-x=0$,解出特征根为$\lambda_1=\frac{x+1+\sqrt{x^2+6x+1}}{2},\lambda_2=\frac{x+1-\sqrt{x^2+6x+1}}{2}$。

我们钦定$f_{-1}(x)=0$,此时转移方程在所有正整数处都成立。我们设$f_i(x)=c_1(x)\lambda_1^{i+1}(x)+c_2(x)\lambda_2^{i+1}(x)$,把$f_{-1},f_0$代进去,得到$c_1=\frac 1{\lambda_1-\lambda_2},c_2=\frac 1{\lambda_2-\lambda_1}$。所以$f_n$就能表示成:$$f_n(x)=\frac {\lambda_1^{n+1}(x)-\lambda_2^{n+1}(x)}{\lambda_1(x)-\lambda_2(x)}$$

直接开根把$\lambda$算出来之后$\ln+\exp$求幂再求逆就行了。

由于$m\leq n$(当$m>n$时后面那些显然就是$0$),而$\lambda_2$的常数项是$0$,所以$\lambda_2^{n+1}(x)\equiv 0\pmod {x^{m+1}}$,这意味着分子后面那项根本没有贡献,可以直接忽略,从而减小常数。

#include <bits/stdc++.h>
#define MAXN 1100000
#define MAXW 1048576
#define LL long long 
using namespace std;

const LL P=998244353;

int n,m,sizew;
LL w[MAXW],inv_v[MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

void init(){
    w[0]=1;
    w[1]=getPow(3,(P-1)/MAXW);
    for(int i=2;i<MAXW;i++) w[i]=w[i-1]*w[1]%P;
    inv_v[1]=1;
    for(int i=2;i<MAXN;i++)
        inv_v[i]=-inv_v[P%i]*(P/i)%P;
}

void FFT(LL *a,int n,int flag){
    static int rev[MAXN],revn;
    if(revn!=n){
        revn=n;
        for(int i=1;i<n;i++) rev[i]=rev[i>>1]>>1|((i&1)?(n>>1):0);
    }
    for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int l=2;l<=n;l<<=1){
        int l2=l>>1;
        for(int i=0;i<n;i+=l){
            for(int j=0;j<l2;j++){
                LL t1=a[i+j],t2=a[i+j+l2]*w[MAXW/l*j];
                a[i+j]=(t1+t2)%P;
                a[i+j+l2]=(t1-t2)%P;
            }
        }
    }
    if(flag==-1){
        LL invn=getPow(n,P-2);
        for(int i=0;i<n;i++) a[i]=a[i]*invn%P;
        for(int i=1;i<n;i++) if(i<n-i) swap(a[i],a[n-i]);
    }
}

void getInv(LL *b,LL *a,int len){
    if(len==1){
        b[0]=getPow(a[0],P-2);
        return;
    }
    static LL t1[MAXN],t2[MAXN];
    getInv(t1,a,len>>1);
    for(int i=len>>1;i<len<<1;i++) t1[i]=t2[i]=0;
    for(int i=0;i<len;i++) t2[i]=a[i];
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<len<<1;i++) t1[i]=(2*t1[i]-t2[i]*t1[i]%P*t1[i])%P;
    FFT(t1,len<<1,-1);
    for(int i=0;i<len;i++) b[i]=t1[i];
}

void getLn(LL *b,LL *a,int len){
    static LL t1[MAXN],t2[MAXN];
    for(int i=0;i<len;i++) t1[i]=a[i+1]*(i+1)%P,t2[i]=a[i];
    t1[len-1]=0;
    getInv(t2,t2,len);
    for(int i=0;i<len;i++) t1[i+len]=t2[i+len]=0;
    FFT(t1,len<<1,1); FFT(t2,len<<1,1);
    for(int i=0;i<len<<1;i++) t1[i]=t1[i]*t2[i]%P;
    FFT(t1,len<<1,-1);
    for(int i=1;i<len;i++) b[i]=t1[i-1]*inv_v[i]%P;
    b[0]=0;
}

void getExp(LL *b,LL *a,int len){
    if(len==1){
        b[0]=1;
        return;
    }
    static LL t1[MAXN],t2[MAXN],t3[MAXN];
    getExp(t1,a,len>>1);
    for(int i=0;i<len>>1;i++) t3[i]=t1[i],t1[i+(len>>1)]=t3[i+(len>>1)]=0;
    getLn(t2,t1,len);
    for(int i=0;i<len>>1;i++) t2[i]=t2[i+(len>>1)]-a[i+(len>>1)],t2[i+(len>>1)]=0;
    FFT(t3,len,1); FFT(t2,len,1);
    for(int i=0;i<len;i++) t3[i]=t3[i]*t2[i]%P;
    FFT(t3,len,-1);
    for(int i=0;i<len>>1;i++) b[i]=t1[i],b[i+(len>>1)]=-t3[i];
}

void getPow(LL *b,LL *a,int len,LL k){
    static LL t1[MAXN];
    getLn(t1,a,len);
    for(int i=0;i<len;i++) t1[i]=t1[i]*k%P;
    getExp(b,t1,len);
}

void solve(){
    for(sizew=1;sizew<=n;sizew<<=1);
    static LL t1[MAXN],t2[MAXN];
    t1[0]=1; t1[1]=6; t1[2]=1;
    getPow(t1,t1,sizew,inv_v[2]);
    for(int i=0;i<sizew;i++) t2[i]=t1[i];
    t2[0]=(t2[0]+1)%P; t2[1]=(t2[1]+1)%P;
    for(int i=0;i<sizew;i++) t2[i]=t2[i]*inv_v[2]%P;
    getPow(t2,t2,sizew,m+1);
    getInv(t1,t1,sizew);
    for(int i=0;i<sizew;i++) t1[i+sizew]=t2[i+sizew]=0;
    FFT(t1,sizew<<1,1); FFT(t2,sizew<<1,1);
    for(int i=0;i<sizew<<1;i++) t1[i]=t1[i]*t2[i]%P;
    FFT(t1,sizew<<1,-1);
    for(int i=1;i<=n;i++) printf("%lld ",(t1[i]+P)%P);
}

int main(){
#ifdef DEBUG
    freopen("127.in","r",stdin);
#endif
    scanf("%d%d",&m,&n);
    int _n=n;
    if(n>m) n=m;
    init();
    solve();
    for(int i=1;i<=_n-n;i++) printf("0 ");
    return 0;
}