CEFDIV-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Prasant Kumar
Tester: Aryan, Satyam
Editorialist: Devendra Singh

DIFFICULTY:

2533

PREREQUISITES:

Dynamic programming

PROBLEM:

You are given an array A of size N.

A partitioning of the array A is the splitting of A into one or more non-empty contiguous subarrays such that each element of A belongs to exactly one of these subarrays.

Find the number of ways to partition A such that the parity of the sum of elements within the subarrays is alternating. In other words, if S_i denotes the sum of the elements in the i-th subarray, then either

  • S_1 is odd, S_2 is even, S_3 is odd and so on.
  • or S_1 is even, S_2 is odd, S_3 is even and so on.

For example if A = [1, 2, 3, 3, 5]. One way to partition A is [1, 2] [3, 3] [5]. Another way to partition A is [1] [2] [3] [3, 5]. Note that there exists more ways to partition this array.

Since the answer may be large, output it modulo 998244353.

QUICK EXPLANATION

Let dp[i][parity] represent the number of valid partitions of Prefix P_i such that the last subarray chosen has parity of the sum as parity. Let sum[i] represent the sum of first i elements of the A.
Initialize dp[0][0]=1 and dp[0][1]=1

Case 1: sum[i] is even.

  • dp[i][0]= \underset{sum[j]|2}{\sum^{i-1}_{j=0}}\,dp[j][1]
  • dp[i][1]= \underset{sum[j]\nmid2}{\sum^{i-1}_{j=0}}\,dp[j][0]

Case 2: sum[i] is odd.

  • dp[i][1]= \underset{sum[j]|2}{\sum^{i-1}_{j=0}}\,dp[j][0]
  • dp[i][0]= \underset{sum[j]\nmid2}{\sum^{i-1}_{j=0}}\,dp[j][1]

EXPLANATION:

Let us calculate the number of valid partitions of A when a fixed prefix P_i (array formed by first i elements of A) is the first subarray of the partition. Let the sum of elements of P_i be even. Then the problem is reduced to finding the number of valid partitions of array B formed by the remaining N-i elements (other than P_i) of array A without changing the order of the elements and starting with the first subarray of the partition having odd sum.

The case when the sum of elements of P_i is odd is similar except now the problem is reduced to finding the number of valid partitions of array B starting with the first subarray of the partition having even sum.

This problem has both optimal substructure property and overlapping subproblems. These kind of problems can be solved using dynamic programming.

Let dp[i][parity] represent the number of valid partitions of Prefix P_i such that the last subarray chosen has parity of the sum as parity. Let sum[i] represent the sum of first i elements of the A.
Initialize dp[0][0]=1 and dp[0][1]=1

Case 1: sum[i] is even.

  • dp[i][0]= \underset{sum[j]|2}{\sum^{i-1}_{j=1}}\,dp[j][1] \:i.e. sum of all valid partitions till now such that the last chosen subarray of these partitions has odd sum and the sum of all elements in the partition is even.
  • dp[i][1]= \underset{sum[j]\nmid2}{\sum^{i-1}_{j=1}}\,dp[j][0] :i.e. sum of all valid partitions till now such that the last chosen subarray of these partitions has even sum and the sum of all elements in the partition is odd.

Case 2: sum[i] is odd.

  • dp[i][1]= \underset{sum[j]|2}{\sum^{i-1}_{j=1}}\,dp[j][0] \:i.e. sum of all valid partitions till now such that the last chosen subarray of these partitions has even sum and the sum of all elements in the partition is even.
  • dp[i][0]= \underset{sum[j]\nmid2}{\sum^{i-1}_{j=1}}\,dp[j][1] :i.e. sum of all valid partitions till now such that the last chosen subarray of these partitions has odd sum and the sum of all elements in the partition is odd.

These sums can be maintained using just four variables. The answer would be dp[N][0]+dp[N][1] For details of the implementation see the code attached below.

TIME COMPLEXITY:

O(N) for each test case.

SOLUTION:

Setter's Solution
#include<bits/stdc++.h>

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/trie_policy.hpp>
using namespace std;
#define int long long
#define endl "\n"
using namespace __gnu_pbds;


const int sz=2e5+10;
int arr[sz];
int n;
int dp[sz][3][3];

int mod=998244353;

