SIMPLEARRAY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Sahil Tiwari
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

2326

PREREQUISITES:

Frequency arrays, Basic combinatorics

PROBLEM:

You are given an array A and an integer K. Count the number of subsequences that don’t contain any pair whose sum is divisible by K.

EXPLANATION:

First, notice that the condition "A_i + A_j is divisible by K" can be written as A_i + A_j \equiv 0 \pmod K.

In particular, we can work with all the array elements modulo K so that they’re all between 0 and K-1.

Now, consider what happens when we have x+y \equiv 0 \pmod K when both x and y are less than K. There are three possibilities:

  • First, we can have x = y = 0
  • Second, if K is even we can have x = y = K/2
  • Finally, if neither of the above hold, we must have y = K-x; and in particular x \neq y.

Let’s leave the first two cases alone for now, and look at the third.
For convenience, let x \lt K-x.
Note that for each x, any good subsequence can have either some occurrences of x, or some occurrences of K-x: never both.
In particular, we can take as many x-s as we like, or as many (K-x)-s as we like, without affecting any other sums (since K-(K-x) = x). Essentially, we ‘pair up’ x with K-x, and then different pairs are completely independent.

So, the choices of which of the x's or (K-x)'s we take are completely independent across different x.
This means that any subsequence can be constructed as follows:

  • Choose a subset of 1's or a subset of K-1's
  • Then, choose a subset of 2's or a subset of (K-2)'s
  • Then, choose a subset of 3's or a subset of (K-3)'s
    \vdots

Thus, the total number of subsequences can be found by multiplying the number of choices for different x.

This brings us to the questions: how many choices are there for a fixed x?

Answer

Let freq(x) be the number of occurrences of x in the array.

Note that we can choose any subset of the x's, or any subset of the (K-x)'s.
The first one gives us 2^{freq(x)} choices, while the second gives us 2^{freq(K-x)} choices.

The empty set is counted in both, so we need to subtract 1 to avoid overcounting.
This brings the total to 2^{freq(x)} + 2^{freq(K-x)} - 1.

The number of subsequences is thus just the product of (2^{freq(x)} + 2^{freq(K-x)} - 1) across all x such that x \lt K-x.

The only exceptions here are x = 0 and (if K is even) x = K/2, which shouldn’t be included in the above product because they behave slightly differently. Do you see how to deal with them?

Answer

x = 0 and x = K/2 follow a simple rule: there can’t be more than one of each in the subsequence.

So, we have 1+freq(0) choices for 0 (choose none of them, or choose exactly one), and similarly 1 + freq(K/2) choices for K/2.

Multiply these quantities to the previous value to obtain the final answer.

Notice that the value for a given x requires us to compute a power of 2 modulo something.
There are several ways to do this: the simplest is to just precompute the value of 2^x\pmod {MOD} for every 0 \leq x \leq 5\cdot 10^5 before processing any test cases, after which these can be used in \mathcal{O}(1).
Alternately, you can use binary exponentiation.

TIME COMPLEXITY

\mathcal{O}(N + K) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
#define int long long int
using namespace std;

#define mod 1000000007

int power(int a , int b) {
	if(b == 0)
		return 1;
	int res = power(a , b>>1);
	if(b & 1)
		return (res * res % mod) * a % mod;
	return res * res % mod;
}

