传送门
一看 2s 1e6 就想 \(nlog^2n\) 去了,成功避开正解
考虑枚举左端点,在合法的右端点中取最大值
我一直在想如何把原序列扔进线段树里,利用pushup维护
但这样每换一个左端点都要整体pushup一次显然不对
考虑暴力找右端点的过程,发现它统计了一个前缀和
一种颜色第一次出现贡献为正,第二次为负,第三次为0
然后就要从这个前缀和里找最大值
因为这里没有一个显式的在序列里找最大值的过程我就没想到线段树优化
只要对于每个位置维护出这个颜色下一次和下下次出现的位置
每次移动左端点的时候区间加减确保这个序列仍满足这个性质
就变成了每次取一遍最大值了
Code:
#include <bits/stdc++.h> using namespace std; #define INF 0x3f3f3f3f #define N 1000010 #define ll long long #define reg register int #define max2(a, b) ((a)>(b)?(a):(b)) //#define int long long char buf[1<<21], *p1=buf, *p2=buf; #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++) inline int read() { int ans=0, f=1; char c=getchar(); while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();} while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();} return ans*f; } int n, m; int c[N]; ll d[N]; bool ext[N]; namespace force{ int cnt[N], sta[N], top; void solve() { ll ans=0, sum; for (reg i=1; i<=n; ++i) { if (i!=n && c[i]==c[i+1]) continue; sum=0; for (reg j=i; j<=n; ++j) { if (++cnt[c[j]]==1) {sta[++top]=c[j]; sum+=d[c[j]];} else if (cnt[c[j]]==2) sum-=d[c[j]]; ans = max2(ans, sum); } while (top) cnt[sta[top--]]=0; } printf("%lld\n", ans); exit(0); } } namespace task{ int nxt1[N], nxt2[N], pos[N]; int tl[N<<2], tr[N<<2]; ll sum[N], tag[N<<2], maxn[N<<2], len[N<<2], ans; #define tl(p) tl[p] #define tr(P) tr[p] #define tag(p) tag[p] #define len(p) len[p] #define maxn(p) maxn[p] #define pushup(p) maxn(p)=max(maxn(p<<1), maxn(p<<1|1)); void spread(int p) { if (!tag(p)) return ; maxn(p<<1)+=tag(p), tag(p<<1)+=tag(p); maxn(p<<1|1)+=tag(p), tag(p<<1|1)+=tag(p); tag(p)=0; } void build(int p, int l, int r) { tl(p)=l; tr(p)=r; len(p)=r-l+1; if (l==r) {maxn(p)=sum[l]; return ;} int mid=(l+r)>>1; build(p<<1, l, mid); build(p<<1|1, mid+1, r); pushup(p); } void upd(int p, int l, int r, ll tem) { if (l<=tl(p)&&r>=tr(p)) {maxn(p)+=tem; tag(p)+=tem; return ;} int mid=(tl(p)+tr(p))>>1; spread(p); if (l<=mid) upd(p<<1, l, r, tem); if (r>mid) upd(p<<1|1, l, r, tem); pushup(p); } void solve() { for (int i=n; i; --i) nxt1[i]=pos[c[i]], nxt2[i]=nxt1[pos[c[i]]], pos[c[i]]=i; for (int i=1; i<=n; ++i) { if (pos[c[i]]==i) { sum[i]+=d[c[i]]; if (nxt1[i]) sum[nxt1[i]]-=d[c[i]]; } } for (int i=1; i<=n; ++i) sum[i]=sum[i-1]+sum[i]; //for (int i=1; i<=n; ++i) cout<<sum[i]<<' '; cout<<endl; build(1, 1, n); for (int i=1; i<=n; ++i) { //cout<<"at: "<<i<<' '<<maxn(1)<<endl; ans = max(ans, maxn(1)); upd(1, i, n, -d[c[i]]); if (nxt1[i]) upd(1, nxt1[i], n, 2*d[c[i]]); if (nxt2[i]) upd(1, nxt2[i], n, -d[c[i]]); } printf("%lld\n", ans); exit(0); } } signed main() { n=read(); m=read(); for (int i=1; i<=n; ++i) c[i]=read(); for (int i=1; i<=m; ++i) d[i]=read(); //force::solve(); task::solve(); return 0; }