int solve(int i,int prev,int cur){
	
	
	if(i+1==n){
		return (prev != ((cur+arr[i])%2));
	}
	
	if(dp[i][prev][cur] != -1){
		return dp[i][prev][cur];
	}
	
	int ans=solve(i+1,prev,(cur+arr[i])%2);
	
	if(prev != ((cur+arr[i])%2)){
		ans+=solve(i+1,(cur+arr[i])%2,0);
	}
	
	return dp[i][prev][cur]=ans%mod;
}

signed main(){
//	
//	freopen("chef12.txt","r",stdin);
//	freopen("chef12_output.txt","w",stdout);
	ios_base::sync_with_stdio(0) , cin.tie(0);
	int t;cin>>t;
	while(t--){
		cin>>n;
		for(int i=0;i<n;i++){
			cin>>arr[i];
		}
		
		for(int i=0;i<n+5;i++){
			for(int j=0;j<3;j++){
				for(int k=0;k<3;k++){
					dp[i][j][k]=-1;
				}
			}
		}
		// dp(i,parity_of_sum_of_pervious_block, parity_of_current_ongoing_sum);
		cout<<solve(0,2,0)<<endl;
	}
	return 0;
}
Tester's Solution
// #pragma GCC optimize("O3")
// #pragma GCC target("popcnt")
// #pragma GCC target("avx,avx2,fma")
// #pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;     
using namespace std;
#define ll long long 
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;   
#define pb push_back                 
#define mp make_pair               
#define nline "\n"                             
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;   
const ll MAX=500500;   
void solve(){       
    ll n; cin>>n;
    vector<ll> a(n+5),sum(n+5,0);    
    for(ll i=1;i<=n;i++){
        cin>>a[i]; 
        sum[i]=(sum[i-1]+a[i])%2; 
    }
    map<ll,map<ll,pair<ll,ll>>> direct;
    direct[0][0]={0,1};
    direct[1][0]={1,1};
    direct[0][1]={1,0};
    direct[1][1]={0,0};
    vector<vector<ll>> dp(5,vector<ll>(5,0));
    ll val=0;
    for(ll i=1;i<=n;i++){
        val=0;
        vector<vector<ll>> now=dp;
        for(ll j=0;j<2;j++){
            auto it=direct[sum[i]][j]; 
            val+=dp[it.f][it.s];
            now[sum[i]][j]+=dp[it.f][it.s]; 
            now[sum[i]][j]%=MOD; 
        }
        val++;
        now[sum[i]][sum[i]]++;
        swap(dp,now);
    }
    val%=MOD;
    cout<<val<<nline;
    return;          
}                                     
int main()                                                                               
{    
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                              
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif         
    ll test_cases=1;               
    cin>>test_cases; 
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Editorialist's Solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(), _obj.end()
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
const int N = 1e5 + 11, mod = 998244353;
ll max(ll a, ll b) { return ((a > b) ? a : b); }
ll min(ll a, ll b) { return ((a > b) ? b : a); }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
void sol(void)
{
    int n;
    cin >> n;
    vll v(n);
    for (int i = 0; i < n; i++)
        cin >> v[i], v[i] %= 2;
    ll p[n + 1], dp[2],oddend[2], evenend[2];
    memset(dp, 0, sizeof(dp));
    memset(oddend, 0, sizeof(oddend));
    memset(evenend, 0, sizeof(evenend));
    memset(p,0,sizeof(p));
    dp[1] = dp[0] = 1;
    evenend[0] = oddend[0] = 1;
    for (int i = 1; i <= n; i++)
        p[i] = p[i - 1] + v[i - 1];
    for (int i = 1; i <= n; i++)
    {
        if(p[i]%2==0)
        {
            dp[0]=oddend[0];
            dp[1]=evenend[1];
            evenend[0]+=dp[0];
            oddend[0]+=dp[1];
        }
        else
        {
            dp[0] = oddend[1];
            dp[1] = evenend[0];
            evenend[1]+=dp[0];
            oddend[1]+=dp[1];
        }
        dp[0] %= mod;
        dp[1] %= mod;
        evenend[0]%=mod;
        evenend[1]%=mod;
        oddend[0]%=mod;
        oddend[1]%=mod;
    }
    cout << (dp[0] + dp[1]) % mod << '\n';
    return;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);
    int test = 1;
    cin >> test;
    while (test--)
        sol();
}
6 Likes

I have seen simpler approaches in submissions
Tabulation :
https://www.codechef.com/viewsolution/63618685
Recursive:
https://www.codechef.com/viewsolution/63473335

This is essentially the same dp solution with memory optimization as done in my solution. This can be done since every state i depends only on i-1.

1 Like

Yes…but the 2nd recursive solution looks less daunting than the iterative one


