MDSWIN - Editorial

Author: Md Sabbir Rahman
Tester: Yash Chandnani
Editorialist: Michael Nematollahi

HARD

PREREQUISITES:

Nim, Fast Walsh–Hadamard transform, Rabin Miller

PROBLEM:

There are K piles of stones, each of them with initially one stone in it. You are given a set of N integers S.
For each pile i, you are to choose a value H[i] such that H[i] = a for some a \in S.
Two players are playing a game on these piles. A player at their turn chooses one or two piles and does the following operation on each of them:

• An operation on pile i with currently x stones in it consists of adding y > 0 stones to it such that (x+y) | H[i].

The player who cannot make a move loses.

In how many ways can you choose the H array for the piles such that the first player wins the game (provided that the players play optimally)?

QUICK EXPLANATION:

Translate the game into an almost regular Nim game where you can take stones off of one or two piles in your turn. You will need to use a quick algorithm to find the number of divisors of a number.
The winning condition for the first player is that the “3-xor” (by which I mean summing up the binary digits modulo 3, instead of 2, which would be the regular xor) be not 0.
Finally, use the fast Walsh-Hadamard transform to find the number of ways to choose the number of stones in each pile so that their 3-xor is not 0.

EXPLANATION:

First, let’s talk about what an operation on a pile looks like.

Consider a pile v. Let D be the sorted array of divisors of H[v]. Assume it currently has 1 stone in it.
If a player makes the number of stones in v equal to D[j], the first j divisors will not be useable in the future, as in the next player cannot make the number of stones in v equal to them. This can be interpreted as removing the first j-1 divisors from D (j-1, as 1 is already not useable.). This is the same as the regular Nim game, where a player at their turn removes a positive number of stones from a pile.

So we can replace the i^{th} pile with a new pile that has d(H[i]) - 1 stones in it, where d(x) is the number of divisors of x.
d(x) can be calculated in O(\frac{x^{1/3}}{log(x)}) by precomputing the prime numbers up to x^{1/3}, whose count is O(\frac{x}{log(x)}), and using the Rabin Miller algorithm to determine if a number is prime. Refer to the SpecialTau function in the setter’s code to see how.
So the complexity of replacing every number with their number of divisors will be O(\frac{N \times MAX^{1/3}}{log(MAX)}), where MAX = 10^9 is the maximum value of a member of S.

