ARMERIA

Rubyと競技プログラミングの話 AtCoderやCodeforcesの問題解説記事が多め。

個数制限付き部分和問題+条件を満たす選び方の数え上げ

先日のコンテストの問題における、最後の高速化について知りたいというリクエストがあったため、解説を書きます。より一般的な「個数制限付き部分和問題+条件を満たす選び方の数え上げ」という枠組みで説明します。

扱う問題

正整数  N, S と、正整数からなる長さ  N の数列  \lbrace a_{0}, ..., a_{N-1}\rbrace, \lbrace b_{0}, ..., b_{N-1}\rbrace が与えられる。

 a_{0} 0 個以上  b_{0} 個以下、 a_{1} 0 個以上  b_{1} 個以下…と選ぶ方法のうち、その総和が  S であるような方法の個数を  10^{9}+7 で割った余りを求めよ。

より形式的には、以下の条件を満たす非負整数列  \lbrace n_{0}, ..., n_{N-1}\rbrace の個数を  10^{9}+7 で割った余りを求めよ。

  • 全ての  i に対して  0 \le n_{i} \le b_{i}
  •  \sum_{i}n_{i}a_{i} = S

この問題を全体計算量  O(NS) で解きます。

解法

普通の部分和問題と同様にDPをします。状態を以下のように定義します。

  •  dp\lbrack i\rbrack\lbrack j\rbrack a_{0}, ..., a_{i-1} をそれぞれいくつ使うか決めて、それまでの総和が  j であるような場合の数

このDPテーブルのサイズは  (N+1)(S+1) です。

 a_{i} をいくつ使うかを決める時の遷移を、貰うDPで考えます。 dp\lbrack i+1\rbrack\lbrack j\rbrack に遷移するのは、

  •  dp\lbrack i\rbrack\lbrack j\rbrack から、 a_{i} 0 個使うことにして遷移
  •  dp\lbrack i\rbrack\lbrack j-a_{i}\rbrack から、 a_{i} 1 個使うことにして遷移
  •  dp\lbrack i\rbrack\lbrack j-b_{i}a_{i}\rbrack から、 a_{i} b_{i} 個使うことにして遷移

という  b_{i}+1 個の遷移元のうち、添字が非負であるものです。

これらを全て個別に遷移していては最悪ケースで  O(NS^{2}) 掛かってしまうので、以下のいずれかの手法で高速化します。

高速化手法1:累積和

先ほどの説明から、DPの遷移式は以下のようになります。

 dp\lbrack i+1\rbrack\lbrack j\rbrack = dp\lbrack i\rbrack\lbrack j\rbrack + dp\lbrack i\rbrack\lbrack j-a_{i}\rbrack + \cdots + dp\lbrack i\rbrack\lbrack j-b_{i}a_{i}\rbrack

 j j \bmod a_{i} が等しいグループに分けて、グループごとに  dp\lbrack i\rbrack\lbrack j\rbrack の累積和を取っておきます。そうするとこの右辺は1つのグループ内での区間和になるので、累積和から  O(1) で計算できます。これで全体計算量  O(NS) を達成できます。

高速化手法2:差分更新

本質的には高速化手法1とほぼ同じですが、より実装が楽だと思います。

 dp\lbrack i+1\rbrack\lbrack j\rbrack dp\lbrack i+1\rbrack\lbrack j-a_{i}\rbrack の遷移元がほとんど同じで1つだけずれていることを利用すると、 j を小さい方から求めていきながら以下の式で差分更新することができます。

 dp\lbrack i+1\rbrack\lbrack j\rbrack = dp\lbrack i+1\rbrack\lbrack j-a_{i}\rbrack + dp\lbrack i\rbrack\lbrack j\rbrack - dp\lbrack i\rbrack\lbrack j-(b_{i}+1)a_{i}\rbrack

添字が負である領域の値は全て  0 とします。この計算は  O(1) でできるので、同様に全体計算量  O(NS) を達成できます。

f:id:betrue12:20201005120022p:plain

実装例

高速化手法2を使っています。

#include <bits/stdc++.h>
using namespace std;

int64_t MOD = 1e9+7;
void add(int64_t& a, int64_t b){
    a = (a+b) % MOD;
}
void mul(int64_t& a, int64_t b){
    a = a*b % MOD;
}

int main(){
    int N, S;
    vector<int> A(N), B(N);
    
    vector<vector<int64_t>> dp(N+1, vector<int64_t>(S+1));
    dp[0][0] = 1;
    for(int i=0; i<N; i++) for(int j=0; j<=S; j++) {
        dp[i+1][j] = dp[i][j];
        if(j-A[i] >= 0) add(dp[i+1][j], dp[i+1][j-A[i]]);
        if(j-(B[i]+1)*A[i] >= 0) add(dp[i+1][j], MOD-dp[i][j-(B[i]+1)*A[i]]);
    }
}

ACコード

冒頭にリンクを貼った問題のACコードリンクも貼っておきます。これが元のリクエストの回答になればと思います。

Submission #17202414 - AtCoder Regular Contest 104