题目链接

BZOJ3747

题解

这种找区间最优的问题,一定是枚举一个端点,然后用数据结构维护另一个端点

我们枚举左端点,用线段树维护每个点作为右端点时的答案
当左端点为\(1\)时,我们能\(O(n)\)预处理出每个位置的答案初始化线段树
当左端点右移一位时,该位上的电影就从区间删除了,记\(nxt[i]\)为下一个同类电影的位置,那么从左端点到\(nxt[i] - 1\)的位置的权值都会减少掉\(w[f[i]]\),而\(nxt[i]\)\(nxt[nxt[i]] - 1\)的位置都会增加\(w[f[i]]\)
就是一个简单的线段树维护区间最大值了

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,int>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,int>
#define LL long long int
#define ls (u << 1)
#define rs (u << 1 | 1)
using namespace std;
const int maxn = 1000005,maxm = 100005,INF = 1000000000;
inline int read(){
	int out = 0,flag = 1; char c = getchar();
	while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
	while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
	return out * flag;
}
LL mx[maxn << 2],tag[maxn << 2],C[maxn];
void upd(int u){mx[u] = max(mx[ls],mx[rs]);}
void pd(int u){
	if (tag[u]){
		mx[ls] += tag[u]; tag[ls] += tag[u];
		mx[rs] += tag[u]; tag[rs] += tag[u];
		tag[u] = 0;
	}
}
void build(int u,int l,int r){
	if (l == r){mx[u] = C[l]; return;}
	int mid = l + r >> 1;
	build(ls,l,mid);
	build(rs,mid + 1,r);
	upd(u);
}
void add(int u,int l,int r,int L,int R,LL v){
	if (l >= L && r <= R){mx[u] += v; tag[u] += v; return;}
	pd(u);
	int mid = l + r >> 1;
	if (mid >= L) add(ls,l,mid,L,R,v);
	if (mid < R) add(rs,mid + 1,r,L,R,v);
	upd(u);
}
LL query(int u,int l,int r,int L,int R){
	if (l >= L && r <= R) return mx[u];
	pd(u);
	int mid = l + r >> 1;
	if (mid >= R) return query(ls,l,mid,L,R);
	if (mid < L) return query(rs,mid + 1,r,L,R);
	return max(query(ls,l,mid,L,R),query(rs,mid + 1,r,L,R));
}
int nxt[maxn],last[maxn],bac[maxn];
int n,m,f[maxn],w[maxn];
LL sum,ans;
int main(){
	n = read(); m = read();
	REP(i,n) f[i] = read();
	REP(i,m) w[i] = read();
	for (int i = n; i; i--) nxt[i] = last[f[i]],last[f[i]] = i;
	for (int i = 1; i <= n; i++){
		if (!bac[f[i]]) sum += w[f[i]];
		else if (bac[f[i]] == 1) sum -= w[f[i]];
		bac[f[i]]++;
		C[i] = sum;
	}
	build(1,1,n);
	for (int i = 1; i <= n; i++){
		ans = max(ans,query(1,1,n,i,n));
		if (nxt[i]){
			if (nxt[i] > i + 1) add(1,1,n,i + 1,nxt[i] - 1,-w[f[i]]);
			add(1,1,n,nxt[i],nxt[nxt[i]] ? nxt[nxt[i]] - 1 : n,w[f[i]]);
		}
		else add(1,1,n,i + 1,n,-w[f[i]]);
	}
	printf("%lld\n",ans);
	return 0;
}