SEGTHREE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: adhoom
Tester: satyam_343
Editorialist: iceknight1093

DIFFICULTY:

2089

PREREQUISITES:

Observation

PROBLEM:

You’re given an array A.
In one step, you can increase one of its elements by 1.

Find the minimum number of moves needed to reach an array for which every subarray of size 3 has a sum divisible by 3.

EXPLANATION:

First, observe that every element will be increased by either 0, 1, \text{or } 2.
Performing more than two increments on a single element is useless: you could reduce it by 3 and achieve the same result.

Let B be the final array we attain after increments, where each size-3 subarray has its sum divisible by 3.
Notice that, if we fix B_1 and B_2, the rest of the elements of B are also uniquely fixed!
In fact, we can even find them all easily in \mathcal{O}(N) time.

How?

Suppose B_1 and B_2 are fixed. Then,

  • (B_1 + B_2 + B_3) must be divisible by 3.
    Since B_1 and B_2 are fixed, we already have the sum (B_1 + B_2 + A_3) — and we can only increase A_3.
    Since the number of moves required should be minimum, our best bet is to increase A_3 by either 0, 1, \text{ or } 2. Exactly one of these will let us reach a sum that’s divisible by 3.
    Notice that this means B_3 is fixed uniquely, to one of \{A_3, A_3+1, A_3+2\}.
  • Next, (B_2 + B_3 + B_4) should be divisible by 3.
    Once again, B_2 and B_3 are fixed, so B_4 is determined uniquely.
  • Continuing on this way from left to right, each B_i is fixed in order, hence fixing the entire array.

Combining this with our first observation (that each A_i will be increased by at most 2) we see that there aren’t too many options to check.
In particular:

  • B_1 is one of (A_1, A_1+1, A_1+2)
  • B_2 is one of (A_2, A_2+1, A_2+2)

This gives us 3\times 3 = 9 options in total for B_1 and B_2.
For each option, the entire B array can be computed in \mathcal{O}(N) time, as detailed in the spoiler above.
So, simply try all 9 options, compute the B array for them all, and find which of them uses the least number of increments in total.

TIME COMPLEXITY

\mathcal{O}(9\cdot N) per testcase.

CODE:

Author's code (C++)
///       ______        __________                    _____   _____        _____
///      ///  \\\      ||__||   \\\    |||     |||  ||     || |||\\\      ///|||
///     ///    \\\     ||__||    \\\   |||_____|||  ||     || ||| \\\    /// |||
///    ///______\\\    ||__||     \\\  |||_____|||  ||     || |||  \\\  ///  |||
///   ///________\\\   ||__||     ///  |||_____|||  ||     || |||   \\\///   |||
///  ///          \\\  ||__||    ///   |||     |||  ||     || |||            |||
/// ///            \\\ ||__||___///    |||     |||  ||_____|| |||            |||

