一中OJ | P3659 保卫家园 | 数学 | 组合数学 / 期望


Description

强尼是蚁群中最聪明的一只蚂蚁。

强尼所在的蚁群共有 $n$ 个蚁穴,总共有 $m$ 只不那么聪明的蚂蚁保卫着这 $n$ 个蚁穴。但是这 $m$ 只蚂蚁并不是兢兢业业的时刻保卫着蚁穴,他们总喜欢从自己所保卫的蚁穴中出来,聚集到一起讨论天空的颜色。而一旦出现危险情况,这 $m$ 只蚂蚁会以最快的速度随机地跑回一个蚁穴。一个蚁穴只要有一只蚂蚁保卫,强尼就认为它是安全的蚁穴。

这里的随机你可以这样理解:每只蚂蚁的行为是独立的。每只蚂蚁等概率地跑回这 $n$ 个蚁穴,即某只蚂蚁跑回某一给定的蚁穴的概率为 $\frac{1}{n}$ 。

现在,强尼想知道,危险来临时,安全的蚁穴的期望个数是多少。为了避免浮点数产生的误差,设最终的期望值为 $E$ ,强尼只需要你告诉它 $E \times n^m \mod 998244353$ 的值即可。(可以证明,$E\times n^m$ 是一个正整数。其中 $998244353=7 \times 17 \times 2^{23}+1$ 是一个质数)。


Input Format

输入只有一行,包含两个正整数 $n,m$ 。


Output Format

输出只有一行,$E \times n^m \mod 998244353$ 的值,其中 $E$ 表示安全蚁穴个数的数学期望。


Input Sample

Sample 01

2 3

Sample 02

5 0


Output Sample

Sample 01

14

Sample 02

0


Data Range

$0 < n \le 5000$。$0 \le m \le 5000$。


Solution

此题计算期望个数 × $n^m$ 的结果,实际上就是计算 $\sum$ 每种情况的安全蚁穴个数。

考虑枚举有保卫的蚁穴的个数 $x \in [1 , n]$,则对于每个 $x$ 种数为 $\binom{n}{x}$ 。接下来的问题就转换成:有 $x$ 个标号的盒子,放 $m$ 个标号的球的方案数。考虑第二类 Stirling 数,有 $x$ 个无标号的盒子,放 $m$ 个标号的球的方案为 $S(m,x)$ 。乘上 $x!$ 即为盒子标号的情况。

这里用了二项式反演后的公式 $S(m,x)x! = \sum_{i=0}^{x} (-1)^{x-i} \binom{x}{i} i^m$。最后的结果为 $\sum_{x=1}^{n} x \binom{n}{x}\sum_{i=0}^{x} (-1)^{x-i} \binom{x}{i} i^m$。时间复杂度 $\Theta(n^2 + n\log_2m)$。


Code

#include<cstdio>
#include<cstdlib>
#include<cstring>

#define min(x,y) ((x)<(y)?(x):(y))
#define max(x,y) ((x)>(y)?(x):(y))
#define swap(x,y) {int t=x; x=y,y=t;}
#define wipe(x,y) memset(x,y,sizeof(x))
#define dbgIn(x) freopen(x".in","r+",stdin)
#define rep(x,y,z) for(int x=y,I=z;x<=I;++x)
#define dbgOut(x) freopen(x".out","w+",stdout)
#define file(x) freopen(x".in","r+",stdin); freopen(x".out","w+",stdout)

#define mod 998244353

typedef long long ll;

inline void Read(int &x){
    x=0; char ch=0; while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+ch-48,ch=getchar(); return;
}

int N;
int M;

ll ans;

ll pw[5005];
ll C[5005][5005];

void Binom(){
    C[0][0]=1;
    rep(i,1,N){
        C[i][0]=1;
        rep(j,1,i)
            C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
    }
    return;
}

ll qkpow(ll tp,ll up){
    if(up==0) return 1;
    ll ret=qkpow(tp,up>>1);
    if(up&1) return ret*ret%mod*tp%mod;
    return ret*ret%mod;
} 

void Power(){
    rep(i,0,N)
        pw[i]=qkpow(i,M);
    return;
}

ll Rc(int Box){
    ll ret=0;
    rep(i,0,Box)
        if((Box-i)&1)
            ret=((ret-(ll)pw[i]*C[Box][i])%mod+mod)%mod;
        else
            ret=(ret+(ll)pw[i]*C[Box][i])%mod;
    return ret;
}

int main(){
    Read(N);
    Read(M);
    Power();
    Binom();
    rep(i,1,N)
        ans=(ans+(ll)C[N][i]*i%mod*Rc(i))%mod;
    printf("%lld\n",ans);
    return 0;
}