ARMERIA

Rubyと競技プログラミングの話 AtCoderやCodeforcesの問題解説記事が多め。

ゆるふわ競プロオンサイト #3 (Div. 1) Sweets Distribution(Hard)

お題箱より。

Programming Problems and Competitions :: HackerRank

公式解説はこちらで公開されています。

ゆるふわオンサイト#3 解説 - Google スライド

解法

2点をswapするクエリを処理するような問題は、swapであることを上手く活用するような解法も考えられますが(累積和を利用するような場合など)、今回の問題では厳しそうです。単に2点の値が変更されるクエリと思うことにします。

点更新のクエリを処理する代表的な手段はセグメント木です。セグメント木を使う上で必要なのが区間のマージで、今回の問題で最適値を求める仕組みを上手く区間のマージに落とし込めないか考えてみます。

お菓子を  0, ..., N-1 と表記し、人を  0, 1, 2, 3 と表記します。結論から言うと、各区間  \lbrack l, r\rbrack に対して次のような値

 V_{\lbrack l, r\rbrack}(i, j) = お菓子  l を人  i が、お菓子  r を人  j が食べて、その間のお菓子は問題文の条件を満たすような分け方をしているときの、区間  \lbrack l, r\rbrack で得られる美味しさの合計の最大値

を各  (i, j) について求めておけば良いです。これを  0 \le i \le j \le 3 である全ての  (i, j) について求めておけば、2つの区間をマージしたときに新しい区間に対する同様の値を求めることができます。

具体的に、区間  \lbrack l, r\rbrack についての値を  \lbrack l, m-1\rbrack \lbrack m, r\rbrack からマージして求めたいとします。 V_{\lbrack l, r\rbrack}(i, j) の値として採用できる候補となるのは、お菓子  m を食べる人を  k とすると

  • お菓子  m-1 を人  k-1 が食べる: V_{\lbrack l, m-1\rbrack}(i, k-1) + V_{\lbrack m, r\rbrack}(k, j)
  • お菓子  m-1 を人  k が食べる: V_{\lbrack l, m-1\rbrack}(i, k) + V_{\lbrack m, r\rbrack}(k, j)

のどちらかであるので、 k=i, ..., j およびこれら2パターンを全通り試して最も大きいものを取れば  V_{\lbrack l, r\rbrack}(i, j) を求めることができます。

ということで、この  V_{\lbrack l, r\rbrack}(i, j) を各  (i, j) について並べた  4\times 4 の配列をセグメント木の区間  \lbrack l, r\rbrack に対応する要素とすることにします。例えば  i \gt j であるなど、状況としてあり得ないような  (i, j) については -INF などを入れておけば良いでしょう。

単一のお菓子  x に対応する  4\times 4 配列は、 V_{\lbrack x, x\rbrack}(0, 0), ..., V_{\lbrack x, x\rbrack}(3, 3) がそれぞれ  A_{x}, ..., D_{x} であり、他が全て -INF であるようなものです。これをセグメント木の最下段に置いてswapクエリで変更し、 V_{\lbrack 0, N-1\rbrack}(0, 3) を答えとして取り出せば良いです。

また、セグメント木を実装する上では多くの場合単位元が必要です。今回の定義だと単位元の性質を満たす都合の良い  4\times 4 の配列が作れないのですが、結局マージする処理の中で「片方が単位元ならもう片方を返す」という処理を入れておけば良いので、適当に空配列などを使っておけば大丈夫です。

以上より、隣り合う2区間に対応する  4\times 4 の配列のマージ処理はこのように書けます。

typedef vector<vector<int64_t>> V;
V E;

auto merge = [](V& a, V& b){
    if(a.size() == 0) return b;
    if(b.size() == 0) return a;
    V res(4, vector<int64_t>(4, -1e18));
    for(int i=0; i<4; i++) for(int k=i; k<4; k++) for(int j=k; j<4; j++){
        chmax(res[i][j], a[i][k] + b[k][j]);
        if(i<k) chmax(res[i][j], a[i][k-1] + b[k][j]);
    }
    return res;
};

これをマージ関数とするセグメント木を構築すれば問題を解くことができます。抽象化していないと非常に大変なので、しましょう。

ACコード

#include <bits/stdc++.h>
using namespace std;

void chmax(int64_t& a, int64_t b){
    a = max(a, b);
}

template<typename T>
struct Segtree {
    int n, n_org;
    T e;
    vector<T> dat;
    typedef function<T(T& a, T& b)> Func;
    Func f;

    Segtree(int n_input, Func f_input, T e_input, vector<T>& A){
        n_org = n_input;
        f = f_input;
        e = e_input;
        n = 1;
        while(n < n_input) n <<= 1;
        dat.resize(2*n-1, e);
        for(int k=0; k<int(A.size()); k++) dat[k+n-1] = A[k];
        for(int k=n-2; k>=0; k--) dat[k] = f(dat[2*k+1], dat[2*k+2]);
    }

    void update(int k, T a){
        k += n - 1;
        dat[k] = a;
        while(k > 0){
            k = (k - 1)/2;
            dat[k] = f(dat[2*k+1], dat[2*k+2]);
        }
    }

    T get(int k){
        return dat[k+n-1];
    }

    T between(int a, int b){
        return query(a, b+1, 0, 0, n);
    }

    T query(int a, int b, int k, int l, int r){
        if(r<=a || b<=l) return e;
        if(a<=l && r<=b) return dat[k];
        T vl = query(a, b, 2*k+1, l, (l+r)/2);
        T vr = query(a, b, 2*k+2, (l+r)/2, r);
        return f(vl, vr);
    }
};

