AtCoder ABC 259 Ex - Yet Another Path Counting (橙色, 600 点) - けんちょんの競プロ精進記録

けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 259 Ex - Yet Another Path Counting (橙色, 600 点)

opt さんとのスペースで解いた! いくつか典型を見落としていたのでメモ!!

問題概要

 N \times N のグリッドがあって、マス  (i, j) には値  a_{i, j} が記されている。

いずれかのマスから始めて右または下に隣接するマスへの移動を 0 回以上繰り返すことで得られる経路のうち、始点と終点のラベルが同じものの個数を 998244353 で割った余りを求めよ。

制約

  •  1 \le N \le 400
  •  1 \le a_{i, j} \le N^{2}

考えたこと

 O(N^{4}) なら容易にできる。いかにして計算量を落とすか......

opt さんが注目していたのは、「値の種類がすごく多いとき」と「値の種類がすごく少ないとき」はともに容易に解くことができて、その間が難しいのではないかということだった。

極端なことを言えば、全種類異なっていたら 0 個。全種類同じだったら綺麗に解けそう。

こういう状況では、平方分割的な発想がうまくいきそう。つまり、

  • その値のマスが  N 個以上である値を処理する
  • その値のマスが  N 個未満である値を処理する

とで解法を分けることが考えられる。

その値のマスが  N 個以上である場合

この場合は、値の種類自体が高々  O(N) 個未満であることが大きな特徴だ。

したがって、各値について  O(N^{2}) の計算量で解ければ十分だ。DP でできる。その値を val とする。


  • dp[i][j] ← マス  (i, j) まで到達する経路のうち、始点の値が val であるものの個数 (終点が val でなくてもよい)

最後に、val であるマスについての dp[i][j] の値を合算すればよい。各値 val に対して  O(N^{2}) の計算量で解ける。

その値のマスが  N 個未満である場合

この場合は、ありうる値が最悪  O(N^{2}) 通りにもなり得る。その代わり、各値ごとのマス数が十分小さいのだ。

したがって、愚直に、

  • 各値について
  • その値をとるすべてのマスのペアを考えて
  • 二項係数の値を足していく

という方法でよいのではと考えられる。

実はこのとき、考えるべきペア数の総和が  O(N^{3}) で抑えられることが言える。なぜなら、どのマスを固定して考えても、そのマスと同じ数値を持つマスの個数が  N 個未満だからだ。

提出コード

以上の場合分けによって、全体を通して  O(N^{3}) の計算量で解けた。

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

// Binomial coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;
BiCoef<mint> bc(1000);

int main() {
    int N;
    cin >> N;
    vector<vector<int>> a(N, vector<int>(N));
    map<int, vector<pint>> places;
    for (int i = 0; i < N; ++i)
        for (int j = 0; j < N; ++j) {
            cin >> a[i][j];
            places[a[i][j]].emplace_back(i, j);
        }
    
    mint res = 0;
    for (auto [val, cells] : places) {
        if (cells.size() >= N) {
            vector<vector<mint>> dp(N, vector<mint>(N, 0));
            for (int i = 0; i < N; ++i) {
                for (int j = 0; j < N; ++j) {
                    if (i > 0) dp[i][j] += dp[i-1][j];
                    if (j > 0) dp[i][j] += dp[i][j-1];
                    if (a[i][j] == val) {
                        dp[i][j] += 1;
                        res += dp[i][j];
                    }
                }
            }
        } else {
            for (int i = 0; i < cells.size(); ++i) {
                for (int j = i; j < cells.size(); ++j) {
                    auto [sx, sy] = cells[i];
                    auto [tx, ty] = cells[j];
                    if (sx > tx) swap(sx, tx), swap(sy, ty);
                    if (sx <= tx && sy <= ty) {
                        int dx = tx - sx, dy = ty - sy;
                        res += bc.com(dx + dy, dx);
                    }
                }
            }
        }
    }
    cout << res << endl;
}

メモ

場合分けしないで  O(N^{3}) で解く方法もあるらしい。形式的冪級数による考察がハマるらしい!

atcoder.jp