#include <bits/stdc++.h>

using namespace std;

#define int long long int 

const int mod = 998244353 ;

int32_t main()
{
    int t ;
    cin >> t ;
    while(t--){
        int n , sum = 0 , ans = 0 ;
        cin >> n ;
        vector<vector<int>> dp(2 , vector<int>(2 , 0)) ;
        vector<vector<int>> dp_(2 , vector<int>(2)) ;
        // where coloumn represents prevsum of odd or even , 
        // row represents previous sum ended with even or odd subarray
        for(int i = 0 , x ; i < n ; i ++){
            cin >> x ;
            sum = ((sum + (x&1))&1) ;
            for(int j = 0 ; j < 2 ; j ++){
                for(int k = 0 ; k < 2 ; k ++){
                    if(sum ^ j ^ k){
                        dp_[sum][sum ^ j] = dp[j][k] % mod ;
                        //cout << sum << ',' << (sum ^ j) << '=' << dp[sum][sum^j] << ' ' ; 
                        if(i == n - 1) ans = (ans + dp[j][k]) % mod ;
                    }
                }
            }
            for(int j = 0 ; j < 2 ; j ++){
                for(int k = 0 ; k < 2 ; k ++){
                    if(sum ^ j ^ k) dp[sum][sum ^ j] = (dp[sum][sum^j] + dp_[sum][sum ^ j]) % mod ;
                }
            }
            dp[sum][sum] = (dp[sum][sum] + 1LL) % mod ;
        } 
        //cout << '\n' ;
        cout << (ans + 1LL) % mod << '\n' ;
    }

    return 0;
}

Here is similar approach using DP & tabulation method with a 2 - D dp vector .

Before start reading the approach we are meant to say that the subarray type even odd means the subarray sum is either even(0) or odd(1) .

In this approach i move on this way of thinking that , the prefix sum of the given array may be either even or odd if it is even the the last bit will be ‘0’ otherwise it will be ‘1’ for odd sum .
Since we have to calculate in how many ways we can divide the whole array into partitions so that there will be (0101…/1010…) manner of partitions will be followed . So , we will deal with ‘0’ and ‘1’ as the prefix sum .
Suppose the prefix sum from 0’th index to i’th index is even then sum will be = 0 and we can partition it in 2 different ways either we can take out an even subarray including i’th index as the end index of the subarray or
we can take out an odd subarray including i’th index as the end index of that subarray .
If we want cut out an even subarray then we have to check at which indices previously an even prefix sum has apprered because till i’th index sum is even so to cut out an even subarray we have to check
previously occured even prefix sum .
Now let’s suppose at some j’th index an even prefix sum occured where j < i . Now at j’th index we have to check whether j’th index is ended with an odd or even subarray . If it is odd subarray then it’s fine otherwise it’s
wrong because as per question we have to maintain (0101…/1010…) manner since we are willing to end the i’th index with an even subarray so at j’th index we must have to finish with odd subarray .
So similary if we want to end at i’th index with an odd subarray we must have to find first at which j’th indices an odd prefix sum occured & at j’th indices whether it is finished with an even subarray or not .

So , in this problem we are able to find out that there are 4 states which depends upon two variables .
The 4 states are till i’th index the prefix sum is odd/even , & is i’th index finished with an even/odd subarray .
And The two variables are the prefix sum(even/odd type) & the subarray(even/odd type) .

So , in the code we are taking a 2D dp vector where its rows represents whether the prefix sum till i’th index is even(‘0’)/odd(‘1’) and coloumn represents the subarray sum ended at i’th index is either even(‘0’)/odd(‘1’) .
after calculating the ‘sum’ we are are running 2 loops where ‘j’ represents row & ‘k’ represents coloumn .
We first check sum ^ j ^ k == 1 or not because sum represnts prefix sum is even or odd , j is either 1/0 if we take j == 0 means we are cutting out an (sum ^ j) subarray if sum is even(‘0’) then 0^0 = 0 means it’s even subarray
if sum == 1 then 0^1 = 1 means odd subarray similarly for j == 1 the subarray type will be vice versa .
If sum ^ j == 1 then k must have to be ‘0’ means i’th index ending with an odd subarray which starts at some index l’th index(say) .
In precise way array[l , l + 1 , … , i] = odd/‘1’ and at l’th index we must have to finished with k = 0(even) subarray .