#include<bits/stdc++.h>
#define FIO ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define endl "\n"
using namespace std;
typedef long long ll;
typedef long double ld;
const ll N=1e5+5;
ll a[N];
ll dp[N][4][4];
ll vis[N][4][4];
ll cur=2;
ll n;
ll solve(ll idx,ll prv1,ll prv2)
{
    if(idx==n)return 0;
    ll &ans=dp[idx][prv1][prv2];
    ll &v=vis[idx][prv1][prv2];
    if(v==cur)return ans;
    v=cur;
    ans=4e18;
    for(int i=0;i<3;i++)
    {
        ll x=a[idx];
        ll c=0;
        while(x%3!=i)x++,c++;
        if(prv1==3)ans=min(ans,solve(idx+1,i,prv2)+c);
        else if(prv2==3)ans=min(ans,solve(idx+1,prv1,i)+c);
        else
        {
            if((i+prv1+prv2)%3==0)
            {
                ans=min(ans,solve(idx+1,prv2,i)+c);
            }
        }
    }
    return ans;
}
void test_case()
{
    cin>>n;
    for(int i=0;i<n;i++)cin>>a[i];
    cur++;
    ll ans=4e18;
    ans=min(ans,solve(0,3,3));
    cout<<ans<<endl;
}
int main()
{
//    FIO
//  freopen("input.txt","rt",stdin);
//  freopen("output.txt","wt",stdout);
    ll t;
    t=1;
    cin>>t;
    while(t--)
    {
        test_case();
    }
}
Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif 
#define ll long long 
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){   
        char g=getchar();
        if(g=='-'){  
            assert(fi==-1);     
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
const ll MOD=1e9+7;
vector<ll> readv(ll n,ll l,ll r){
    vector<ll> a;
    ll x;
    for(ll i=1;i<n;i++){  
        x=readIntSp(l,r);  
        a.push_back(x);   
    }
    x=readIntLn(l,r);
    a.push_back(x);
    return a;  
}
const ll MAX=3000300;   
ll sum_n=0;     
void dbug(vector<ll> a){
    for(auto t:a){
        cout<<t<<" ";
    }   
    cout<<endl; 
}
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;  
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
ll gt(ll n,ll freq,ll k){
    ll pw=(binpow(2,k,MOD-1)*freq)%(MOD-1);
    ll now=(binpow(n,pw+1,MOD)-binpow(n,freq,MOD)+MOD)*inverse(n-1,MOD);
    now%=MOD;
    return now;
}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
bool check_distinct(vector<ll> a){
    sort(a.begin(),a.end());
    ll n=a.size();
    for(ll i=1;i<n;i++){
        assert(a[i]!=a[i-1]);
    }
    return true;
}
ll g(ll x){
    return x;  
}
struct dsu{
    vector<ll> parent,height;
    ll n,len;
    dsu(ll n){
        this->n=n;
        parent.resize(n);
        height.resize(n);
        len=n;
        for(ll i=0;i<n;i++){
            parent[i]=i;
            height[i]=1;
        }
    }
    ll find_set(ll x){
        return find_set(x,x); 
    }
    ll find_set(ll x,ll orig){
        if(parent[x]==x){
            return x;
        }
        parent[orig]=find_set(parent[x]);
        return parent[orig]; 
    }
    void union_set(ll u,ll v){
        u=find_set(u),v=find_set(v);
        if(u==v){
            return;
        }
        len--; 
        if(height[u]<height[v]){
            swap(u,v); 
        }
        parent[v]=u;
        height[u]+=height[v]; 
    }
    ll getv(ll l){
        l=find_set(l);
        return height[l]; 
    }
};
void solve(){    
    ll n; cin>>n;
    vector<ll> a(n);
    for(auto &i:a){
        cin>>i;
        i%=3;
    }    
    ll ans=n;
    for(ll l=0;l<=2;l++){  
        for(ll r=0;r<=2;r++){
            ll now=0;
            vector<ll> b={l,r};
            for(ll i=2;i<n;i++){
                b.push_back((6-b[i-1]-b[i-2])%3);
            }
            for(ll i=0;i<n;i++){
                ll cur=a[i];
                while((cur%3)!=b[i]){
                    now++;
                    cur++;
                }
            }
            ans=min(ans,now);
        }
    }
    cout<<ans<<"\n";
    return;  
}  
int main(){
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                   
    freopen("input.txt", "r", stdin);                                                
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif         
    ll test_cases; cin>>test_cases;
    while(test_cases--){
        solve();
    }
    assert(sum_n<=g(1e5)); 
    assert(getchar()==-1);
    return 0;
}
Editorialist's code (Python)
from itertools import product
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    ans = 10**9
    for x, y in product([0, 1, 2], repeat=2):
        moves = x + y
        p1, p2 = a[0] + x, a[1] + y
        for i in range(2, n):
            cur = p1 + p2 + a[i]
            moves += (-cur)%3
            p1, p2 = p2, (a[i]-cur)%3
        ans = min(ans, moves)
    print(ans)
1 Like

Can someone please explain the solution more clearly

1 Like

So we have only 3 options for incrementing a value - [0, 1, 2]

The first two elements of the array A[0] and A[1], their effect cannot be checked if we use the logic of incrementing only the last(3rd) element of each subset.

Since it’s only 3 options and we have to check the effect of 0, 1, 2 on A[0] and A[1], that is 9 different combinations right?

(0, 0)
(0, 1)
(0, 2)
(1, 0)
(1, 1)
(1, 2)
(2, 0)
(2, 1)
(2, 2)

We add one of the following values to A[0], A[1] and then check the number of increments we are making throughout the array to get correct array that has every 3 consecutive elements divisible by 3, which can be done in O(n) complexity, making the complete complexity of the code as 9 * O(n).

Out of the 9 combinations that we try, the one that has the min increments is the answer.

It’s sort of a brute force approach.

for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    
    ans = 10**9
    for x in [0, 1, 2]:
        for y in [0, 1, 2]:
        
            moves = x + y # since we are adding x and y to A[0] and A[1]
            p1 = a[0] + x
            p2 = a[1] + y
            
            # now comes checking the no of moves required to change the all the values in the array so that all 3 consecutive numbers in the array are divisible by 3
            # first iteration will check if sum(A[0], A[1], A[2]) is divisible by 3 or not
            for i in range(2, n): 
                
                p3 = a[i]
                total = p1 + p2 + p3
                
                if total%3 != 0: # if sum is not divisible by 3, we find the increment
                    remainder = total % 3
                    increment = 3-remainder
                    
                    p3 += increment # since we want to carry this increment forward to other upcoming subsets
                    total += increment # making sure the total is divisible
                    moves += increment # adding the increment to moves
                
                p1 = p2 # A[1] is the first element now
                p2 = p3 # A[2] is the second the element now
                
                # next iteration will check if sum(A[i+1], A[i+2], A[i+3]) is divisible by 3 or not
            
            ans = min(ans, moves)
        
    print(ans)
2 Likes

I, a dumbass, thought this could be solved only using Dynamic Programming, so I started figuring out the dp state, recurrence relation and stuff. After several failed attempts, I realised it’s a simple problem whose solution has just nine different states. Man, I am a retard!!!

2 Likes

Us bro

in general, The Question generally focusses on making every consecutive 3 elements of the array divisible by 3.

Important Claim: Every time we take the consecutive three elements of the array, we can fix or increment every element of the array except the first two elements to make it divisible by 3.

Consider an Example ,
4
1 2 5 8

Here on taking 1 + 2 + 5 = 8 , on incrementing 5 → 6 , it becomes divisible by 3
similarly 2 + 6 + 8 = 16 , we can increment 8 twice to → 10,
1 2 6 16 ( possible final array)
This is just one of the possibility , Just try to play around with different numbers fixing first two numbers and you will get the hang of it.

So, In general first element of final array can have remainder → 0, 1, 2
similarly, second element of final array can have remainder → 0, 1, 2
Hence , 3 x 3 = 9 states are possible
so , we are just applying a brute force solution to find the minimum increments by traversing through each 9 states and finally adjusting the remaining numbers according to it.
You can imagine it like a window being formed and we are trying to slide that window.

C++ Code
#include <bits/stdc++.h>
using namespace std;
#define int long long int

int32_t main(){
    int test;
    cin >> test;

    while(test--){
        int size;
        cin >> size;

        int arr[size];

        //in order to input the elements
        for(int idx = 0; idx < size; idx += 1){
            cin >> arr[idx];
        }

        //main logic of the problem
        //there are in total nine possible states

        int result = INT_MAX;

        for(int first = 0; first < 3; first += 1){
            for(int second = 0; second < 3; second += 1){
                int increment = 0;

                int x = arr[0];
                int y = arr[1];

                //now we have to adjust the value of x
                while(x % 3 != first){
                    increment += 1;
                    x += 1;
                }

                //now we have to adjust the value of y
                while(y % 3 != second){
                    increment += 1;
                    y += 1;
                }

                for(int idx = 2; idx < size; idx += 1){
                    int z = arr[idx];
                    int total = x + y + z;

                    while(total % 3 != 0){
                        z += 1;
                        total += 1;
                        increment += 1;
                    }

                    x = y;
                    y = z;
                }

                result = min(result, increment);
            }
        }

        cout << result << '\n';
    }
}

Hope! I am able to clarify the doubt.

Note: I came with this solution from solution video and little bit of brainstorming. Feel free to ping me for any doubts. Cheers !! ;D

1 Like

Wow, nice solution! I butchered this problem with straightforward dp sadly XD
My DP solution+code here: Solution

UPD: My dp solution is an extravagant way of writing the editorial solution in hindsight XD

1 Like

can you explain a bit more how you thought of that dp method while tackling the problem in your explanation for the answer ? basically how you got that there is possibility of repeating common subproblems in this question

yeah i would also like to know the way

Why the following approach fails:

import java.util.*;
import java.lang.*;
import java.io.*;

/* Name of the class has to be "Main" only if the class is public. */
class Codechef
{
    static int recUtil(int i,int arr[]){
        if(i+2==arr.length){
            return 0;
        }
        int sum=arr[i]+arr[i+1]+arr[i+2];
        int toAdd=3-sum%3;
        int min=Integer.MAX_VALUE/2;
        if(toAdd==3){
            return recUtil(i+1,arr);
        }
        //incrementing single
        for(int k=i;k<=i+2;k++){
            arr[k]+=toAdd;
            min=Math.min(min,toAdd+recUtil(i+1,arr));
            arr[k]-=toAdd;
        }
        if(toAdd==2){
            arr[i]+=1;
            arr[i+1]+=1;
            min=Math.min(min,toAdd+recUtil(i+1,arr));
            arr[i]-=1;
            arr[i+1]-=1;
            
            arr[i+1]+=1;
            arr[i+2]+=1;
            min=Math.min(min,toAdd+recUtil(i+1,arr));
            arr[i+1]-=1;
            arr[i+2]-=1;
            
            arr[i]+=1;
            arr[i+2]+=1;
            min=Math.min(min,toAdd+recUtil(i+1,arr));
            arr[i]-=1;
            arr[i+2]-=1;
        }
        return min;
    }
	public static void main (String[] args) throws java.lang.Exception
	{
		Scanner input= new Scanner(System.in);
		int T=input.nextInt();
		while(T--!=0){
		    int size=input.nextInt();
		    int arr[]=new int[size];
		    for(int i=0;i<size;i++){
		        arr[i]=input.nextInt();
		    }
		    System.out.println(recUtil(0,arr));
		}
	}
}

I’ve added more detail. I’m confused, do you mean overlapping subproblem property? If you can’t see why it’s overlapping then you should learn dp basics and recursion. If you meant you can’t see the optimal substructure property intuitively, same applies.

By the way for anyone wondering about the math formula mentioned in dan’s solution to make
(i % mod) = (j % mod) for any general mod by adding 1’s to i is = (j - i + mod) % mod for bigger mods

2 Likes

Can you explain,why your solution passes without taking long ?

Bro can you please explain more of why did you think it as dp problem,and can you please explain your code ,please

why sliding window solution not get accepted . infact it ran for 2 test cases. but ultimately failed

was this solution accepted?

You should Learn about time complexity. In case of normal sliding window, the time limit gets exceeded.

if you see the first part of the code , I have defined int itself as long long int and that’s why I used int32_t for main function . This is usually a hack used by me as I find it very convenient.

i understood now how you think of the recursion approach and which finally led to a dp approach basically here i was not cleared from first what dp states i would need to choose but your sol seems make it clear . is there a good way of analyzing what dp states one needs to consider in a problem ?

Ok, solving more dp problems leads to making things more intuitive. There are various blogs out there on how to come up with dp solution eg: (Dynamic Programming: Prologue - Codeforces) ← There are 2 other blogs inside this one, btw.

I also started out by actually mastering recursive style dp first which is more intuitive especially when you already understand recursion well so you could try that.

2 Likes