FIZZBUZZ2310 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: naisheel, jalp1428
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2702

PREREQUISITES:

Math

PROBLEM:

The value of an array A of length N is

\sum_{i=1}^{N-1} \left(\sum_{j=1}^i A_j - \sum_{j=i+1}^N A_j \right)

Given an array, find the sum of values of all its subarrays.

EXPLANATION:

Let’s write out the entire expression we want to compute.

\sum_{L=1}^N \sum_{R=L}^N \sum_{i=L}^{R-1} \left(\sum_{j=L}^i A_j - \sum_{j=i+1}^R A_j \right)

Notice that this can be broken up into two separate quadruple summations: one with the inner part being \sum_{j=L}^i A_j and the other with the inner part being \sum_{j=i+1}^R A_j.
Both can be computed similarly, so I’ll discuss only the first one. That is, we’ll compute

\sum_{L=1}^N \sum_{R=L}^N \sum_{i=L}^{R-1} \sum_{j=L}^i A_j

Let’s think about contribution instead.
How many times is a certain A_j included in this summation?
It’s not hard to see that A_j is counted once for each choice of L \leq j \leq i \lt R.
That is,

  • Pick a subarray [L, R] that contains j.
  • For this subarray, A_j is added once for each (non-full) prefix that includes it, which translates to picking a j that’s \geq i but \lt R.

Counting this isn’t too hard: there are i choices for L (anything to the left of i, including it), and \frac{(N-i)\cdot (N-i+1)}{2} choices for pairs of (j, R) (there are N-i+1 indices to the right of i, from which we want to pick 2 distinct indices).

So, the overall summation comes out to be just

\sum_{i=1}^N A_i \cdot i \cdot \frac{(N-i)\cdot (N-i+1)}{2}

Similarly, find the value of the other summation and subtract it from this to get the final answer.
An easy way to do this is to observe that the second summation is pretty much the same as the first one, just on the reversed array instead (so the multipliers stay exactly the same and don’t have to be recomputed).

Note that you might multiply up to 4 integers at a time, so make sure to take care that overflow doesn’t occur.

TIME COMPLEXITY

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;

// -------------------- Input Checker Start --------------------

// This function reads a long long, character by character, and returns it as a whole long long. It makes sure that it lies in the range [l, r], and the character after the long long is endd. l and r should be in [-1e18, 1e18].
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            if(!(fi == -1))
                cerr << "- in between integer\n";
            assert(fi == -1);
            is_neg = true; // It's a negative integer
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0'; // fi is the first digit
            cnt++;
            
            // There shouldn't be leading zeroes. eg. "02" is not valid and assert will fail here.
            if(!(fi != 0 || cnt == 1))
                cerr << "Leading zeroes found\n";
            assert(fi != 0 || cnt == 1); 
            
            // "-0" is invalid
            if(!(fi != 0 || is_neg == false))
                cerr << "-0 found\n";
            assert(fi != 0 || is_neg == false); 
            
            // The maximum number of digits should be 19, and if it is 19 digits long, then the first digit should be a '1'.
            if(!(!(cnt > 19 || (cnt == 19 && fi > 1))))
                cerr << "Value greater than 1e18 found\n";
            assert(!(cnt > 19 || (cnt == 19 && fi > 1))); 
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                // We've reached the end, but the long long isn't in the right range.
                cerr << "Constraint violated: Lower Bound = " << l << " Upper Bound = " << r << " Violating Value = " << x << '\n'; 
                assert(false); 
            }
            return x;
        }
        else if((g == ' ') && (endd == '\n'))
        {
            cerr << "Extra space found. It should instead have been a new line.\n";
            assert(false);
        }
        else if((g == '\n') && (endd == ' '))
        {
            cerr << "A new line found where it should have been a space.\n";
            assert(false);
        }
        else
        {
            cerr << "Something weird has happened.\n";
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    if(!(l <= cnt && cnt <= r))
        cerr << "String length not within constraints\n";
    assert(l <= cnt && cnt <= r);
    return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() 
{ 
    char g = getchar();
    if(g != EOF)
    {
        if(g == ' ')
            cerr << "Extra space found where the file shold have ended\n";
        if(g == '\n')
            cerr << "Extra newline found where the file shold have ended\n";
        else
            cerr << "File didn't end where expected\n";
    }
    assert(g == EOF); 
}

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}

bool checkStringContents(string &s, char l, char r) {
    for(char x: s) {
        if (x < l || x > r) {
            cerr << "String is not valid\n";
            return false;
        }
    }
    return true;
}

bool isStringBinary(string &s) {
    return checkStringContents(s, '0', '1');
}

bool isStringLowerCase(string &s) {
    return checkStringContents(s, 'a', 'z');
}
bool isStringUpperCase(string &s) {
    return checkStringContents(s, 'A', 'Z');
}

bool isArrayDistinct(vector<int> a) {
    sort(a.begin(), a.end());
    for(int i = 1 ; i < a.size() ; ++i) {
        if (a[i] == a[i-1])
        return false;
    }
    return 1;
}

bool isPermutation(vector<int> &a) {
    int n = a.size();
    vector<int> done(n);
    for(int x: a) {
      if (x <= 0 || x > n || done[x-1]) {
        cerr << "Not a valid permutation\n";
        return false;
      }
      done[x-1]=1;
    }
    return true;
}

