SQING - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: kingmessi
Testers: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Math

PROBLEM:

For a string S, define the sqing value of S as follows:

  • Let X = 0 initially.
  • For each 1 \leq i \leq N, let j \lt i be the largest index such that S_j \neq S_i.
    Add (i-j)^2 to X.
  • The sqing value of S equals the final value of X.

You’re given N. Compute the sum of the sqing values across all binary strings of length N.

EXPLANATION:

For convenience of notation, let f(S) denote the sqing value of S.

Observe that f(S) is the sum of squares of several values of (i-j).
Further, (i-j) lies between 1 and N-1, since j \lt i.

Let’s fix the value of (i-j), and see how many times (i-j)^2 appears in the summation of all f(S).
Let k = i - j.
Then,

  • For this difference to be possible at all, i should be an index \geq k+1, since we need j \geq 1.
    This gives us N - k choices for what i is, which then uniquely fixes j.
  • We want S_i \neq S_j, so there are two options: either S_i = 0 and S_j = 1, or vice versa.
  • By definition, j should be the closest index to i before it, that has a different value.
    That means, for every index k such that j \lt k \lt i, we are forced to fix S_k = S_i.
    However, indices \lt j or \gt i have no such restrictions, and can be anything.
  • There are N - (i - j + 1) ‘free’ indices, each of which can take two values (0 or 1).
    Since k = i-j, we have N - k - 1 free indices.

Putting everything together, the number of times k appears as a difference in the computation of f(S) across all S, is

2\cdot (N-k) \cdot 2^{N-k-1}

Its contribution to the answer is thus this quantity, multiplied by k^2.

To obtain the overall answer, just sum this up across all k, to obtain

\sum_{k=1}^{N-1} 2\cdot (N-k) \cdot 2^{N-k-1} \cdot k^2

This is easily computed in \mathcal{O}(N) or \mathcal{O}(N\log N) time by just looping over k.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include<bits//stdc++.h>
using namespace std;
#define int long long
const int M = 1e9+7;

long long binpow(long long a, long long b, long long m = M) {
    a %= m;
    long long res = 1;
    while (b > 0) {
        if (b & 1)
            res = res * a % m;
        a = a * a % m;
        b >>= 1;
    }
    return res;
}

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !(buffer[now] == ' ' || buffer[now] == '\n')) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
}inp;

signed main(){
	int t;
	// cin >> t;
	t = inp.readInt(1,1000);
	inp.readEoln();
	int smn = 0;
	while(t--){
		int n;
		// cin >> n;
		n = inp.readInt(1,500'000);
		inp.readEoln();
		smn += n;
		int ans = 0;
		int tot = binpow(2,n);
		int inv = binpow(2,M-2);
		int cur = inv;
		for(int i = 2;i <= n;i++){
			int sq = ((i-1)*(i-1))%M;
			int res = ((n-i+1)*sq)%M;
			res *= cur;res %= M;
			cur *= inv;cur %= M;
			ans += res;
			ans %= M;
		}
		ans *= tot;
		ans %= M;
		ans += M;
		ans %= M;
		cout << ans << endl;
	}
	assert(smn <= 500000);
	inp.readEof();
	return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n = int(input())
    ans = 0
    for i in range(1, n):
        ch = 2 * (n - i)
        rem = n - i - 1
        ch *= pow(2, rem, mod)
        ans += ch * i * i % mod
    print(ans % mod)
3 Likes

Amazing.
How people figure it out in two hours !!

4 Likes

Please do something. Hundreads of users cheated in last 15 minutes for this queation. This is really disappointing

5 Likes