signed main() {

	int t;
	cin>>t;
	while(t--) {
		int n , k;
		cin>>n>>k;
		vector<int> a(k);
		for(int i=0;i<n;i++) {
			int x;
			cin>>x;
			a[x % k]++;
		}
		// for(auto &i: a)
		// 	cout<<i<<" ";
		// cout<<endl;
		int ans = 1;
		// cout<<power(2 , 2)<<endl;
		for(int i=1;i<(k+1)/2;i++) {
			int c = (power(2 , a[i]) + power(2 , a[k-i]) - 1);
			ans = ans * c % mod;
		}
		if(k % 2 == 0) {
			ans = ans * (a[k/2] + 1) % mod;
		}
		ans = ans * (a[0]+1) % mod;
		cout<<ans<<endl;
	}
	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>

using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

const long long mod = (int) 1e9 + 7;

struct mint {
    long long value;

    mint(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }

    mint &operator+=(const mint &other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }

    mint &operator-=(const mint &other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }

    mint &operator*=(const mint &other) {
        value = value * other.value % mod;
        return *this;
    }

    mint &operator/=(const mint &other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }

    friend mint operator+(const mint &lhs, const mint &rhs) { return mint(lhs) += rhs; }

    friend mint operator-(const mint &lhs, const mint &rhs) { return mint(lhs) -= rhs; }

    friend mint operator*(const mint &lhs, const mint &rhs) { return mint(lhs) *= rhs; }

    friend mint operator/(const mint &lhs, const mint &rhs) { return mint(lhs) /= rhs; }

    mint &operator++() { return *this += 1; }

    mint &operator--() { return *this -= 1; }

    mint operator++(int) {
        mint result(*this);
        *this += 1;
        return result;
    }

    mint operator--(int) {
        mint result(*this);
        *this -= 1;
        return result;
    }

    mint operator-() const { return mint(-value); }

    bool operator==(const mint &rhs) const { return value == rhs.value; }

    bool operator!=(const mint &rhs) const { return value != rhs.value; }

    bool operator<(const mint &rhs) const { return value < rhs.value; }
};

string to_string(const mint &x) {
    return to_string(x.value);
}

ostream &operator<<(ostream &stream, const mint &x) {
    return stream << x.value;
}

istream &operator>>(istream &stream, mint &x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(1 / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e4);
    in.readEoln();
    int sn = 0, sk = 0;
    int mn = 2e9, mx = -1;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readSpace();
        int k = in.readInt(1, 5e5);
        in.readEoln();
        sn += n;
        sk += k;
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            a[i] = in.readInt(1, 1e9);
            mn = min(mn, a[i]);
            mx = max(mx, a[i]);
            a[i] %= k;
            (i == n - 1 ? in.readEoln() : in.readSpace());
        }
        vector<int> c(k);
        for (int i = 0; i < n; i++) {
            c[a[i]]++;
        }
        mint ans = c[0] + 1;
        for (int i = 1; i < k; i++) {
            int j = k - i;
            if (i == j) {
                ans *= c[i] + 1;
            } else if (i < j) {
                ans *= power(2, c[i]) + power(2, c[j]) - 1;
            }
        }
        cout << ans << '\n';
    }
    assert(sn <= 1e6);
    assert(sk <= 1e6);
    cerr << sn + sk << " " << mn << " " << mx << endl;
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    freq = [0]*k
    for x in a:
        freq[x%k] += 1
    ans = 1
    for i in range(k):
        if i == 0 or 2*i == k:
            ans *= 1 + freq[i]
        else:
            if i > k-i: break
            ans *= pow(2, freq[i], mod) + pow(2, freq[k-i], mod) - 1
        ans %= mod
    print(ans%mod)
2 Likes

Hey, can anyone explain this approach in simpler words? I am not able to comprehend and follow the approach that has been stated here.

1 Like

same🤔

#include <bits/stdc++.h>
using namespace std;

#define int     long long int

int M = 1e9+7;

int binExpIter(int a,int b){ 
    int ans = 1;
    while(b){ 
        if(b&1) ans = (ans * a) %M; 
        a = (a*a) % M; 
        b >>= 1; 
    } 
    return ans;
}

void solve(){
    int n,k;
    cin>>n>>k;
    vector<int> v(n);
    map<int, int> mp;
    for(int i=0;i<v.size();i++){
        cin>>v[i];
        v[i]= v[i]%k;
        mp[v[i]]++;
    }

    int ans = 1; // empty sub-sequence

    // check for 1 to (k+1)/1 , but exclude (k+1)/2
    for(int i=1;i<(k+1)/2;i++){
        int first = mp[i];
        int counter = mp[(k-i)];

        int currAns = (power(2,first) + power(2,counter) - 1)%M;

        ans = ans*currAns;
        ans %= M;
    }

    // check for 0 (whose remainder was zero)
    ans = (ans*(mp[0] + 1))%M;

    // check for equal to (k+1)/2 (whose remainder was zero)
    if(k%2==0) ans = (ans*(mp[(k)/2] + 1))%M;

    cout<<ans<<endl;    
}

signed main() {
    int t=1;
    cin>>t;
    while(t--) solve();
    return 0;
}

Most important is Binary exponentiation , can understand from luv YouTube. Take input and storing value with mod , this will help in calculation. Calculate the frequency simultaneously using map data structure , it would help in find the i + k-i = k.
find the frequency of i and k - i with O(1) time.

5 3
[1, 2, 3, 4, 5]  -- After applying mod -- > [1, 2, 0, 1, 2]

declare Ans = 1 , because at least we have empty subsequence every time in your answer.

Iterate from 1 to less than middle of k to find all the pair which are there alternative . if the counter part exists then we will divide the space into two

0 -> we will deal later
1 - 3 -> find the pair which are divisible by k using [ i + k-i = k ]
2 -> we will also deal later with middle value for even number

we are increasing the search space by separating the divisible number in separate set
int currAns = (power(2,first) + power(2,counter) - 1)%M;
we know the number of subsequence will be 2n for length n. subtract 1 from answer because we have already consider empty [ [] ] as our answer.

Time to handle 0 case . It would be multiple of frequency count of 0 in our search space
ans = (ans*(mp[0] + 1))%M;

At last for middle value , same as zero one consider
ans = (ans*(mp[(k)/2] + 1))%M;

New thing learn was binary exponentiation

Like this post it take 2hr to understand , code and write blog.

3 Likes

Hi can anyone maybe explain this in a simpler babyish way, can’t really understand the intuition behind this approach!

Thanks @bakru_k78