オンライン分割統治 FFT 面白いね!!
公式解説がとても丁寧なので、備忘録程度に
問題概要
頂点数 のが与えられます。 組の頂点対に対して、次のように無向辺を張っていきます。 組めの頂点対に対しては
- 長さ 1 の辺が 本
- 長さ 2 の辺が 本
- ...
- 長さ の辺が 本
というように辺を張っていきます。このように作られたグラフにおいて、頂点 0 から出発して頂点 0 へと戻ってくる長さ のウォークの本数を 998244353 で割ったあまりを求めてください。
制約
まずは DP
いかにも計算量的に間に合わないことはわかりつつも、まずは愚直な DP を立てることが大事な気がする。
dp[v][t]
← 距離 だけ進んで、頂点 に到達するような場合の数
このとき、 を終点にもつような各辺 に対して、
dp[v][t]
+= dp[u][t - i]
× p[e][i]
という遷移が立てられる。この時点で にはなっているけど、とても間に合わないということで悩んでいた。そのジレンマは公式解説にまさに書かれていた。
僕もこの式が畳み込みっぽいとは思っていて、でも dp[t]
の計算に dp[0]
, dp[1]
, ..., dp[t-1]
も必要で難しいな...となった。それと類似の状況は下に貼る問題でも生じていた。畳み込みの添字に i < j という順序関係があって難しいという問題だった。そしてそのときに FFT が役に立ったことを思い出してはいた。なので今回も FFT 的にできないかな...という気持ちはあった。
ただそのあとを詰められなかった。コンテスト後に解説を見て、分割統治 FFT (オンライン FFT) という考え方を知った。上に貼った問題もオンライン FFT とは違うけど、似た感じなのかなと思った。分割統治 FFT の解説については公式解説がとても丁寧なのでそちらに。
備忘録として、統治部分の式を。
t = mid, mid + 1, ..., right - 1
に対して
dp[v][t]
+= dp[u][i]
× p[e][t - i]
計算量は となる。なんか順序っぽい構造が入った FFT は分割統治法を使うとうまくいくことがあるよ、ということで頭に留めたいと思う。
コード
#include <bits/stdc++.h> using namespace std; #include "atcoder/convolution.hpp" #include "atcoder/modint.hpp" using namespace atcoder; using mint = modint998244353; int main() { // 入力 int N, M, T; cin >> N >> M >> T; vector<int> A(M * 2), B(M * 2); vector<vector<mint>> P(M * 2, vector<mint>(T + 1)); for (int i = 0; i < M; ++i) { cin >> A[i] >> B[i]; --A[i], --B[i]; B[i + M] = A[i], A[i + M] = B[i]; for (int t = 1; t <= T; ++t) { long long p; cin >> p; P[i + M][t] = P[i][t] = p; } } // 分割統治 FFT vector<vector<mint>> dp(N, vector<mint>(T + 1, 0)); dp[0][0] = 1; auto rec = [&](auto self, int left, int right) -> void { if (right - left <= 1) return; int mid = (left + right) / 2; // まず左半分を更新 self(self, left, mid); // 左半分から右半分への遷移を更新 for (int e = 0; e < M * 2; ++e) { int u = A[e], v = B[e]; vector<mint> L(mid - left, 0), R(right - left, 0); for (int t = left; t < mid; ++t) L[t - left] = dp[u][t]; for (int t = 0; t < right - left; ++t) R[t] = P[e][t]; auto seki = convolution(L, R); for (int t = mid; t < right; ++t) { dp[v][t] += seki[t - left]; } } // 最後に右半分を更新 self(self, mid, right); }; rec(rec, 0, T + 1); cout << dp[0][T].val() << endl; }