「 番目の値を求めよ」「メディアンを求めよ」といった問題では、二分探索法が有効なことが多々ありますね。
問題概要
長さ の数列 が与えられます。
この数列の連続する区間として考えられるものは 個あります。そのそれぞれの区間について、区間内の値のメディアンをとります。
このようにして得られる 個の値のメディアンを求めてください。
制約
二分探索へ
次の判定問題を解くことを考えてみましょう。この判定問題に対する答えが "Yes" となる最小の が答えとなります。
数列の連続する区間 () であって、その区間内の数値のメディアンが 以下であるようなものが 個以上存在するならば "Yes"、存在しないならば "No" と判定してください
判定問題を解く方針
まず、各数列の値に対して、次の変換を施して考えましょう。
- x 以上の値は、1 に置き換える
- x 未満の値は、-1 に置き換える
このとき、
区間 () のメディアンが 以下
⇔ 区間 において 1 の個数が -1 の個数より少なくない
⇔ 区間 の総和が 0 以上
と言い換えられます。よって、数列の値を 1, -1 に置き換えた状態で、「総和が 0 以上であるような区間」の個数を数えて、それが 個以上であるかどうかを判定する問題へと帰着されました。
累積和をとる
置き換えた数列を改めて [tex: a_{0}, \dots, a_{N-1}} として、
としましょう。このとき、数列 における区間 の総和が 0 以上という条件は
S[r] - S[l] >= 0
⇔ S[l] <= S[r]
と言い換えられます。ここまで来るととても明快です。判定問題は最終的には次のように言い換えられます。
累積和 において、 を満たす () の組の個数が 個以上であるかどうかを判定してください
これとは反対の を満たす の組の個数は転倒数と呼ばれます。転倒数の求め方については
などを参考にしてください。 の計算量で求められます。
コード
転倒数を求めるのには通常 BIT と呼ばれるデータ構造を活用します。コードは次のように実装できます。
最後に計算量を評価しましょう。数列の値の最大値を とすると、二分探索の反復回数は 回と評価できます。よって全体としては と評価できます。
#include <iostream> #include <vector> using namespace std; template <class Abel> struct BIT { const Abel UNITY_SUM = 0; // to be set vector<Abel> dat; /* [1, n] */ BIT(int n) : dat(n + 1, UNITY_SUM) { } void init(int n) { dat.assign(n + 1, UNITY_SUM); } /* a is 1-indexed */ inline void add(int a, Abel x) { for (int i = a; i < (int)dat.size(); i += i & -i) dat[i] = dat[i] + x; } /* [1, a], a is 1-indexed */ inline Abel sum(int a) { Abel res = UNITY_SUM; for (int i = a; i > 0; i -= i & -i) res = res + dat[i]; return res; } /* [a, b), a and b are 1-indexed */ inline Abel sum(int a, int b) { return sum(b - 1) - sum(a - 1); } /* debug */ void print() { for (int i = 1; i < (int)dat.size(); ++i) cout << sum(i, i + 1) << ","; cout << endl; } }; int main() { long long N; cin >> N; vector<int> a(N); for (int i = 0; i < N; ++i) cin >> a[i]; int low = 0, high = 1<<30; const int geta = N+1; while (high - low > 1) { int mid = (low + high) / 2; long long num = 0; BIT<long long> bit(N*2+10); int sum = 0; bit.add(sum+geta, 1); for (int i = 0; i < N; ++i) { int val; if (a[i] <= mid) val = 1; else val = -1; sum += val; num += bit.sum(1, sum+geta); bit.add(sum+geta, 1); } if (num > (N+1)*N/2/2) high = mid; else low = mid; } cout << high << endl; }