楽しかった。こういうのでロリハ使うの楽しい。発想自体は Zero-Sum Ranges (200 点) と似てる。
問題概要
高橋君は、いつも頭の中に長さ 2000000001 の数列 と、整数値 を思い浮かべている。初期状態では、数列の各要素値と、 の値はすべて 0 である。
ここで長さ の文字列 が与えられる。各文字は "+-><" のいずれかであり
- '+' のとき、 をインクリメントする
- '-' のとき、 をデクリメントする
- '>' のとき、 をインクリメントする
- '<' のとき、 をデクリメントする
この 回の操作を行ってできる結果の数列を とする。 の空でない部分区間であって、その区間の操作のみを初期状態の数列と に対して行ってできる結果の数列が、 に一致するものを数え上げよ ( の値は一致しなくてよい)
制約
考えたこと
個考えられる区間を高速に扱う方法なんて限られている。しゃくとり法などが代表的だけど、とてもそれが適用できる雰囲気の問題ではない。一つあるのは、
- 累積和の要領で、区間 [0, ) に関する状態をハッシュ化して、
- 各 に対して、区間 [ , ) が条件を満たすような がどのようなハッシュ値をもつべきかを求め、そのような をカウントする
という考え方。この考え方の最も簡単な場合が、Zero-Sum Ranges といえる。
ハッシュ化
数列の状態は以下のようにして自然にハッシュ化できる。 を適当な値、 を素数として、
- (mod. M)
が負になる部分もあるが、問題ない。これはロリハそのものでもある。 個目の操作を終えた段階でのハッシュ値を 、その段階での の値を とおく。こうしておくと、各 に対して、ハッシュ値が
となるような区間 [0, ) ( ) の個数をカウントすれば OK
ハッシュの衝突確率
ハッシュの衝突確率に関する議論は、公式解説にある。
https://img.atcoder.jp/arc099/editorial.pdf
僕の場合、(MOD, BASE) の組が 1 組のみだと衝突した。2 組にしたら通った。
#include <iostream> #include <string> #include <vector> #include <map> #include <algorithm> using namespace std; long long modinv(long long a, long long mod) { long long b = mod, u = 1, v = 0; while (b) { long long t = a/b; a -= t*b; swap(a, b); u -= t*v; swap(u, v); } u %= mod; if (u < 0) u += mod; return u; } long long modpow(long long a, long long n, long long mod) { long long res = 1; if (n < 0) { a = modinv(a, mod); n = -n; return modpow(a, n, mod); } while (n > 0) { if (n & 1) res = res * a % mod; a = a * a % mod; n >>= 1; } return res; } const vector<long long> MOD = {1000000009, 1000000007}; const vector<long long> BASE = {17, 1009}; long long solve(int N, const string &S) { vector<vector<long long>> hash(2, vector<long long>(N+1, 0)); vector<vector<long long>> pval(2, vector<long long>(N+1, 0)); map<pair<long long, long long>, vector<int>> pos; for (int i = 0; i < N; ++i) { for (int it = 0; it < 2; ++it) { if (S[i] == '>') { hash[it][i+1] = hash[it][i]; pval[it][i+1] = pval[it][i] + 1; } else if (S[i] == '<') { hash[it][i+1] = hash[it][i]; pval[it][i+1] = pval[it][i] - 1; } else if (S[i] == '+') { pval[it][i+1] = pval[it][i]; long long add = modpow(BASE[it], pval[it][i], MOD[it]); hash[it][i+1] = (hash[it][i] + add) % MOD[it]; } else if (S[i] == '-') { pval[it][i+1] = pval[it][i]; long long add = modpow(BASE[it], pval[it][i], MOD[it]); hash[it][i+1] = (hash[it][i] - add + MOD[it]) % MOD[it]; } } pos[{hash[0][i+1], hash[1][i+1]}].push_back(i+1); } long long res = 0; for (int i = 0; i <= N; ++i) { vector<long long> risou_add(2), risou(2); for (int it = 0; it < 2; ++it) { risou_add[it] = hash[it][N] * modpow(BASE[it], pval[it][i], MOD[it]) % MOD[it]; risou[it] = (hash[it][i] + risou_add[it]) % MOD[it]; } auto &v = pos[{risou[0], risou[1]}]; int it = upper_bound(v.begin(), v.end(), i) - v.begin(); res += (int)v.size() - it; } return res; } int main() { int N; string S; while (cin >> N >> S) cout << solve(N, S) << endl; }