So , if sum^j^k == 1 then we are proceeding with and storing in a secondary 2D vector dp_[sum][sum^j] = dp[j][k] ;
we are using a secondary dp_ because suppose if sum == 1 and j == 0 then k == 0 .
then sum^j == 1 then if we use our primary or our main/primary dp vector then it will store dp[1][1] += dp[0][0] ;
& when we will come to j==1 then we will do dp[1][0] += dp[1][1] which will include our currently stored information
and it’s aming that there is some combinations which are finishied at i’th index with odd subarray & it is used to form an even subaray as well at that same index i .
which is wrong because simultaneously it’s not possbble to present both even at odd subarray at i’th index .

And at last we will add dp_[sum][sum^j] in dp[sum][sum^j] & we will add dp[sum][sum] += 1 bacause we cna take the whole prefix sum till i’th index as tha subarray of type sum .
and in our Ans variable we will add all (0101…/1010…) combinations upto (N - 2)'th index since we cutting out an even & odd subarray at last (N-1)'th index & add +1 as well for taking the whole array .
Here N is the size of the array & we are dealing with ‘0’ based indexing .

So , hope this editorial will help to understand how to think dp to them who are facing issues with how to find out subproblems , states & varibles which will be define the DP problem type .
Happy Coding Guys & feel free to share ur doubts & commnents :slight_smile: .

6 Likes

I have used a similar approach like the editorial but i am getting TLE for some test cases.



long long int ways(int n,int status,int a[],int dp[][2]){
	if(n < 0)
		return 1;
	if(dp[n][status] != -1)
		return dp[n][status];
	long long int total = 0,s=0;
	for(int i=n;i>=0;i--){
		s += a[i];
		if((s%2) == status){
			total = (ways(i-1,(status+1)%2,a,dp)%MOD + total%MOD)%MOD;
		}
	}
	return dp[n][status] = total;
}


int main(){

	/*freopen("input.txt","r",stdin);
	freopen("output.txt","w",stdout);*/
	
	int test,n,a[MAX],dp[MAX][2];
	cin >> test;
	while(test--){
		cin >> n;
		for(int i=0;i<n;i++){
			cin >> a[i];
			dp[i][1] = dp[i][0] = -1;
		}
		long long int ans = (ways(n-1,0,a,dp)%MOD+ways(n-1,1,a,dp)%MOD)%MOD;
		cout << ans << endl;

	}
	return 0;
}

Your codes time complexity is N^2.

You can refer to setters solution for recursive code.

import sys
sys.setrecursionlimit(100000000)
from functools import lru_cache

@lru_cache(None)
def dp(start, req):
    if start == n: return 1
    _sum = sol = 0
    for end in range(start + 1, n + 1):
        _sum ^= arr[end - 1]
        if _sum == req:
            sol += dp(end, _sum ^ 1)
    return sol



for _ in range(int(input())):
    n = int(input())
    arr = [int(x) for x in input().split()]
    arr = [x & 1 for x in arr]
    print(dp(0, 0) + dp(0, 1))
    dp.cache_clear()

I am getting TLE on a few cases.
Could some plz point out why ?

Could some plz point out why ?

Because it is N^2

Oh yea…thanks for pointing that out !

import sys
sys.setrecursionlimit(100000000)
from functools import lru_cache
mod=998244353 

@lru_cache(None)
def dp(start, _sum , prev):
    b=0
    _sum^=arr[start]
    if start == n-1: return 1 if prev!=_sum else 0
    a=dp(start+1,_sum,prev)
    if _sum!=prev:
        b=dp(start+1,0,_sum)
    return (a+b)%mod


for _ in range(int(input())):
    n = int(input())
    arr = [int(x) for x in input().split()]
    arr = [x & 1 for x in arr]
    print(dp(0,0,None))
    dp.cache_clear()

Could someone plz point out why this is throwing Runtime Error (SIGSEGV) ?

Fantastic explanation, pal! Do you happen to know few more problems like this one?

1 Like

It’s a great pleasure to see that u hav understood the logic , process & the code everything !!:slightly_smiling_face:
I’m actually really happy to see that @aamodpandey u r interested to upsolve the previous contest questions & came here , & have shown ur interest towards reading this type of big & eye paining editorial . Just do ur best , in this way . U will defintely be high rated coder in upcoming future , Carry on Bro :+1: .
U r in the right track bro .
Yes , there r few more more problems like this one which u will find in Codechef , Leetcode , Codeforces & Atcoder .
i cann’t remember any problem like this one currently but if someting i find i will defintely add here in future !! :blush::+1:

1 Like

i’ve knocked u through the mail .

1 Like

ight!