传送门

比较水的树形DP,可以直接统计出一条边走过去之后还能走回来的概率。然后把问题转化成求每个点被经过的概率,以及期望的经过次数就行了。这两个都很好DP。

#include <bits/stdc++.h>
#define MAXN 1000010
#define LL long long

const LL P=998244353;

struct edge{
    int to,next;
    edge(int _to=0,int _next=0):to(_to),next(_next){}
}e[MAXN<<1];

int n;
int c[MAXN],d[MAXN];
int g[MAXN],nume;
LL w1[MAXN],w2[MAXN],w3[MAXN];
LL sw2[MAXN];
LL inv_v[MAXN];

LL getPow(LL x,LL y){
    LL res=1;
    while(y){
        if(y&1) res=res*x%P;
        x=x*x%P;
        y>>=1;
    }
    return res;
}

void addEdge(int u,int v){
    e[nume]=edge(v,g[u]);
    g[u]=nume++;
}

void dfs(int x,int p){
    LL temp=0;
    for(int i=g[x];~i;i=e[i].next)
        if(e[i].to^p){
            dfs(e[i].to,x);
            temp=(temp+w1[e[i].to]);
        }
    w3[x]=temp;
    if(p){
        if(d[x]==1) return;
        temp=temp*inv_v[d[x]]%P;
        w1[x]=getPow(1-temp,P-2)*inv_v[d[x]]%P;
    }
}

void dfs2(int x,int p){
    LL temp=p?w2[x]:0;
    w3[x]=(w3[x]+w2[x])*inv_v[d[x]]%P;
    for(int i=g[x];~i;i=e[i].next)
        if(e[i].to^p)
            temp=(temp+w1[e[i].to])%P;
    for(int i=g[x];~i;i=e[i].next)
        if(e[i].to^p){
            LL temp2=(temp-w1[e[i].to])*inv_v[d[x]]%P;
            w2[e[i].to]=getPow(1-temp2,P-2)*inv_v[d[x]]%P;
            sw2[e[i].to]=sw2[x]*w2[e[i].to]%P;
            dfs2(e[i].to,x);
        }
}

int main(){
#ifdef DEBUG
    freopen("136.in","r",stdin);
#endif
    memset(g,-1,sizeof g);
    inv_v[1]=1;
    for(int i=2;i<MAXN;i++) inv_v[i]=inv_v[P%i]*(P-P/i)%P;
    scanf("%d",&n);
    static char str[MAXN];
    scanf("%s",str+1);
    for(int i=1;i<=n;i++) c[i]=str[i]=='1';
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        addEdge(u,v);
        addEdge(v,u);
        d[u]++; d[v]++;
    }
    dfs(1,0);
    sw2[1]=1;
    dfs2(1,0);
    LL ans=0;
    for(int i=1;i<=n;i++)
        if(c[i]==0 || d[i]==1){
            ans=(ans+sw2[i])%P;
        }else{
            LL temp=getPow(1-w3[i],P-2);
            temp=temp%P*sw2[i]%P;
            ans=(ans+temp)%P;
        }
    ans=(ans%P+P)%P;
    printf("%lld\n",ans);
}