SPLITMAX - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Abhinav Sharma
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

1313

PREREQUISITES:

None

PROBLEM:

We define the value of an array A to be

\sum_{i=1}^N \sum_{\substack{j = 1 \\ j \neq i}}^N (A_i \cdot A_j)

You have an array C. In one move, you can split some C_i into two smaller positive elements that sum to it. What’s the maximum possible value you can attain?

EXPLANATION:

Let S = A_1 + A_2 + \ldots + A_N.

The formula for value specifically excludes the product of an element with itself. When we rewrite it a bit to add that part back in,

\sum_{i=1}^N \sum_{\substack{j = 1 \\ j \neq i}}^N (A_i \cdot A_j) = \sum_{i=1}^N \sum_{\substack{j = 1}}^N (A_i \cdot A_j) - \sum_{i=1}^N A_i^2 \\ = \sum_{i=1}^N A_i \left ( \sum_{j=1}^N A_j\right ) - \sum_{i=1}^N A_i^2= \sum_{i=1}^N A_i \cdot S - \sum_{i=1}^N A_i^2\\ = S^2 - \sum_{i=1}^N A_i^2

Notice that the first part of this formula, S^2, is a constant. Further, S does not change when we perform an operation, since an element is split into two with the same sum.

So, maximizing the value of the array is equivalent to minimizing the value of \sum_{i=1}^N A_i^2.

This is obviously achieved when each A_i = 1. When this happens, the length of the array will be exactly S, and so \sum_{i=1}^N A_i^2 = \sum_{i=1}^S 1^2 = S.

The final answer is hence simply S^2 - S.

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;

#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 

#define ll long long
#define db double
#define el "\n"
#define ld long double
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define all(ds) ds.begin(), ds.end()
#define ff first
#define ss second
#define pb push_back
#define mp make_pair
typedef vector< long long > vi;
typedef pair<long long, long long> ii;
typedef priority_queue <ll> pq;
#define o_set tree<ll, null_type,less<ll>, rb_tree_tag,tree_order_statistics_node_update> 

const ll mod = 998244353;
const ll INF = (ll)1e18;
const ll MAXN = 1000006;

ll po(ll x, ll n){ 
    ll ans=1;
    while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
    return ans;
}


int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
   
    int T=1;
    cin >> T;
    while(T--){
    
        ll sum = 0;
        ll x;
        int n;
        cin>>n;
        
        rep(i,n){
            cin>>x;
            sum+=x;
        }

        sum%=mod;

        cout<<(sum*sum-sum)%mod<<el;
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
    return 0;
} 
Tester's code (C++)
//Utkarsh.25dec
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
#include <array>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 998244353
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0;
void solve()
{
    int n=readInt(1,200000,'\n');
    sumN+=n;
    assert(sumN<=200000);
    int A[n+1];
    memset(A,0,sizeof(A));
    ll sum=0;
    for(int i=1;i<=n;i++)
    {
        if(i==n)
            A[i]=readInt(1,1000000000,'\n');
        else
            A[i]=readInt(1,1000000000,' ');
        sum+=A[i];
    }
    sum%=mod;
    ll ans=(sum*sum)+(mod-sum);
    ans%=mod;
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,1000,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    print((sum(a) * sum(a) - sum(a))%mod)
1 Like

forum

Can anyone explain to me why this is true ?
It’s a part of a sum propriety ?
like: sum(sum(A*B)) = sum(A) * sum(B) ?

Thanks :innocent:

Oh, thanks for pointing that out: that’s a typo, the inner index should be j. I’ve fixed it now.

As for the property, that’s just the standard summation property of pulling out the multiplier when it’s a constant:

\sum_{i=1}^N (kA_i) = k\sum_{i=1}^NA_i
1 Like