传送门

我们把$B_i$差分得到$D_i=B_i-B_{fa_i}$,发现题目的限制其实就是$\sum D_i\leq Y$。考虑把一个树高为$i$的子树,把根的$D_i$加一后产生的贡献。记$S_i$为这个贡献,则$(S_i,X^i)$可以表示成$(S_{i-1},X^{i-1})$的线性转移。由于这两个东西都是在模$P$意义下的,所以这个$S_i$的循环节长度就是$P^2$级别的。这样,我们就能直接预处理出对于每个$k$,有多少个点$i$满足$S_{h_i}=k$。

然后,我们做一个三维的DP,设$f_{i,j,k}$表示当前处理了$S$的取值在$[0,i)$之间的的点,当前$D$的总和为$j$,且当前$\sum A_i\times B_i$模$P$为$k$。直接枚举如何选取贡献为$i$的点来转移就行了,复杂度是$O(Y^2P^2)$。

#include <bits/stdc++.h>
#define LL long long
#define MAXN 510
#define MAXM 12

const LL P=1000000007;

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

LL H,inv2=getPow(2,P-2);
int C,n,m;
LL cnt[MAXN];
LL c[MAXN][MAXM];
LL f[MAXN][MAXM][MAXN];

LL getC(LL x,LL y){
    LL t1=1,t2=1;
    for(int i=0;i<y;i++){
        t1=t1*(x-i)%P;
        t2=t2*(i+1)%P;
    }
    return t1*getPow(t2,P-2)%P;
}

void pre_gao(){
    static bool visit[MAXN][MAXN];
    static int v[MAXN*MAXN][2];
    int numv=1;
    int nx=C,ny=C;
    v[1][0]=nx;
    v[1][1]=getPow(2,H-1);
    visit[nx][ny]=1;
    for(int i=2;;i++){
        ny=ny*C%m;
        nx=(nx*2+ny)%m;
        if(visit[nx][ny]) break;
        v[++numv][0]=nx;
        v[numv][1]=v[numv-1][1]*inv2%P;
        visit[nx][ny]=1;
    }
    LL step=getPow(inv2,numv);
    for(int i=1;i<=numv;i++){
        LL t1=H/numv;
        if(i<=H%numv) t1++;
        LL t2=v[i][1]*(getPow(step,t1)-1)%P*getPow(step-1,P-2)%P;
        cnt[v[i][0]]=(cnt[v[i][0]]+t2)%P;
    }
    for(int i=0;i<=n;i++){
        c[i][0]=1;
        for(int j=1;j<=i;j++)
            c[i][j]=(c[i-1][j-1]+c[i-1][j])%P;
    }
    for(int i=0;i<m;i++)
        for(int j=0;j<=n;j++)
            c[i][j]=getC(cnt[i]+j-1,j);
}

inline void update(LL &x,LL y){ x=(x+y)%P; }

LL dp(){
    f[0][0][0]=1;
    for(int i=0;i<m;i++)
        for(int j=0;j<=n;j++)
            for(int k=0;k<m;k++)
                for(int x=0;x<=n-j;x++)
                    update(f[i+1][j+x][(k+i*x)%m],f[i][j][k]*c[i][x]);
    LL ans=0;
    for(int i=0;i<=n;i++)
        update(ans,f[m][i][0]);
    return ans;
}

int main(){
#ifdef DEBUG
    freopen("145.in","r",stdin);
#endif
    scanf("%lld%d%d%d",&H,&C,&n,&m);
    pre_gao();
    printf("%lld\n",dp());
    return 0;
}