PROBLEM LINK:
Setter: Shashwat Chandra
Tester: Yash Chandnani
Editorialist: Rajarshi Basu
DIFFICULTY:
Medium-Hard
PREREQUISITES:
Sqrt Decomposition
PROBLEM:
You are given two sequences A_1, A_2, \ldots, A_N and C_1, C_2, \ldots, C_N. For each valid i, C_i is the colour of A_i.
You should answer Q queries. In each query:
- You are given two colours x and y.
- Consider the subsequence of A which contains only elements A_i such that C_i = x or C_i = y (in the original order). Let’s denote it by B.
- For each contiguous subsequence of B (including empty subsequence), consider the sum of all its elements. Find the maximum of these sums.
Constraints
- 1 \le T \le 10^2
- 1 \le N, Q \le 3*10^5
- |A_i| \le 10^9 for each valid i
- 1 \le C_i \le N for each valid i
- 1 \le x \ne y \le N
- the sum of N over all test cases does not exceed 3*10^5
- the sum of Q over all test cases does not exceed 3*10^5
EXPLANATION:
My Apologies for the delay. The editorial should be up by tomorrow. Meanwhile, feel free to read setter’s brief notes for ideas
The main Idea for this problem is to use Square Root Decomposition. Well… actually more than one type of them.
First, we classify colours into heavy and light. Let the frequency of a color C be denoted by F[C].
- If F[C] >\sqrt N, then we call C to be a light color.
- else, C is a heavy color.
In a query, let the colors chosen be C_1 and C_2.
If both are light?
If both are light, that means we can just create the subsequence B since it will have O(\sqrt N) terms. Then, it just becomes a simple application of Kadane’s algorithm.
If (atleast) one of them is heavy
First,
- dp[i][c] = maximum sum subarray such that right endpoint is at i, and the colors being considered are C[i] and c. Take care that C[i] \neq c.We assume that c is heavy, so there are only \sqrt N choices for c.
- There can be a few cases in the transition for dp[i][c]. Let j be maximum index such that C[j] = C[i] and j < i. Then,
- dp[i][C] = max(A[i],A[i] + sufMax,A[i] + block_{sum} + dp[j][c])
- Basically, between j and i, there will be some indices x where C[x] = c. if you want to include A[j] in your dp[i][c], it means that all such indices x needs to be included as well. From that, we get:
block_{sum} = \sum\limits_{\forall C[x] = c, j < x < i} A[x]- The second term in the \max was taking a suffix of these aforementioned x values. The maximum such suffix is denoted by sufMax. How can we do this? Well, taking maximum of suffix from a VARIABLE right endpoint is same as taking minimum of a prefix from a FIXED left endpoint right (just a subtraction away)? Try to think along this direction
“Is this enough? No. Exceptions?
After this dp[.][.] construction, it should be easy to find out the answer right? But we are still missing some cases.
We are assuming that our optimal sequence ends at a light color or some other heavy color thats not equal to c. (if this is not clear, revisit the dp[.][.] definition). However this is not enough since there might be some more occurrences of c that we should have taken in dp[i][c]. In particular, the greatest prefix in the next block. So, we need two cases for the dp, where dp_1[i][c] accounts for sequences ending at i, and dp_2[i][c] accounts for sequences which include the greatest prefix in the block (of indices x with C[x] = c) after i.
Details to be inserted
Setter’s Notes
We do sqrt decomposition based on no. of occurrences of a color.
If both sz[i],sz[j] < root N, answer the query in O(root N)
For colors with sz[i] > root N, we preprocess the answer for it with all other color. For a particular color i the preprocessing works as follows:
- Now, build a min sparse table on the value prefix sums of this color.
- For any index k such that color[k] != i, dp[k] = max(dp[j]+a[k],a[k]) if j is the closest index with the same color to the left of k and there is no index with color i between j and k
- Otherwise dp[k] = max(dp[j]+a[k]+blocksum,a[k]) j is the closest index with the same color to the left of k and blocksum is the sum of values of all indices with color i in between.
- The other case is dp[k] = max(a[k]+suffixofblock,a[k]) here suffixofblock is max over (pref[x]-minprefixsum) where x is the closest index with color i to the left of k and minprefixsum is the minprefixsum is literally that across all valid indices in the block.
Every index is preprocessed at most root N times giving a final complexity of O(N root N).
Another case is when freq[i] >= sqrtN and freq[j] < sqrtN. In this case no preprocessing considers the case when the answer consists of only indices with color i. For this we have another sqrt decomp (on blocks this time). We mark the blocks which have color j and for each color i with freq[i] >= sqrtN, we maintain information like max overall sum subarray in the block of that color, maxprefix till index x of that block and maxsuffix till index x etc.
SOLUTIONS:
Setter’s Code
/*input
1
12 1
292044772 659496817 836243974 921839771 -992521372 496511382 670210154 340982537 524243020 863201960 773286685 -506588905
2 4 3 1 4 2 1 3 1 1 1 1
1 2
1 3
1 4
*/
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define f first
#define s second
const int INF = 1e18;
const int BL = 310;
const int SZ = 1010;
const int N = 3e5+10;
const int LG = 19;
int a[N];
int inv[N];
int c[N];
int n,q;
vector<int> bigs;
vector<int> all[N];
int na[N];
int overall[SZ][SZ];
int inbest[BL];
int sufbest[BL][SZ];
int prebest[BL][SZ];
int totsum[BL];
int whichblock[N];
int blockstart[BL];
int blockend[BL];
int blockpos[N];
int ans[N];
int spma[N][LG];
int spmi[N][LG];
int pre[N];
int dp[N];
int bestforthis[BL][N];
int logg[N];
int stpref[BL][SZ];
int stsuf[BL][SZ];
vector< pair<int,int> > adj[N];
void ini(){
for(int i = 1;i <= n;i++){
all[i].clear();
adj[i].clear();
}
bigs.clear();
}
#define print(arr) for(auto it:arr)cout << it << " "; cout << endl;
int brute(int x,int y){
int res = 0;
int cur = 0;
int i = 0;
int j = 0;
//print(all[x]);
//print(all[y]);
while(i < (int)all[x].size() or j < (int)all[y].size()){
//cout << i << " " << j << " " << cur << " " << res << endl;
if(j == (int)all[y].size()){
cur = max(a[all[x][i]],cur+a[all[x][i]]);
res = max(res,cur);
i++;
}
else if(i == all[x].size()){
cur = max(a[all[y][j]],cur+a[all[y][j]]);
res = max(res,cur);
j++;
}
else if(all[x][i] > all[y][j]){
cur = max(a[all[y][j]],cur+a[all[y][j]]);
res = max(res,cur);
j++;
}
else{
cur = max(a[all[x][i]],cur+a[all[x][i]]);
res = max(res,cur);
i++;
}
}
return res;
}
void setsqrt(int ind){
int color = bigs[ind];
int nn = all[color].size();
for(int i = 1;i <= nn;i++){
whichblock[i] = ((i-1)/SZ)+1;
blockend[whichblock[i]] = i;
totsum[whichblock[i]] = 0;
if(whichblock[i] != whichblock[i-1]){
blockstart[whichblock[i]] = i;
blockpos[i] = 1;
}
else{
blockpos[i] = blockpos[i-1]+1;
}
na[i] = a[all[color][i-1]];
}
for(int i= 1;i <= nn;i++)totsum[whichblock[i]] += na[i];
pre[0] = 0;
for(int i = 1;i <= n;i++){
if(c[i] == color){
pre[i] = pre[i-1]+1;
}
else pre[i] = pre[i-1];
}
for(int i = 1;i <= nn;i++){
if(blockstart[whichblock[i]] == i){
int wbl = whichblock[i];
prebest[wbl][0] = 0;
int suma = 0;
for(int j = 1;j <= SZ;j++){
if(i+j-1 > nn)break;
suma += na[i+j-1];
prebest[wbl][j] = max(prebest[wbl][j-1],suma);
}
}
if(blockend[whichblock[i]] == i){
int wbl = whichblock[i];
sufbest[wbl][0] = 0;
int suma = 0;
for(int j = 1;j <= SZ;j++){
if(i-j+1 < 1)break;
suma += na[i-j+1];
sufbest[wbl][j] = max(sufbest[wbl][j-1],suma);
}
}
}
}
void justoverall(int ind,int wb){
int color = bigs[ind];
int nn = all[color].size();
for(int i = blockstart[wb];i <= blockend[wb];i++){
int wbl = whichblock[i];
int thispos = blockpos[i];
int cur = 0;
for(int j = 1;j <= SZ;j++){
if(i+j-1 > nn)break;
if(whichblock[i+j-1] != whichblock[i])break;
cur = max(cur,0LL)+na[i+j-1];
//cout << thispos << " " << thispos+j-1 << " OOO " << cur << endl;
overall[thispos][thispos+j-1] = cur;
}
overall[thispos][thispos] = max(0LL,overall[thispos][thispos]);
for(int j = 2;j <= SZ;j++){
if(i+j-1 > nn)break;
if(whichblock[i+j-1] != whichblock[i])break;
cur = max(cur,0LL)+na[i+j-1];
overall[thispos][thispos+j-1] = max(overall[thispos][thispos+j-1],overall[thispos][thispos+j-2]);
}
inbest[wbl] = max(overall[1][blockpos[i]],inbest[wbl]);
}
for(int j = 1;j <= blockpos[blockend[wb]];j++){
stpref[wb][j] = overall[1][j];
stsuf[wb][j] = overall[j][blockpos[blockend[wb]]];
//cout << wb << " " << j << " " << stpref[wb][j] << " " << stsuf[wb][j] << endl;
}
}
void domx(int l,int r,int qind){
ans[qind] = max(ans[qind],overall[blockpos[l]][blockpos[r]]);
}
int ansqrt(int i,int j){
int iwb = whichblock[i];
int jwb = whichblock[j];
//cout << i << " " << iwb << " " << j << " " << jwb << endl;
int res =0;
res = max(res,stsuf[iwb][blockpos[i]]);
//res = max(res,overall[iwb][blockpos[i]][blockpos[blockend[iwb]]]);
res = max(res,stpref[jwb][blockpos[j]]);
//res = max(res,overall[jwb][1][blockpos[j]]);
int cur = sufbest[iwb][blockend[iwb]-i+1];
for(int wb = whichblock[i]+1;wb < whichblock[j];wb++){
//cout << totsum[wb] << endl;
res = max(res,cur+prebest[wb][blockend[wb]-blockstart[wb]+1]);
cur = max(cur+totsum[wb],sufbest[wb][blockend[wb]-blockstart[wb]+1]);
res = max(res,inbest[wb]);
}
res = max(res,cur+prebest[jwb][blockpos[j]]);
return res;
}
void prelog(){
logg[1] = 0;
for(int i = 2;i < N;i++){
if(1<<(logg[i-1]+1) <= i)logg[i] = logg[i-1]+1;
else logg[i] = logg[i-1];
}
}
int rngmax(int l,int r){
if(l > r)return -INF;
int j = logg[r - l + 1];
return max(spma[l][j],spma[r - (1 << j) + 1][j]);
}
int rngmin(int l,int r){
if(l > r)return INF;
int j = logg[r - l + 1];
return min(spmi[l][j],spmi[r - (1 << j) + 1][j]);
}
void sparsepre(int inds){
pre[0] = 0;
for(int i = 1;i <= n;i++){
if(c[i] == bigs[inds]){
pre[i] = pre[i-1]+1;
}
else pre[i] = pre[i-1];
}
int color = bigs[inds];
int nn = all[color].size();
for(int i = 0;i <= nn;i++){
if(i)spmi[i][0] = spma[i][0] = a[all[color][i-1]]+spma[i-1][0];
}
for(int j = 1;j < LG;j++){
for(int i = 0;i <= nn;i++){
if(i+(1<<j)-1 > nn)continue;
spmi[i][j] = min(spmi[i][j-1],spmi[i+(1<<(j-1))][j-1]);
spma[i][j] = max(spma[i][j-1],spma[i+(1<<(j-1))][j-1]);
}
}
for(int i = 1;i <= n;i++){
if(i == color)continue;
int ans =0;
for(int j = 0;j < (int)all[i].size();j++){
int ind = all[i][j];
int prv = 0;
if(j){
prv = all[i][j-1];
}
int one = pre[ind];
int two = pre[prv];
int three = nn;
if(j+1 < (int)all[i].size())three = pre[all[i][j+1]];
int suffix = rngmax(one,one)-rngmin(two,one);
int prefix = rngmax(one,three)-rngmin(one,one);
int allsum = spma[one][0]-spma[two][0];
dp[ind] = max(dp[prv]+allsum,suffix)+a[ind];
ans = max(ans,dp[ind]+prefix);
}
bestforthis[inds][i] = ans;
}
}
void solve(){
ini();
cin >> n >> q;
//cout << n << " " << q << endl;
a[n+1] = 0;
for(int i = 1;i <= n;i++)cin >> a[i];
for(int i = 1;i <= q;i++)ans[i] = 0;
for(int i = 1;i <= n;i++){
whichblock[i] = ((i-1)/SZ)+1;
blockend[whichblock[i]] = i;
if(whichblock[i] != whichblock[i-1]){
blockstart[whichblock[i]] = i;
blockpos[i] = 1;
}
else{
blockpos[i] = blockpos[i-1]+1;
}
cin >> c[i];
all[c[i]].push_back(i);
}
int cnt = 0;
for(int i = 1;i <= n;i++){
if(all[i].size() >= SZ){
bigs.push_back(i);
inv[i] = cnt++;
}
else{
inv[i] = -1;
}
}
for(int i= 0;i < (int)bigs.size();i++){
sparsepre(i);
}
for(int i = 1;i <= q;i++){
int x,y;cin >> x >> y;
if(all[x].size() < all[y].size())swap(x,y);
if(all[x].size() < SZ){
ans[i] = brute(x,y);
}
else if(all[y].size() < SZ){
adj[inv[x]].push_back({y,i});
}
else{
int wowsie = max(bestforthis[inv[x]][y],bestforthis[inv[y]][x]);
ans[i] = wowsie;
}
}
for(int i = 0;i < (int)bigs.size();i++){
int color = bigs[i];
setsqrt(i);
vector< pair< pair<int,int> ,int> > queries;
vector< pair< pair<int,int> ,int> > spec;
for(auto it:adj[i]){
ans[it.s] = max(ans[it.s],bestforthis[i][it.f]);
int prev = 0;
int best = 0;
for(auto cc:all[it.first]){
int l = pre[prev]+1;
int r = pre[cc];
if(l <= r){
if(whichblock[l] != whichblock[r])queries.push_back({{l,r},it.second});
else{
spec.push_back({{l,r},it.second});
}
}
prev = cc;
}
int l = pre[prev]+1;
int r = pre[n];
if(l <= r){
if(whichblock[l] != whichblock[r])queries.push_back({{l,r},it.second});
else{
spec.push_back({{l,r},it.second});
}
}
}
sort(spec.begin(),spec.end());
int prvbl = 0;
for(int ix = 0;ix < spec.size();ix++){
if(prvbl != whichblock[spec[ix].f.f]){
justoverall(i,whichblock[spec[ix].f.f]);
}
prvbl = whichblock[spec[ix].f.f];
if(whichblock[spec[ix].f.f] == whichblock[spec[ix].f.s]){
domx(spec[ix].f.f,spec[ix].f.s,spec[ix].s);
}
}
for(int ix = 1;ix <= whichblock[all[color].size()];ix++){
//cout << i << " " << ix << endl;
justoverall(i,ix);
}
for(int ix = 0;ix < queries.size();ix++){
ans[queries[ix].s] = max(ans[queries[ix].s],ansqrt(queries[ix].f.f,queries[ix].f.s));
}
}
//cout << "HERE" << endl;
for(int i = 1;i <= q;i++){
/*int x;cin >> x;
cout << x << " " << ans[i] << endl;
assert(x == ans[i]);*/
cout << ans[i] << endl;
}
}
//2830941819
//4541337744
signed main(){
//freopen("rt.txt","r",stdin);
//freopen("bb4out.txt","w",stdout);
prelog();
ios_base::sync_with_stdio(0);cin.tie(0);
int t = 1;cin >> t;
//cout << t << endl;
while(t--){
solve();
}
}
Tester’s Code
#include <bits/stdc++.h>
using namespace std;
void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#ifndef ONLINE_JUDGE
#define debug(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define debug(x...)
#endif
#define rep(i, n) for(int i = 0; i < (n); ++i)
#define repA(i, a, n) for(int i = a; i <= (n); ++i)
#define repD(i, a, n) for(int i = a; i >= (n); --i)
#define trav(a, x) for(auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()
#define fill(a) memset(a, 0, sizeof (a))
#define fst first
#define snd second
#define mp make_pair
#define pb push_back
typedef long double ld;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
void pre(){
}
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
struct Node {
Node *l = 0, *r = 0;
int lo, hi;
ll l_sum = 0, r_sum = 0, val = 0, sum=0;
Node(int lo=0,int hi=300000):lo(lo),hi(hi){
if (lo + 1 < hi) {
int mid = lo + (hi - lo)/2;
l = new Node(lo, mid); r = new Node(mid, hi);
}
} // Large interval of -inf
void set(int p, int x) {
if (p+1 <= lo || hi <= p) return;
if (p <= lo && hi <= p+1) {
l_sum=r_sum=sum=val=x;
}
else {
l->set(p, x), r->set(p, x);
val = max(l->val,r->val);
l_sum = max(l->l_sum,l->sum+r->l_sum);
r_sum = max(r->r_sum,r->sum+l->r_sum);
sum = l->sum + r->sum;
val = max(l->val,r->val);
val = max(val,l->r_sum+r->l_sum);
}
}
};
vi g[300009];
ll a[300009];
Node N;
int c[300009];
void ins(int x){
trav(i,g[x]){
N.set(i,a[i]);
}
}
void del(int x){
trav(i,g[x]){
N.set(i,0);
}
}
ll ans[300009];
bool fg[300009];
int main() {
cin.sync_with_stdio(0); cin.tie(0);
cin.exceptions(cin.failbit);
pre();
int nq=readIntLn(1,100);
ll sumn=0,sumq=0;
rep(qq,nq){
int n=readIntSp(1,300000);
int q=readIntLn(1,300000);
sumn+=n,sumq+=q;
rep(i,n-1) {
a[i]=readIntSp(-1e9,1e9),g[i+1].clear();
fg[i+1]=0;
}
fg[n]=0;
a[n-1]=readIntLn(-1e9,1e9),g[n].clear();
rep(i,n-1){
int x=readIntSp(1,n);
c[i] = x;
g[x].pb(i);
}
c[n-1]=readIntLn(1,n);
g[c[n-1]].pb(n-1);
vector<pair<pii,pii>> v;
map<pii,ll> prev_ans;
rep(i,q){
int x=readIntSp(1,n);
int y=readIntLn(1,n);
fg[x]=fg[y]=1;
assert(x!=y);
if(sz(g[x])+sz(g[y])<=10000){
if(prev_ans.find(mp(x,y))!=prev_ans.end()){
ans[i] = prev_ans[mp(x,y)];
continue;
}
int ix = 0,iy=0;
ll cns=0,gns=0;
while(ix<sz(g[x])&&iy<sz(g[y])){
if(g[x][ix]<g[y][iy]){
cns = max(a[g[x][ix]],a[g[x][ix]]+cns);
gns=max(gns,cns);
ix++;
}
else{
cns = max(a[g[y][iy]],a[g[y][iy]]+cns);
gns=max(gns,cns);
iy++;
}
}
while(iy<sz(g[y])){
cns = max(a[g[y][iy]],a[g[y][iy]]+cns);
gns=max(gns,cns);
iy++;
}
while(ix<sz(g[x])){
cns = max(a[g[x][ix]],a[g[x][ix]]+cns);
gns=max(gns,cns);
ix++;
}
prev_ans[mp(x,y)]=prev_ans[mp(y,x)]=gns;
ans[i]=gns;
}
else{
if(mp(sz(g[x]),x)<mp(sz(g[y]),y)) swap(x,y);
v.pb(mp(mp(sz(g[x]),x),mp(y,i)));
}
}
sort(all(v));
reverse(all(v));
pii lst = mp(0,0);
trav(i,v){
if(lst.snd==i.fst.snd) swap(lst.fst,lst.snd);
if(lst.fst==i.snd.fst) swap(lst.fst,lst.snd);
if(lst.fst!=i.fst.snd){
del(lst.fst);
ins(i.fst.snd);
lst.fst=i.fst.snd;
}
if(lst.snd!=i.snd.fst){
del(lst.snd);
ins(i.snd.fst);
lst.snd=i.snd.fst;
}
ans[i.snd.snd] = N.val;
}
del(lst.fst),del(lst.snd);
rep(i,q){
cout<<ans[i]<<'\n';
}
}
assert(sumn<=300000);
assert(sumq<=300000);
return 0;
}