ARMERIA

Rubyと競技プログラミングの話 AtCoderやCodeforcesの問題解説記事が多め。

合成数modでの二項係数を用いた数え上げ

先日やっていたバチャで出てきたので、まとめておきます。

この記事では二項係数を  _{n}C_{k} と表記します。また法を  M とします。

はじめに

 _{n}C_{k} が必要となる  n, k の最大値をそれぞれ  N, K として、 O(N^{2}) O(NK) が許される制約であれば、パスカルの三角形で全部計算できます。こちらのほうが圧倒的に楽なので間に合う場合はこちらを使いましょう。

vector<vector<int64_t>> comb(N+1, vector<int64_t>(N+1));
for(int i=0; i<=N; i++){
    comb[i][0] = comb[i][i] = 1;
    for(int j=1; j<i; j++) comb[i][j] = (comb[i-1][j-1] + comb[i-1][j]) % M;
}

やり方

アウトライン

中国剰余定理を使います。

中国剰余定理 (CRT) の解説と、それを用いる問題のまとめ - Qiita

まず  M素因数分解して、素数 P_{i} = p_{i}^{c_{i}} の積に分解します。

求めたい答え(の  \bmod を取る前の値)を  A とします。 A \bmod M を求める代わりに、  A \bmod P_{i} を求めて、中国剰余定理で復元します。

例えば  M = 180 であれば、これは  2^{2}\times 3^{2}\times 5素数冪の積に分解できるので、 A \bmod 4 A \bmod 9 A \bmod 5 をそれぞれ別々に求めて復元します。

この  A が二項係数の足し・引き・掛け算で計算できるようなときに、 A \bmod P_{i} を求める方法を書いていきます。

二項係数

 M が十分大きい素数である場合は、階乗とその逆元を前計算しておいて  O(1) _{n}C_{k} を求める方法が広く使われています。ただし  M素数でない場合や、  M が小さくて  n M 以上になってしまう場合は、一般には逆元が取れると限らないためこの方法は使えません。ですが、これを応用することにします。

素数 p、法とする素数冪を  P = p^{c} と書くことにします。このとき  n! は、

  •  n! の中の素因数  p を全て取り除いたものを、 \bmod P で計算したもの
  •  n! の中に素因数  p がいくつ含まれるか

の2つで特徴付けられます。これらの値を順に  x, y として、 n! \leftrightarrow (x, y) と表記することにします。

例えば  P=2^{2}=4 を法として  6! を計算する場合は、 1\times 2\times 3\times 4\times 5\times 6 の中に素因数  2 4 つ含まれ、それらを除くと  45 \equiv 1 \bmod 4 なので、 6! \leftrightarrow (1, 4) となります。

ここで  x の値が  p の倍数になることはありません。例えば  p=2 として、素因数に  2 を含まない奇数だけを掛け算した値が、 \bmod 4 0, 2 になることはありません。

 n! に対応する値  (x, y) は、 0!, 1!, 2!, ..., n! と順次計算していくことができます。 (n-1)! から  n! を計算する際には、 n に素因数  p がいくつ含まれるかを計算して、 x, y をそれぞれ処理すれば良いです。

そして二項係数

 \displaystyle _{n}C_{k} = \frac{n!}{k!(n-k)!}

 \bmod P で計算する際も、これらの値から計算することができます。

 x については、分母にあるものに対しては  \bmod P での逆元を計算して掛ければ良いです。 P素数とは限らないので、フェルマーの小定理ではなく拡張ユークリッドの互除法を使いましょう。 n! \leftrightarrow (x, y) において  x p の倍数になることはない、つまり  P と互いに素であることは保証されるので、拡張ユークリッドの互除法は適用可能です。

「1000000007 で割ったあまり」の求め方を総特集! 〜 逆元から離散対数まで 〜 - Qiita

 y については、分母にあるものをマイナスとして扱って足し算すれば良いです。二項係数は整数になるので、この値が非負であることは保証されます。

こうして得られた二項係数が  _{n}C_{k}  \leftrightarrow (X, Y) と計算できたとして、ここから  _{n}C_{k} \bmod P の値を求めます。これは  Y \ge c であれば  0 であり、そうでなければ  X p^{Y} \bmod P を実際に計算してあげれば良いです。

これで  _{n}C_{k} \bmod P が計算できたので、これらの足し・引き・掛け算で構成される値は  \bmod P で計算できます。 (x, y) の形式から実際の  \bmod P での値に直してしまった後は、逆元は一般には取れないので注意です。

長いですが実装例です。

// サブルーチン群
void add(int64_t& a, int64_t b, int64_t mod){
    a = (a+b) % mod;
}
void mul(int64_t& a, int64_t b, int64_t mod){
    a = a*b % mod;
}

vector<pair<int64_t, int>> prime_division(int64_t n){
    vector<pair<int64_t, int>> ret;
    for(int64_t i=2; i*i<=n; i++){
        int cnt = 0;
        while(n % i == 0){
            n /= i;
            cnt++;
        }
        if(cnt) ret.emplace_back(i, cnt);
    }
    if(n > 1) ret.emplace_back(n, 1);
    return ret;
}

int64_t extgcd(int64_t a, int64_t b, int64_t& x, int64_t& y){
    int64_t d = a;
    if(b != 0){
        d = extgcd(b, a%b, y, x);
        y -= (a/b) * x;
    }else{
        x = 1; y = 0;
    }
    return d;
}

int64_t inv_mod(int64_t a, int64_t mod){
    int64_t x, y;
    extgcd(a, mod, x, y);
    return (MOD + x%mod) % mod;
}

// 素数冪を (p, c) で表現したもの
vector<pair<int64_t, int>> primes;
// 素数冪 p^c の実際の値
vector<int64_t> ppow;
// 階乗を (x, y) 形式で表現したもの
vector<vector<pair<int64_t, int>>> fact;


// 素因数分解をして素数冪ごとにfactを前計算
void create_composite_mod_table(int N, int64_t M){
    primes = prime_division(M);
    int sz = primes.size();
    ppow.resize(sz, 1);
    fact.resize(sz);
    for(int pi=0; pi<sz; pi++){
        int64_t p = primes[pi].first, cnt = primes[pi].second;
        while(cnt--) ppow[pi] *= p;

        auto& f = fact[pi];
        f.resize(N+1);
        f[0] = {1, 0};
        for(int i=1; i<=N; i++){
            f[i] = f[i-1];
            int n = i;
            while(n%p == 0){
                n /= p;
                f[i].second++;
            }
            mul(f[i].first, n, ppow[pi]);
        }
    }
}

// 二項係数を計算
int64_t comb_mod(int n, int k, int pi){
    auto &a = fact[pi][n], &b = fact[pi][k], &c = fact[pi][n-k];
    int64_t p = primes[pi].first, cnt = primes[pi].second;
    int64_t pp = ppow[pi];
    int pw = a.second - b.second - c.second;
    if(pw >= cnt) return 0;

    int64_t v = a.first;
    mul(v, inv_mod(b.first, pp), pp);
    mul(v, inv_mod(c.first, pp), pp);
    while(pw--) mul(v, p, pp);
    return v;
}

問題例

Codeforces Div.1の問題です。

問題:Problem - D - Codeforces

ACコード:Submission #71631654 - Codeforces