By inspection, you can confirm that a number under the given constraints can have at most 1344 divisors. Which means, the piles in the translated game have at most 1343 stones in them. Note that the binary representation of 1343 has 11 bits (excluding the leading 0's).

The new game is almost the same as the regular Nim game, except that a player can choose one or two piles at their turn and make a move on each of them.
As mentioned here, the first player in this version of Nim wins iff the “3-xor” (by which I mean summing up the binary digits modulo 3, instead of 2, which would be the regular xor) is not 0.
The proof is similar to the regular Nim, where one could make the xor-sum of the piles 0 iff it wasn’t already 0.

Let P be the polynomial of degree 3^T-1 whose coefficients are c_i, where c_i is the number of elements a \in S such that d(a)-1 = i and T = 11 is the maximum number of digits in the binary representation of d(a) - 1.

Utilizing the observations made above, the problem comes down to finding the sum of the coefficients of the terms in P^K with non-zero degrees, where the product of two terms with degrees a and b is the 3-xor of a and b. This problem can be solved using the fast Walsh-Hadamard transform in O((T+log K)*3^T).

To see an implementation of the solution described, refer to the setter’s code below.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
ll mod = 1e9+7;

#define FASTIO ios_base::sync_with_stdio(false);cin.tie(NULL);
#define bt(i) (1LL<<(i))

#define debug(x) cerr<<#x<<" = "<<(x)<<"\n"
#define hoise cerr<<"hoise\n"
#define tham getchar()

mt19937 rng((unsigned int) chrono::system_clock::now().time_since_epoch().count());

inline ll MOD(ll x, ll m = mod){
if(x < m && x >= 0) return x;
ll y = x % m;
return (y < 0)? y+m: y;
}

const int nmax = 1e3+5;
///====================== template =========================

inline ll add(ll x, ll y){
x += y;
if(x >= mod) x -= mod;
return x;
}

///utility function for modular multiplication
inline ll mult(ll x, ll y){
x *= y;
if(x >= mod) x%= mod;
return x;
}

///An algebraic extended number system, where w^3 = 1
///A number is a+bw, a and b are in modular field
///multiplication rule is special w*w = - w - 1 (as w^2 + w + 1 = 0)
struct extNum{
ll a, b;
extNum(ll _a = 0, ll _b = 0){
a = _a, b = _b;
}
extNum operator+(extNum x){
}
extNum operator*(ll k){
return extNum(mult(a, k), mult(b, k));
}
extNum operator*(extNum x){
ll na = MOD(a*x.a - b*x.b);
ll nb = MOD(a*x.b + b*x.a - b*x.b);
return extNum(na, nb);
}
void print(){
cout<<a<<" + w"<<b<<"\n";
}
}w(0, 1), w2(mod-1, mod-1);     ///these are w and w^2 respectively

///Convolution code, Given a list of count of numbers in base 3
///converts them into a point value form like ordinary fwht, except
///unlike fwht, the values used are 1, w, w^2
///Takes time O(n log_3(n))
typedef vector<extNum> poly;
void FWHT(poly &coefs, poly &vals, bool invert = false) {
vals = coefs;
int n = vals.size();
for (int len = 1; len < n; len *= 3) {
int pitch = len*3;
int len2 = len*2;
for (int i = 0; i < n; i += pitch) {
for (int j = 0; j < len; j++) {
extNum a = vals[i + j];
extNum b = vals[i + j + len];
extNum c = vals[i + j + len2];
vals[i+j] = a + b + c;
vals[i+j+len] = a + b*w + c*w2;
vals[i+j+len2] = a + b*w2 + c*w;
if(invert) swap(vals[i+j+len], vals[i+j+len2]);
}
}
}

ll inv3 = (mod+1)/3, inv = 1;
for(int i = 1; i<n; i*= 3)
inv = (inv3*inv) % mod;
if (invert)
for (int i = 0; i < n; i++) vals[i] = vals[i]*inv;
return;
}

///Performing the convolution and multiplication is done
///via divide and conquer, kind of like binary exponentiating
extNum expo(extNum x, ll n){
extNum ret(1);
while(n){
if(n & 1LL) ret = (ret*x);
x = (x*x);
n >>= 1;
}
return ret;
}

///Sieve to calculate primes upto 10^3 = cube_root(10^9)
vector<int> primes;
bool composite[nmax];
int sieve(){
for(int i = 2; i<nmax; i++){
if(composite[i]) continue;
primes.push_back(i);
for(int j = i+i; j<nmax; j+=i) composite[j] = true;
}
return primes.size();
}

///exponentiating in mod
ll modexpo(ll x, ll n, ll m = mod){
if(n == 0) return (m == 1)? 0: 1;
ll y = modexpo(x*x % m, n >> 1, m);
return (n&1)? y*x % m: y;
}

///Rabin_miller to quickly check if a number is prime
///SPRP is a proven list of witnesses that can check prime for
///number upto 1e18
ll SPRP[7] = {2LL, 325LL, 9375LL, 28178LL, 450775LL, 9780504LL, 1795265022LL};
bool RabinMiller(ll p, int t = 7)		//t = 7 for SPRP base
{
if(p < 4) return (p > 1);
if(!(p & 1LL)) return false;
ll x = p - 1;
int e = __builtin_ctzll(x);
x >>= e;
while(t--)
{
//ll witness = (rng() % (p-3)) + 2;	//Using random witness
ll witness = SPRP[t];
witness = modexpo(witness%p, x, p);
if(witness <= 1) continue;
for(int i = 0; i<e && witness != p-1; i++)
witness = (witness * witness) % p;
if(witness != p-1) return false;
}
return true;
}

///check if a number is square, works for number upto 1e9
bool isSquare(ll x){
ll r = sqrtl(x);
for(ll i = r-3; i<=r+3; i++)
if(i*i == x) return true;
return false;
}

///Computes tau(n) = count of divisors of a number, in O(max^(1/3))
///We try to divide by primes till max^(1/3)
///After that only p, p^2, pq or 1 remain
///these are checked by isSquare and Rabin_miller
///We don't need the primes p or q, just their exponent is enough
int SpecialTau(int n){
int ret = 1;
for(int i = 0; i<primes.size(); i++){
int freq = 0;
while(n % primes[i] == 0) freq++, n/= primes[i];
ret *= (freq+1);
}
if(n == 1) return ret;
else if(isSquare(n)) return 3*ret;
else if(RabinMiller(n)) return 2*ret;
else return 4*ret;
}

///converts the binary representation of x to ternary
///5 (101 in binary) is transformed to 10 (101 in ternary)
int tobase3(int x){
int ret = 0;
for(int i = 10; i>=0; i--){
ret*= 3;
if(bt(i)&x) ret += 1;
}
return ret;
}

int MAX = 177147; ///3^11, since maximum divisor count is 1344, which has 11 bits

///FIRST PART OF THE SOLUTION:
///Take the numbers, compute their count of divisors - 1
///As this is the equivalent nim-stack for these numbers
///Then they are converted to base3, for the next step of
///solution
void first(int n, int k, poly &p){
set<int> st;
for(int i = 0; i<n; i++){
int x;
cin>>x;
assert(1 <= x && x <= 1000000000);
st.insert(x);
x = SpecialTau(x)-1;
x = tobase3(x);
p[x].a += 1;
}
assert(st.size() == n);
}

///SECOND PART OF THE SOLUTION:
///A assignment of values (in nim equivalent stone count) is winning
///if we convert the binary representations to base-3 and 3-xor of them
///is non-zero, So we need to perform fwht convolution K times.
///That is done via modified fwht and divide and conquer
void second(int n, int k, poly &p){
FWHT(p, p);
for(int i = 0; i<MAX; i++)
p[i] = expo(p[i], k);
FWHT(p, p, true);
ll ans = 0;
for(int i = 1; i<MAX; i++)
cout<<ans<<"\n";
}

#define time__(f, s) \
{clock_t CLK = clock(); \
f;  \
fprintf(stderr, #s " %.3f\n", (double)(clock() - CLK) / CLOCKS_PER_SEC);}
void solve(){
sieve();

int tc;
cin>>tc;
assert(1 <= tc && tc <= 5);
for(int i = 0; i<tc; i++){
int n, k;
poly p(MAX);
cin>>n>>k;
assert(1 <= n && n <= 100000);
assert(1 <= k && k <= 1000000000);
first(n, k, p);
second(n, k, p);
}
//time__(first(n, k, p), first step:);
//time__(second(n, k, p), second step:);
}

int main(){
FASTIO;
solve();
//time__(solve(), time:);
}

Tester's Solution
#include <bits/stdc++.h>
using namespace std;

void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}

template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#ifndef ONLINE_JUDGE
#define debug(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define debug(x...)
#endif

#define rep(i, n)    for(int i = 0; i < (n); ++i)
#define repA(i, a, n)  for(int i = a; i <= (n); ++i)
#define repD(i, a, n)  for(int i = a; i >= (n); --i)
#define trav(a, x) for(auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()
#define fill(a)  memset(a, 0, sizeof (a))
#define fst first
#define snd second
#define mp make_pair
#define pb push_back
typedef long double ld;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef unsigned long long ull;
const int bits = 10;
// if all numbers are less than 2^k, set bits = 64-k
const ull po = 1 << bits;
ull mod_mul(ull a, ull b, ull &c) {
ull x = a * (b & (po - 1)) % c;
while ((b >>= bits) > 0) {
a = (a << bits) % c;
x += (a * (b & (po - 1))) % c;
}
return x % c;
}
ull mod_pow(ull a, ull b, ull mod) {
if (b == 0) return 1;
ull res = mod_pow(a, b / 2, mod);
res = mod_mul(res, res, mod);
if (b & 1) return mod_mul(res, a, mod);
return res;
}
bool prime(ull p) {
if (p == 2) return true;
if (p == 1 || p % 2 == 0) return false;
ull s = p - 1;
while (s % 2 == 0) s /= 2;
rep(i,8) {
ull a = rand() % (p - 1) + 1, tmp = s;
ull mod = mod_pow(a, tmp, p);
while (tmp != p - 1 && mod != 1 && mod != p - 1) {
mod = mod_mul(mod, mod, p);
tmp *= 2;
}
if (mod != p - 1 && tmp % 2 == 0) return false;
}
return true;
}
const int MAX_PR = 5000000;
bitset<MAX_PR> isprime;
vi eratosthenes_sieve(int lim) {
isprime.set(); isprime[0] = isprime[1] = 0;
for (int i = 4; i < lim; i += 2) isprime[i] = 0;
for (int i = 3; i*i < lim; i += 2) if (isprime[i])
for (int j = i*i; j < lim; j += i*2) isprime[j] = 0;
vi pr;
repA(i,2,lim-1) if (isprime[i]) pr.push_back(i);
return pr;
}
vector<ull> pr;
ull f(ull a, ull n, ull &has) {
return (mod_mul(a, a, n) + has) % n;
}
bool square(int x){
int z = sqrt(x);
return (z*z==x)||((z+1)*(z+1)==z)||((z-1)*(z-1)==x);
}
int factor(int d) {
int res = 1;
for (int i = 0; i < sz(pr) && pr[i]*pr[i]<=d; i++)
if (d % pr[i] == 0) {
int cnt = 1;
while (d % pr[i] == 0) d /= pr[i],cnt++;
res*=cnt;
}
//d is now a product of at most 2 primes.
if (d > 1) {
if (prime(d))
res*=2;
else if(square(d)){
res*=3;
}
else res*=4;
}
return res;
}
void init(int bits) {//how many bits do we use?
vi p = eratosthenes_sieve(1 << ((bits + 2) / 3));
pr.assign(all(p));
}
const ll mod = 1e9+7;
int tf[1500];
typedef vector<pair<ll,ll>> vl;
typedef pair<ll,ll> pll;
pll mul(pll x,pll y){
return mp((x.fst*y.fst-x.snd*y.snd)%mod,(x.fst*y.snd+x.snd*y.fst-x.snd*y.snd)%mod);
}
return mp((x.fst+y.fst)%mod,(x.snd+y.snd)%mod);
}
pll modpow(pll a, ll e) {
if (e == 0) return mp(1,0);
pll x = modpow(mul(a,a), e >> 1);
return e & 1 ? mul(x,a) : x;
}
ll modpow(ll a, ll e) {
if (e == 0) return 1;
ll x = modpow(a * a % mod, e >> 1);
return e & 1 ? x * a % mod : x;
}
void pre(){
init(30);
rep(i,1345){
rep(j,11) if((1<<j)&i) tf[i]+=modpow(3,j);
}
}
void ifwht(vl& a){
int s = 3;
while(sz(a)>=s){
int i = 0,j=s/3,k=2*j;
while(i<sz(a)){
pll x = a[i],y=a[j],z=a[k];
pll w2 = mp(0,1);
pll w = mp(-1,-1);
i++,j++,k++;
if(i%s==s/3) i+=2*s/3,j+=2*s/3,k+=2*s/3;
}
s*=3;
}
}
void fwht(vl& a){
int s = 3;
while(sz(a)>=s){
int i = 0,j=s/3,k=2*j;
while(i<sz(a)){
pll x = a[i],y=a[j],z=a[k];
pll w = mp(0,1);
pll w2 = mp(-1,-1);
i++,j++,k++;
if(i%s==s/3) i+=2*s/3,j+=2*s/3,k+=2*s/3;
}
s*=3;
}
}
vl conv(vl a,vl b){
vl ans(sz(a));
fwht(a),fwht(b);
rep(i,sz(a)){
ans[i] = mul(a[i],b[i]);
}
ifwht(ans);
ll inv = modpow(sz(ans),mod-2);
rep(i,sz(ans)){
ans[i].fst=ans[i].fst*inv%mod;
ans[i].snd=ans[i].snd*inv%mod;
}
return ans;
}

void solve(){
int n,k;cin>>n>>k;
map<int,int> m;
rep(i,n){
int x;cin>>x;
int y = factor(x);
m[y-1]++;
}
vl a(modpow(3,11),mp(0,0)),ans(modpow(3,11),mp(0,0));
trav(i,m) {
a[tf[i.fst]].fst=i.snd;
}
fwht(a);
rep(i,sz(a)){
a[i] = modpow(a[i],k);
}
ifwht(a);
ll inv = modpow(sz(a),mod-2);
cout<<(modpow(n,k)-a[0].fst*inv%mod+mod)%mod<<'\n';

}

int main() {
cin.sync_with_stdio(0); cin.tie(0);
cin.exceptions(cin.failbit);
pre();
int n;cin>>n;
rep(i,n) {
solve();
}
return 0;
}

4 Likes

I got that this question will be solved by using Walsh Transform and Even found a blog on CS academy to understand how to create a transform matrix or use different operator, but I didnt got that method. Can anyone help me finding a good resource for that ?

2 Likes

Here you need a modified version of the Hadamard transform.

In Hadamard transform “mod 2”, the transform can be thought of as multiplying by a matrix with entries in {1, -1}. Note that {1, -1} are exactly roots of x^2 = 1.

If we move to “mod 3”, the matrix we need to consider will contain roots of x^3 = 1. There are three of them (1, and two other complex numbers). As a first step, you need to figure out the right matrix; the next step is to see how to work around dealing with complex numbers, so that you only need to deal with integers.

2 Likes

Thank you for the explanation.

I am not that much familiar with advanced algebra, It seems the method given in https://csacademy.com/blog/fast-fourier-transform-and-variations-of-it
is more generalized. I stucked on the part to find find coefficients of transform matrix.

What I have got this is, write two equations,
first one in convolution which represents property our operator and second in pointwise multiply in another domain.

Now given these two equations, i dont know how to extract coefficients from it.

1 Like

Hi, setter here, if anyone needs it, the final part of the solution is explained below in depth :

Once the problem has been reduced to finding the 3-xor convolution, the next step is modified FWHT. First a brief summary of FWHT: in FFT, ax^5 and bx^3 multiply to become abx^8. But in FWHT we want it to be abx^{5 \oplus 3}. To do this instead we represent the terms as ax_1^1x_2^0x_3^1 and bx_1^0x_2^1x_3^1 (5 and 3 have been broken into bits). In this form we use the set \{1, -1\} for each x_i. In this domain x^p \times x^q = x^{p \oplus q} (where p and q are single bits). So now FWHT is just a multivariable FFT.

Now to the main problem, here we want 3-xor. The behaviour we want is x^p \times x^q = x^{(p + q) \mod 3}. For this we need to use the domain \{1, \omega, \omega^2 \}, here \omega is a cube root of unity. In the setter’s code, I have commented this over my FWHT function. You will notice that whereas normal FWHT uses a+b and a-b, mine uses a+b+c, a+b\omega + c\omega^2 and a+b\omega^2 + c\omega (swapping for inverse). This is basically using \{1, \omega, \omega^2 \} on the polynomial a+bx+cx^2. So now a polynomial of size 3^{11} will be used which contains all possible ternary numbers of 11 digits and modified FWHT can be used for 3-xor convolution.

There is a final step to this problem, the answer is required mod 10^9+7 but in this mod there is no cube root of unity. So, instead we create a hypothetical number system a+bw (like complex numbers a+bi). Here w is a cube root of unity and a, b are numbers mod 10^9+7. We use the rule w \times w = -w-1 for multiplication and now we can use this number system for addition, subtraction and multiplication (maybe division too, but didn’t think about it too much). Since after all calculation we would have gotten integers as answers, so using these number system, after all calculation, we’ll get a+bw with b = 0. In the setter’s code, I have used a structure to keep track of this numbers.

So convert to d(x)-1, convert that to the FWHT polynomial in the new number system, apply FWHT, exponentiate it to k using divide and conquer, invert the FWHT, and add all the coefficients except the one representing 0.

7 Likes

Can someone please explain with an example, what is meant by 3-xor?

1 Like

Can you go into depth in why you need roots of unity? Correct me if I’m wrong, but as far as I know, roots of unity are used in classic FFT only to apply divide and conquer and optimize from O(n^2) to O(n \cdot log(n)). Outside of complexity purposes, they’re useless: any invertible matrix could also do the transformation correctly.

In FWHT, we are only doing 3 \star 3 matrix multiplications, so the original purpose for unity roots is gone. Why do we use unity roots then? Is it a mere mathematical coincidence?

I’m asking because you jumped straight into the fact that we need cubic roots of unity, but didn’t explain why. That’s what isn’t clear to me.