// -------------------- Input Checker End --------------------


#define ll long long int
const int MOD=998244353;

void solve(){
    int n=readIntLn(1,1e5);
    int arr[n];
    for(int i=0;i<n;i++){
        if(i!=n-1){
            arr[i]=readIntSp(0,1e9);
        }
        else{
            arr[i]=readIntLn(0,1e9);
        }
    }
    vector<ll> nums(n+1,0);

    ll sum=0;
    vector<ll> pre(n);
    ll mul=1;
    for(int i=0;i<n;i++){
        pre[i]=(i?pre[i-1]:0)+arr[i];
        pre[i]%=MOD;
        sum+=(mul*arr[i])%MOD;
        sum%=MOD;
        if(n%2){
            if(i>=n/2){
                mul--;
            }
            else{
                mul++;
            }
        }
        else{
            if(i+1==n/2){
                
            }
            else if(i+1>n/2){
                mul--;
            }
            else{
                mul++;
            }
        }
    }
    nums[0]=sum%MOD;
    int l=n/2;
    int r=n-1;
    for(int i=1;i<=n;i++){
        sum-=(pre[r]-(l==0?0:pre[l-1]));
        sum+=MOD;
        sum%=MOD;
        r--;
        if((n%2)^(i%2)){
            l--;
        }
        nums[i]=sum;
    }


    //reversed the array subtract same things
    reverse(arr,arr+n);
    sum=0;
    mul=1;
    for(int i=0;i<n;i++){
        pre[i]=(i?pre[i-1]:0)+arr[i];
        pre[i]%=MOD;
        sum+=(mul*arr[i])%MOD;
        sum%=MOD;
        if(n%2){
            if(i>=n/2){
                mul--;
            }
            else{
                mul++;
            }
        }
        else{
            if(i+1==n/2){
                
            }
            else if(i+1>n/2){
                mul--;
            }
            else{
                mul++;
            }
        }
    }
    nums[0]-=sum;
    nums[0]+=MOD;
    nums[0]%=MOD;
    l=n/2;
    r=n-1;
    for(int i=1;i<=n;i++){
        sum-=(pre[r]-(l==0?0:pre[l-1]));
        sum+=MOD;
        sum%=MOD;
        r--;
        if((n%2)^(i%2)){
            l--;
        }
        nums[i]-=sum;
        nums[i]+=MOD;
        nums[i]%=MOD;
    }
    
    ll ans=0;
    for(int i=0;i<=n;i++){
        ans+=(i*nums[i])%MOD;
        ans%=MOD;
    }
    cout<<ans<<endl;
}

int main(){
    int tc=1;
    tc=readIntLn(1,1e4);
    while(tc--){
        solve();
    }
    readEOF();
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            assert(!isspace(buffer[pos]));
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

const long long mod = 998244353;

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e4);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readEoln();
        sn += n;
        auto a = in.readInts(n, 0, 1e9);
        in.readEoln();
        vector<long long> b(n);
        /*
        for (int l = 0; l < n; l++) {
            for (int r = l; r < n; r++) {
                int k = r - l + 1;
                for (int i = 0; i < k - 1; i++) {
                    for (int j = 0; j <= i; j++) {
                        b[l + j]++;
                    }
                    for (int j = i + 1; j < k; j++) {
                        b[l + j]--;
                    }
                }
            }
        }
        */
        for (int i = 0; i < n; i++) {
            if (i * 2 >= n) {
                b[i] = -b[n - 1 - i];
            } else {
                long long k = n - 2 * i - 1;
                b[i] = (i + 1) * k * (k + i + 1) / 2;
                b[i] %= mod;
            }
        }
        long long ans = 0;
        for (int i = 0; i < n; i++) {
            ans += (mod + a[i]) * b[i] % mod;
        }
        ans %= mod;
        ans += mod;
        ans %= mod;
        cout << ans << '\n';
    }
    assert(sn <= 1e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    ans = 0
    for i in range(n):
        coef = (i+1)*(n-i)*(n-i-1)//2
        ans += coef*(a[i] - a[n-1-i])
        ans %= mod
    print(ans)
1 Like

I don’t know if I can post this under the editorial, sorry if I got the wrong place.

This problem can be solved using two-parameter lagrange interpolation, first one parameter is fixed, the coefficients of the first n polynomials are found, and then the coefficients are interpolated for each of the same positions to obtain the coefficients of the coefficients.

just need to write the brute force code first, and then spend a little time calculating it.

There’s another way we can arrive at the soln. Find the contribution of i (without going into the 4 summations as mentioned in the editorial). Look at a subarray [l, r] s.t. l <= i <= r since all the contributions of i will be in these types of subarrays only. Its contribution to this subarray is -(i - l) + (r - i) = l + r - 2*i. So basically, to find the total contribution of i we have to compute contri[i] = \sum_{l=1}^{i}\sum_{r=i}^{n} l + r - 2*i = (n - i + 1)*i*(i + 1) / 2 + i*(n*(n + 1) / 2 - i*(i - 1) / 2) - 2*i*i*(n - i + 1) (basic sigma properties). And so the final answer is \sum_{i=1}^{n}arr[i]*contri[i].