PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: Yash Kulkarni
Tester: Harris Leung
Editorialist: Nishank Suresh
DIFFICULTY:
TBD
PREREQUISITES:
Binomial Theorem, Binary Exponentiation
PROBLEM:
You have N balls with 2 written on them, and M balls with 0 written on them. Find the number of subsets S of these N+M balls such that sum(S) - |S| is a multiple of 3.
EXPLANATION:
Suppose our subset S contains x balls with 2 written on them, and y balls with 0 written on them.
Then, sum(S) - |S| = 2x - (x+y) = x-y. This is a multiple of 3 if and only if x \equiv y \pmod{3}.
So, what we need to do boils down to this:
- For i = 0, 1, 2, let a_i denote the number of ways to choose x balls with 2 written on them, such that the remainder upon dividing x by 3 is i.
- Similarly, let b_i be the number of ways to choose $y balls with 0 written on them, such that the remainder upon dividing y by 3 is i.
- Then, add a_i \cdot b_i to the answer.
Now all that remains is to compute the values of a_i and b_i for each i. Below I’ll describe how to compute a_0 — the other values will follow a similar process.
We want to compute a_0, which is the number of ways to choose x balls from N such that x is a multiple of 3.
It should be immediately obvious that this is nothing but
So, how do we calculate this?
It turns out, the answer is (2^N + 2\cos(\frac{\pi N}{3})) \cdot \frac{1}{3}
Proof
This is pure equation manipulation.
A common technique when dealing with binomial coefficients is, of course, the binomial expansion.
Consider (1+x)^N = \sum_{i=0}^N \binom{N}{i} x^i. We want to ‘filter out’ the binomial coefficients that are multiples of 3. One way to do this is to evaluate this expression at x = \omega, a primitive third root of unity.
So, we have
Note that this expression also equals (-\omega^2)^N, since 1+\omega+\omega^2 = 0.
Similarly, evaluating the binomial expansion at x = 1 and x = \omega gives
and
Adding up all 3 equations and once again using the fact that 1 + \omega + \omega^2 = 0, we obtain
Now, use the fact that \omega = \cos(\frac{2\pi}{3}) + i \sin(\frac{2\pi}{3}) to cancel out terms and finally bring the 3 to the other side to obtain the desired result.
Note that 2\cos(\frac{\pi N}{3}) is in fact always an integer: it will take the values -2/0/2 depending on the value of N — more specifically, because of periodicity, it is enough to compute these for N modulo 6.
This computes a_0 (and similarly b_0, by replacing N with M).
a_1 and a_2 (and hence b_1 and b_2) can be computed similarly, by evaluating the expressions x(1+x)^N and x^2(1+x)^N at the third roots of unity and summing the expressions. You will end up with similar equations for them in terms of \cos and/or \sin, which will still always be integers under these constraints.
Note that you can also just compute a_1, and then compute a_2 as 2^N - a_0 - a_1.
Computing 2^N needs to be done in \mathcal{O}(\log N) using binary exponentiation, and division by 3 also needs to be done via computing the modular inverse of 3 and multiplying by it.
Finally, remember to subtract 1 from the final answer: we counted the empty set as part of our answer, which is not allowed.
TIME COMPLEXITY
\mathcal{O}(\log N) per test case.
CODE:
Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define mod 1000000007
ll powerm(ll x,ll y){ ll res=1; while(y){ if(y&1) res=(res*x)%mod; y=y>>1; x=(x*x)%mod;} return res%mod; }
int main() {
ll T;
cin >> T;
while(T--){
ll N,M;
cin >> N >> M;
// values of 2cos(n.pi/3)
vector<ll>x={2,1,-1,-2,-1,1};
// values of 2sin((2n-1)pi/6)
vector<ll>y={-1,1,2,1,-1,-2};
// values of nC0 + nC3 + nC6 +..., nC1 + nC4 + nC7 +... and nC2 + nC5 + nC8 +...
ll a0=((powerm(2,N)+x[N%6])*powerm(3,mod-2))%mod;
ll a1=((powerm(2,N)+y[N%6])*powerm(3,mod-2))%mod;
ll a2=((powerm(2,N)-a0-a1)%mod+mod)%mod;
// values of mC0 + mC3 + mC6 +..., mC1 + mC4 + mC7 +... and mC2 + mC5 + mC8 +...
ll b0=((powerm(2,M)+x[M%6])*powerm(3,mod-2))%mod;
ll b1=((powerm(2,M)+y[M%6])*powerm(3,mod-2))%mod;
ll b2=((powerm(2,M)-b0-b1)%mod+mod)%mod;
ll ans=(a0*b0+a1*b1+a2*b2-1)%mod;
cout << ans << endl;
}
return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=1e9+7;
// -------------------- Input Checker Start --------------------
long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}
string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[i] = readIntSp(l, r);
a[n - 1] = readIntLn(l, r);
return a;
}
// -------------------- Input Checker End --------------------
typedef array<ll,3> arin;
arin operator*(arin x,arin y){
arin z;z[0]=z[1]=z[2]=0;
for(int i=0; i<3 ;i++){
for(int j=0; j<3 ;j++){
z[(i+j)%3]=(z[(i+j)%3]+x[i]*y[j])%mod;
}
}
return z;
}
arin one={1,0,0};
arin two={1,1,0};
arin zero={1,0,1};
arin pw(arin x,ll y){
if(y==0) return one;
if(y%2) return x*pw(x,y-1);
arin res=pw(x,y/2);
return res*res;
}
int main(){
ios::sync_with_stdio(false);cin.tie(0);
int t;t=readInt(1,1000,'\n');
while(t--){
ll n,m;n=readInt(0,1e9,' ');m=readInt(0,1e9,'\n');
arin res=pw(two,n)*pw(zero,m);
cout << (res[0]+mod-1)%mod << '\n';
}readEOF();
}