题解做法 by Mikukuovo
#include<bits/stdc++.h> #define ll long long using namespace std; const int mod=998244353; int n,d; int main(){ ios::sync_with_stdio(false); cin>>n>>d; vector<int>p(n),q(n); for(int i=0;i<n;++i) cin>>p[i]; for(int i=0;i<n;++i) cin>>q[i]; vector dp(d+1,vector<ll>(d+1)); dp[0][0]=1; for(int k=0;k<n;++k){ int w=abs(p[k]-q[k]); vector s1(d+1,vector<ll>(d+1)); s1=dp; auto s2=s1; for(int i=0;i<d;++i){ for(int j=0;j<d;++j){ (s1[i+1][j+1]+=s1[i][j])%=mod; (s2[i+1][j]+=s2[i][j+1])%=mod; } } auto Calcs1=[&](int i,int j)->ll{ if(i<0||j<0) return 0; return s1[i][j]; }; auto Calcs2=[&](int i,int j,int w)->ll{ int prei=i-w-1,prej=j+w+1; if(j<0){ i+=j; j=0; } if(i<0) return 0; ll res=s2[i][j]; if(prei>=0&&prej<=d) res-=s2[prei][prej]; res=(res%mod+mod)%mod; return res; }; for(int i=0;i<d+1;++i){ for(int j=0;j<d+1;++j){ ll res=0; (res+=Calcs1(i-1,j-w-1))%=mod; (res+=Calcs1(i-w-1,j-1))%=mod; (res+=Calcs2(i,j-w,w))%=mod; dp[i][j]=res; } } } ll ans=0; for(int i=0;i<d+1;++i) for(int j=0;j<d+1;++j) (ans+=dp[i][j])%=mod; cout<<ans<<endl; return 0; }
分块做法 by wanghaoze
#include<bits/stdc++.h> using namespace std; int n,q,a[101010],tp[111],l[111],r[111],p[3][111]; int cnt[3][222],mp[101010],pp[3][222],vl[3][222]; bool vis[101010]; long long sm[3][3][222]; void solve(int m) { memset(vis,0,sizeof(vis)); for(int i=0;i<m;i++) { cin>>tp[i]>>l[i]>>r[i]; if(tp[i]==2) { cin>>p[0][i]>>p[1][i]>>p[2][i]; } l[i]--; vis[l[i]]=1; vis[r[i]]=1; } int nn=0; memset(cnt,0,sizeof(cnt)); memset(sm,0,sizeof(sm)); memset(mp,0,sizeof(mp)); for(int i=0;i<=n;i++) { if(vis[i]) { nn++; mp[i]=nn; pp[0][nn]=0; pp[1][nn]=1; pp[2][nn]=2; } sm[0][a[i]][nn]+=cnt[0][nn]; sm[1][a[i]][nn]+=cnt[1][nn]; sm[2][a[i]][nn]+=cnt[2][nn]; cnt[a[i]][nn]++; } for(int i=0;i<m;i++) { if(tp[i]==2) { for(int j=mp[l[i]];j<mp[r[i]];j++) { pp[0][j]=p[pp[0][j]][i]; pp[1][j]=p[pp[1][j]][i]; pp[2][j]=p[pp[2][j]][i]; } } else { long long ans=0; for(int j=mp[l[i]];j<mp[r[i]];j++) { vl[0][j]=0; vl[1][j]=0; vl[2][j]=0; vl[pp[0][j]][j]+=cnt[0][j]; vl[pp[1][j]][j]+=cnt[1][j]; vl[pp[2][j]][j]+=cnt[2][j]; if(pp[0][j]>pp[1][j]) { ans+=sm[0][1][j]; } if(pp[0][j]>pp[2][j]) { ans+=sm[0][2][j]; } if(pp[1][j]>pp[0][j]) { ans+=sm[1][0][j]; } if(pp[1][j]>pp[2][j]) { ans+=sm[1][2][j]; } if(pp[2][j]>pp[0][j]) { ans+=sm[2][0][j]; } if(pp[2][j]>pp[1][j]) { ans+=sm[2][1][j]; } } for(int j=mp[l[i]];j<mp[r[i]];j++) { for(int k=j+1;k<mp[r[i]];k++) { ans+=1ll*vl[1][j]*vl[0][k]+1ll*vl[2][j]*vl[0][k]+1ll*vl[2][j]*vl[1][k]; } } cout<<ans<<endl; } } int lst=0; for(int i=0;i<n;i++) { lst=max(mp[i],lst); if(lst) { a[i]=pp[a[i]][lst]; } } } int main() { cin>>n>>q; for(int i=0;i<n;i++) { cin>>a[i]; } for(int i=0;i<q;i+=100) { solve(min(q,i+100)-i); } return 0; }