PROBLEM LINK:
Setter: Ritesh Gupta
Tester: Radoslav Dimitrov
Editorialist: Raja Vardhan Reddy
DIFFICULTY:
Easy
PREREQUISITES:
NIL
PROBLEM:
You are given a sequence A_1,A_2,…,A_N. For each k (1≤k≤N), let’s define a function f(k) in the following way:
- Consider a sequence B_1,B_2,…,B_N, which is created by setting A_k=0. Formally, B_k=0 and B_i=A_i for each valid i≠k .
- f(k) is the number of ways to split the sequence B into two non-empty contiguous subsequences with equal sums.
Find the sum S=f(1)+f(2)+…+f(N).
EXPLANATION
Let us call i^{th} split as splitting the sequence into A_1,A_2 \cdots,A_i and A_{i+1}, \cdots,A_{n}. Here A_1,A_2 \cdots,A_i are considered on the left half of the split, and A_{i+1}, \cdots,A_{n} are considered on the right half of the split.
Let diff[i] be the difference between the sum of elements on the left half and sum of elements on the right half of the ith split.
i.e diff[i]=\sum_{j=0} ^{j<=i} A_j -\sum_{j=i+1}^{j<=n}A_j .
We call a split valid, if it splits into two non-empty sequences with equal sum.
Hence, i^{th} split is valid initially if diff[i]=0 and i!=n
When we change A_i to 0, these valid splits become invalid(since A_i !=0), and some invalid splits become valid.
Let’s look at what kind of splits become valid!
j^{th} split becomes valid if:
- j<i and diff[j]=-A_i : If j<i , A_i falls in the right half of the split. Hence when A_i is changed to 0, diff[j] increases by A_i and becomes 0.
- j>=i and j!=n and diff[j]=A[i] : if j>=i, A_i falls in the left half of the split. Hence when A_i is changed to 0, diff[j] decreases by A_i and becomes 0.
Hence, f(i) (Let f(i) be the number of valid splits when A_i is changed to 0) can be calculated as:
f(i)= (number of j<i and diff[j]=-A_i) + (number of j>=i and j!=n and diff[j]=A_i).
This can be done using two maps in O(log(n)). One map for storing counts of diff[j] for j<i and another one for storing counts of diff[j] for j>=i.
TIME COMPLEXITY:
Computation of diff[i] can be done by calculating prefix sums in O(n) time.
Computation of f(i) : O(log(n)) for each i. Therefore, O(nlog(n)) for all i.
Total Complexity: O(n+nlog(n)) = O(nlog(n)) for each test case.
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
#define int long long
using namespace std;
int a[200010],b[200010];
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int t;
cin >> t;
while(t--)
{
int n;
cin >> n;
for(int i=0;i<n;i++)
cin >> a[i];
unordered_map <int,int> m1,m2;
int cnt = 0;
for(int i=n-1;i>=0;i--)
{
cnt += a[i];
m2[cnt]++;
}
int ans = 0;
int sum = cnt;
cnt = 0;
for(int i=0;i<n;i++)
{
m2[sum-cnt]--;
int cnt1 = sum - a[i];
if(cnt1%2 == 0)
ans += (m1[cnt1/2] + m2[cnt1/2]);
cnt += a[i];
m1[cnt]++;
}
cout << ans << endl;
}
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
int read_int();
int n;
int a[MAXN];
void read() {
n = read_int();
for(int i = 0; i < n; i++) {
a[i] = read_int();
}
}
unordered_map<int64_t, int> scnt, pcnt;
void solve() {
scnt.clear();
pcnt.clear();
int64_t answer = 0;
int64_t suff = 0, pref = 0;
for(int i = 0; i < n; i++) {
suff += a[i];
scnt[a[i]]++;
}
for(int mid = 0; mid + 1 < n; mid++) {
pref += a[mid];
suff -= a[mid];
scnt[a[mid]]--;
pcnt[a[mid]]++;
int64_t rem_pref_val = pref - suff;
int64_t rem_suff_val = suff - pref;
if(pcnt.find(rem_pref_val) != pcnt.end()) answer += pcnt[rem_pref_val];
if(scnt.find(rem_suff_val) != scnt.end()) answer += scnt[rem_suff_val];
}
cout << answer << endl;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int T;
T = read_int();
while(T--) {
read();
solve();
}
return 0;
}
const int maxl = 100000;
char buff[maxl];
int ret_int, pos_buff = 0;
void next_char() { if(++pos_buff == maxl) fread(buff, 1, maxl, stdin), pos_buff = 0; }
int read_int()
{
ret_int = 0;
int mns = 0;
for(; buff[pos_buff] < '0' || buff[pos_buff] > '9'; next_char()) mns |= buff[pos_buff] == '-';
for(; buff[pos_buff] >= '0' && buff[pos_buff] <= '9'; next_char())
ret_int = ret_int * 10 + buff[pos_buff] - '0';
if(mns) ret_int *= -1;
return ret_int;
}
Editorialist's Solution
//raja1999
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
//#define int ll
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
//std::ios::sync_with_stdio(false);
map<int,int>mapi;
map<int,int>mapi1;
int a[212345];
ll diff[212345];
int main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t,t1;
cin>>t;
// t=1;
t1=t;
while(t--){
ll n,i,ans=0;
ll ss,s;
mapi.clear();
mapi1.clear();
cin>>n;
s=0;
for(i=0;i<n;i++){
cin>>a[i];
s+=a[i];
}
ss=0;
for(i=0;i<n;i++){
ss+=a[i];
diff[i]=ss-(s-ss);
}
for(i=0;i<n-1;i++){
if(abs(diff[i])>inf){
continue;
}
mapi[diff[i]]++;
}
for(i=0;i<n;i++){
ans+=mapi[a[i]];
ans+=mapi1[-1*a[i]];
if(abs(diff[i])>inf){
continue;
}
mapi[diff[i]]--;
mapi1[diff[i]]++;
}
cout<<ans<<endl;
}
return 0;
}
Feel free to Share your approach, If it differs. Suggestions are always welcomed.