int main(){
    int N, Q;
    cin >> N >> Q;
    vector<vector<int>> A(N, vector<int>(4));
    for(int j=0; j<4; j++) for(int i=0; i<N; i++) scanf("%d", &A[i][j]);

    typedef vector<vector<int64_t>> V;
    V E;

    auto make_single = [](vector<int>& a){
        V res(4, vector<int64_t>(4, -1e18));
        for(int i=0; i<4; i++) res[i][i] = a[i];
        return res;
    };

    auto merge = [](V& a, V& b){
        if(a.size() == 0) return b;
        if(b.size() == 0) return a;
        V res(4, vector<int64_t>(4, -1e18));
        for(int i=0; i<4; i++) for(int k=i; k<4; k++) for(int j=k; j<4; j++){
            chmax(res[i][j], a[i][k] + b[k][j]);
            if(i<k) chmax(res[i][j], a[i][k-1] + b[k][j]);
        }
        return res;
    };

    vector<V> vs(N);
    for(int i=0; i<N; i++) vs[i] = make_single(A[i]);
    Segtree<V> st(N, merge, E, vs);

    while(Q--){
        int l, r;
        scanf("%d %d", &l, &r);
        l--; r--;
        swap(vs[l], vs[r]);
        st.update(l, vs[l]);
        st.update(r, vs[r]);
        int64_t ans = st.between(0, N-1)[0][3];
        printf("%lld\n", ans);
    }
    return 0;
}

余談

説明がややこしくなるので解説そのものでは採用しませんでしたが、区間を半開区間としてとらえ、区間  \lbrack l, r) に対して

 V_{\lbrack l, r)}(i, j) = お菓子  l を人  i が食べていて、「もしお菓子  r を人  j が食べたとすると、問題文の条件を満たすように区間が繋がる」ような分け方をしているときの、区間  \lbrack l, r) で得られる美味しさの合計の最大値

という値を定義することもできます。こうすると更新処理は単純に

\displaystyle V_{\lbrack l, r)}(i, j) = \max_{i \le k \le j}\left(V_{\lbrack l, m)}(i, k) + V_{\lbrack m, r)}(k, j)\right)

と書くことができて、単位元も自然に定義できるようになります。

ただしそのぶん単一のお菓子に対応する要素が少しだけ複雑になるのと、答えの値を  V_{\lbrack 0, N)}(0, 4) として求めたくなるので、自然な実装だと  5\times 5 の配列を持つことになり定数倍が悪化します。

この方針で実装したのが以下のコードで、こちらでもACを確認しています。

#include <bits/stdc++.h>
using namespace std;

void chmax(int64_t& a, int64_t b){
    a = max(a, b);
}

template<typename T>
struct Segtree {
    int n, n_org;
    T e;
    vector<T> dat;
    typedef function<T(T& a, T& b)> Func;
    Func f;

    Segtree(int n_input, Func f_input, T e_input, vector<T>& A){
        n_org = n_input;
        f = f_input;
        e = e_input;
        n = 1;
        while(n < n_input) n <<= 1;
        dat.resize(2*n-1, e);
        for(int k=0; k<int(A.size()); k++) dat[k+n-1] = A[k];
        for(int k=n-2; k>=0; k--) dat[k] = f(dat[2*k+1], dat[2*k+2]);
    }

    void update(int k, T a){
        k += n - 1;
        dat[k] = a;
        while(k > 0){
            k = (k - 1)/2;
            dat[k] = f(dat[2*k+1], dat[2*k+2]);
        }
    }

    T get(int k){
        return dat[k+n-1];
    }

    T between(int a, int b){
        return query(a, b+1, 0, 0, n);
    }

    T query(int a, int b, int k, int l, int r){
        if(r<=a || b<=l) return e;
        if(a<=l && r<=b) return dat[k];
        T vl = query(a, b, 2*k+1, l, (l+r)/2);
        T vr = query(a, b, 2*k+2, (l+r)/2, r);
        return f(vl, vr);
    }
};

int main(){
    int N, Q;
    cin >> N >> Q;
    vector<vector<int>> A(N, vector<int>(4));
    for(int j=0; j<4; j++) for(int i=0; i<N; i++) scanf("%d", &A[i][j]);

    typedef vector<vector<int64_t>> V;
    V E(5, vector<int64_t>(5, -1e18));
    for(int i=0; i<5; i++) E[i][i] = 0;

    auto make_single = [](vector<int>& a){
        V res(5, vector<int64_t>(5, -1e18));
        for(int i=0; i<4; i++) for(int j=i; j<=i+1; j++) chmax(res[i][j], a[i]);
        return res;
    };

    auto merge = [](V& a, V& b){
        V res(5, vector<int64_t>(5, -1e18));
        for(int i=0; i<5; i++) for(int k=i; k<5; k++) for(int j=k; j<5; j++){
            chmax(res[i][j], a[i][k] + b[k][j]);
        }
        return res;
    };

    vector<V> vs(N);
    for(int i=0; i<N; i++) vs[i] = make_single(A[i]);
    Segtree<V> st(N, merge, E, vs);

    while(Q--){
        int l, r;
        scanf("%d %d", &l, &r);
        l--; r--;
        swap(vs[l], vs[r]);
        st.update(l, vs[l]);
        st.update(r, vs[r]);
        int64_t ans = st.between(0, N-1)[0][4];
        printf("%lld\n", ans);
    }
    return 0;
}