PROBLEM LINK:
Author: Md Sabbir Rahman
Tester: Yash Chandnani
Editorialist: Michael Nematollahi
DIFFICULTY:
HARD
PREREQUISITES:
Nim, Fast Walsh–Hadamard transform, Rabin Miller
PROBLEM:
There are K piles of stones, each of them with initially one stone in it. You are given a set of N integers S.
For each pile i, you are to choose a value H[i] such that H[i] = a for some a \in S.
Two players are playing a game on these piles. A player at their turn chooses one or two piles and does the following operation on each of them:
- An operation on pile i with currently x stones in it consists of adding y > 0 stones to it such that (x+y) | H[i].
The player who cannot make a move loses.
In how many ways can you choose the H array for the piles such that the first player wins the game (provided that the players play optimally)?
QUICK EXPLANATION:
Translate the game into an almost regular Nim game where you can take stones off of one or two piles in your turn. You will need to use a quick algorithm to find the number of divisors of a number.
The winning condition for the first player is that the “3-xor” (by which I mean summing up the binary digits modulo 3, instead of 2, which would be the regular xor) be not 0.
Finally, use the fast Walsh-Hadamard transform to find the number of ways to choose the number of stones in each pile so that their 3-xor is not 0.
EXPLANATION:
First, let’s talk about what an operation on a pile looks like.
Consider a pile v. Let D be the sorted array of divisors of H[v]. Assume it currently has 1 stone in it.
If a player makes the number of stones in v equal to D[j], the first j divisors will not be useable in the future, as in the next player cannot make the number of stones in v equal to them. This can be interpreted as removing the first j-1 divisors from D (j-1, as 1 is already not useable.). This is the same as the regular Nim game, where a player at their turn removes a positive number of stones from a pile.
So we can replace the i^{th} pile with a new pile that has d(H[i]) - 1 stones in it, where d(x) is the number of divisors of x.
d(x) can be calculated in O(\frac{x^{1/3}}{log(x)}) by precomputing the prime numbers up to x^{1/3}, whose count is O(\frac{x}{log(x)}), and using the Rabin Miller algorithm to determine if a number is prime. Refer to the SpecialTau function in the setter’s code to see how.
So the complexity of replacing every number with their number of divisors will be O(\frac{N \times MAX^{1/3}}{log(MAX)}), where MAX = 10^9 is the maximum value of a member of S.
By inspection, you can confirm that a number under the given constraints can have at most 1344 divisors. Which means, the piles in the translated game have at most 1343 stones in them. Note that the binary representation of 1343 has 11 bits (excluding the leading 0's).
The new game is almost the same as the regular Nim game, except that a player can choose one or two piles at their turn and make a move on each of them.
As mentioned here, the first player in this version of Nim wins iff the “3-xor” (by which I mean summing up the binary digits modulo 3, instead of 2, which would be the regular xor) is not 0.
The proof is similar to the regular Nim, where one could make the xor-sum of the piles 0 iff it wasn’t already 0.
Let P be the polynomial of degree 3^T-1 whose coefficients are c_i, where c_i is the number of elements a \in S such that d(a)-1 = i and T = 11 is the maximum number of digits in the binary representation of d(a) - 1.
Utilizing the observations made above, the problem comes down to finding the sum of the coefficients of the terms in P^K with non-zero degrees, where the product of two terms with degrees a and b is the 3-xor of a and b. This problem can be solved using the fast Walsh-Hadamard transform in O((T+log K)*3^T).
To see an implementation of the solution described, refer to the setter’s code below.
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
ll mod = 1e9+7;
#define FASTIO ios_base::sync_with_stdio(false);cin.tie(NULL);
#define bt(i) (1LL<<(i))
#define debug(x) cerr<<#x<<" = "<<(x)<<"\n"
#define hoise cerr<<"hoise\n"
#define tham getchar()
mt19937 rng((unsigned int) chrono::system_clock::now().time_since_epoch().count());
inline ll MOD(ll x, ll m = mod){
if(x < m && x >= 0) return x;
ll y = x % m;
return (y < 0)? y+m: y;
}
const int nmax = 1e3+5;
///====================== template =========================
///utility function for modular addition
inline ll add(ll x, ll y){
x += y;
if(x >= mod) x -= mod;
return x;
}
///utility function for modular multiplication
inline ll mult(ll x, ll y){
x *= y;
if(x >= mod) x%= mod;
return x;
}
///An algebraic extended number system, where w^3 = 1
///A number is a+bw, a and b are in modular field
///multiplication rule is special w*w = - w - 1 (as w^2 + w + 1 = 0)
struct extNum{
ll a, b;
extNum(ll _a = 0, ll _b = 0){
a = _a, b = _b;
}
extNum operator+(extNum x){
return extNum(add(a, x.a), add(b, x.b));
}
extNum operator*(ll k){
return extNum(mult(a, k), mult(b, k));
}
extNum operator*(extNum x){
ll na = MOD(a*x.a - b*x.b);
ll nb = MOD(a*x.b + b*x.a - b*x.b);
return extNum(na, nb);
}
void print(){
cout<<a<<" + w"<<b<<"\n";
}
}w(0, 1), w2(mod-1, mod-1); ///these are w and w^2 respectively
///Convolution code, Given a list of count of numbers in base 3
///converts them into a point value form like ordinary fwht, except
///unlike fwht, the values used are 1, w, w^2
///Takes time O(n log_3(n))
typedef vector<extNum> poly;
void FWHT(poly &coefs, poly &vals, bool invert = false) {
vals = coefs;
int n = vals.size();
for (int len = 1; len < n; len *= 3) {
int pitch = len*3;
int len2 = len*2;
for (int i = 0; i < n; i += pitch) {
for (int j = 0; j < len; j++) {
extNum a = vals[i + j];
extNum b = vals[i + j + len];
extNum c = vals[i + j + len2];
vals[i+j] = a + b + c;
vals[i+j+len] = a + b*w + c*w2;
vals[i+j+len2] = a + b*w2 + c*w;
if(invert) swap(vals[i+j+len], vals[i+j+len2]);
}
}
}
ll inv3 = (mod+1)/3, inv = 1;
for(int i = 1; i<n; i*= 3)
inv = (inv3*inv) % mod;
if (invert)
for (int i = 0; i < n; i++) vals[i] = vals[i]*inv;
return;
}
///Performing the convolution and multiplication is done
///via divide and conquer, kind of like binary exponentiating
extNum expo(extNum x, ll n){
extNum ret(1);
while(n){
if(n & 1LL) ret = (ret*x);
x = (x*x);
n >>= 1;
}
return ret;
}
///Sieve to calculate primes upto 10^3 = cube_root(10^9)
vector<int> primes;
bool composite[nmax];
int sieve(){
for(int i = 2; i<nmax; i++){
if(composite[i]) continue;
primes.push_back(i);
for(int j = i+i; j<nmax; j+=i) composite[j] = true;
}
return primes.size();
}
///exponentiating in mod
ll modexpo(ll x, ll n, ll m = mod){
if(n == 0) return (m == 1)? 0: 1;
ll y = modexpo(x*x % m, n >> 1, m);
return (n&1)? y*x % m: y;
}
///Rabin_miller to quickly check if a number is prime
///SPRP is a proven list of witnesses that can check prime for
///number upto 1e18
ll SPRP[7] = {2LL, 325LL, 9375LL, 28178LL, 450775LL, 9780504LL, 1795265022LL};
bool RabinMiller(ll p, int t = 7) //t = 7 for SPRP base
{
if(p < 4) return (p > 1);
if(!(p & 1LL)) return false;
ll x = p - 1;
int e = __builtin_ctzll(x);
x >>= e;
while(t--)
{
//ll witness = (rng() % (p-3)) + 2; //Using random witness
ll witness = SPRP[t];
witness = modexpo(witness%p, x, p);
if(witness <= 1) continue;
for(int i = 0; i<e && witness != p-1; i++)
witness = (witness * witness) % p;
if(witness != p-1) return false;
}
return true;
}
///check if a number is square, works for number upto 1e9
bool isSquare(ll x){
ll r = sqrtl(x);
for(ll i = r-3; i<=r+3; i++)
if(i*i == x) return true;
return false;
}
///Computes tau(n) = count of divisors of a number, in O(max^(1/3))
///We try to divide by primes till max^(1/3)
///After that only p, p^2, pq or 1 remain
///these are checked by isSquare and Rabin_miller
///We don't need the primes p or q, just their exponent is enough
int SpecialTau(int n){
int ret = 1;
for(int i = 0; i<primes.size(); i++){
int freq = 0;
while(n % primes[i] == 0) freq++, n/= primes[i];
ret *= (freq+1);
}
if(n == 1) return ret;
else if(isSquare(n)) return 3*ret;
else if(RabinMiller(n)) return 2*ret;
else return 4*ret;
}
///converts the binary representation of x to ternary
///5 (101 in binary) is transformed to 10 (101 in ternary)
int tobase3(int x){
int ret = 0;
for(int i = 10; i>=0; i--){
ret*= 3;
if(bt(i)&x) ret += 1;
}
return ret;
}
int MAX = 177147; ///3^11, since maximum divisor count is 1344, which has 11 bits
///FIRST PART OF THE SOLUTION:
///Take the numbers, compute their count of divisors - 1
///As this is the equivalent nim-stack for these numbers
///Then they are converted to base3, for the next step of
///solution
void first(int n, int k, poly &p){
set<int> st;
for(int i = 0; i<n; i++){
int x;
cin>>x;
assert(1 <= x && x <= 1000000000);
st.insert(x);
x = SpecialTau(x)-1;
x = tobase3(x);
p[x].a += 1;
}
assert(st.size() == n);
}
///SECOND PART OF THE SOLUTION:
///A assignment of values (in nim equivalent stone count) is winning
///if we convert the binary representations to base-3 and 3-xor of them
///is non-zero, So we need to perform fwht convolution K times.
///That is done via modified fwht and divide and conquer
void second(int n, int k, poly &p){
FWHT(p, p);
for(int i = 0; i<MAX; i++)
p[i] = expo(p[i], k);
FWHT(p, p, true);
ll ans = 0;
for(int i = 1; i<MAX; i++)
ans = add(ans, p[i].a);
cout<<ans<<"\n";
}
#define time__(f, s) \
{clock_t CLK = clock(); \
f; \
fprintf(stderr, #s " %.3f\n", (double)(clock() - CLK) / CLOCKS_PER_SEC);}
void solve(){
sieve();
int tc;
cin>>tc;
assert(1 <= tc && tc <= 5);
for(int i = 0; i<tc; i++){
int n, k;
poly p(MAX);
cin>>n>>k;
assert(1 <= n && n <= 100000);
assert(1 <= k && k <= 1000000000);
first(n, k, p);
second(n, k, p);
}
//time__(first(n, k, p), first step:);
//time__(second(n, k, p), second step:);
}
int main(){
FASTIO;
solve();
//time__(solve(), time:);
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#ifndef ONLINE_JUDGE
#define debug(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define debug(x...)
#endif
#define rep(i, n) for(int i = 0; i < (n); ++i)
#define repA(i, a, n) for(int i = a; i <= (n); ++i)
#define repD(i, a, n) for(int i = a; i >= (n); --i)
#define trav(a, x) for(auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()
#define fill(a) memset(a, 0, sizeof (a))
#define fst first
#define snd second
#define mp make_pair
#define pb push_back
typedef long double ld;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef unsigned long long ull;
const int bits = 10;
// if all numbers are less than 2^k, set bits = 64-k
const ull po = 1 << bits;
ull mod_mul(ull a, ull b, ull &c) {
ull x = a * (b & (po - 1)) % c;
while ((b >>= bits) > 0) {
a = (a << bits) % c;
x += (a * (b & (po - 1))) % c;
}
return x % c;
}
ull mod_pow(ull a, ull b, ull mod) {
if (b == 0) return 1;
ull res = mod_pow(a, b / 2, mod);
res = mod_mul(res, res, mod);
if (b & 1) return mod_mul(res, a, mod);
return res;
}
bool prime(ull p) {
if (p == 2) return true;
if (p == 1 || p % 2 == 0) return false;
ull s = p - 1;
while (s % 2 == 0) s /= 2;
rep(i,8) {
ull a = rand() % (p - 1) + 1, tmp = s;
ull mod = mod_pow(a, tmp, p);
while (tmp != p - 1 && mod != 1 && mod != p - 1) {
mod = mod_mul(mod, mod, p);
tmp *= 2;
}
if (mod != p - 1 && tmp % 2 == 0) return false;
}
return true;
}
const int MAX_PR = 5000000;
bitset<MAX_PR> isprime;
vi eratosthenes_sieve(int lim) {
isprime.set(); isprime[0] = isprime[1] = 0;
for (int i = 4; i < lim; i += 2) isprime[i] = 0;
for (int i = 3; i*i < lim; i += 2) if (isprime[i])
for (int j = i*i; j < lim; j += i*2) isprime[j] = 0;
vi pr;
repA(i,2,lim-1) if (isprime[i]) pr.push_back(i);
return pr;
}
vector<ull> pr;
ull f(ull a, ull n, ull &has) {
return (mod_mul(a, a, n) + has) % n;
}
bool square(int x){
int z = sqrt(x);
return (z*z==x)||((z+1)*(z+1)==z)||((z-1)*(z-1)==x);
}
int factor(int d) {
int res = 1;
for (int i = 0; i < sz(pr) && pr[i]*pr[i]<=d; i++)
if (d % pr[i] == 0) {
int cnt = 1;
while (d % pr[i] == 0) d /= pr[i],cnt++;
res*=cnt;
}
//d is now a product of at most 2 primes.
if (d > 1) {
if (prime(d))
res*=2;
else if(square(d)){
res*=3;
}
else res*=4;
}
return res;
}
void init(int bits) {//how many bits do we use?
vi p = eratosthenes_sieve(1 << ((bits + 2) / 3));
pr.assign(all(p));
}
const ll mod = 1e9+7;
int tf[1500];
typedef vector<pair<ll,ll>> vl;
typedef pair<ll,ll> pll;
pll mul(pll x,pll y){
return mp((x.fst*y.fst-x.snd*y.snd)%mod,(x.fst*y.snd+x.snd*y.fst-x.snd*y.snd)%mod);
}
pll add(pll x,pll y){
return mp((x.fst+y.fst)%mod,(x.snd+y.snd)%mod);
}
pll modpow(pll a, ll e) {
if (e == 0) return mp(1,0);
pll x = modpow(mul(a,a), e >> 1);
return e & 1 ? mul(x,a) : x;
}
ll modpow(ll a, ll e) {
if (e == 0) return 1;
ll x = modpow(a * a % mod, e >> 1);
return e & 1 ? x * a % mod : x;
}
void pre(){
init(30);
rep(i,1345){
rep(j,11) if((1<<j)&i) tf[i]+=modpow(3,j);
}
}
void ifwht(vl& a){
int s = 3;
while(sz(a)>=s){
int i = 0,j=s/3,k=2*j;
while(i<sz(a)){
pll x = a[i],y=a[j],z=a[k];
pll w2 = mp(0,1);
pll w = mp(-1,-1);
a[i] = add(x,add(y,z));
a[j] = add(x,add(mul(y,w),mul(z,w2)));
a[k] = add(x,add(mul(y,w2),mul(z,w)));
i++,j++,k++;
if(i%s==s/3) i+=2*s/3,j+=2*s/3,k+=2*s/3;
}
s*=3;
}
}
void fwht(vl& a){
int s = 3;
while(sz(a)>=s){
int i = 0,j=s/3,k=2*j;
while(i<sz(a)){
pll x = a[i],y=a[j],z=a[k];
pll w = mp(0,1);
pll w2 = mp(-1,-1);
a[i] = add(x,add(y,z));
a[j] = add(x,add(mul(y,w),mul(z,w2)));
a[k] = add(x,add(mul(y,w2),mul(z,w)));
i++,j++,k++;
if(i%s==s/3) i+=2*s/3,j+=2*s/3,k+=2*s/3;
}
s*=3;
}
}
vl conv(vl a,vl b){
vl ans(sz(a));
fwht(a),fwht(b);
rep(i,sz(a)){
ans[i] = mul(a[i],b[i]);
}
ifwht(ans);
ll inv = modpow(sz(ans),mod-2);
rep(i,sz(ans)){
ans[i].fst=ans[i].fst*inv%mod;
ans[i].snd=ans[i].snd*inv%mod;
}
return ans;
}
void solve(){
int n,k;cin>>n>>k;
map<int,int> m;
rep(i,n){
int x;cin>>x;
int y = factor(x);
m[y-1]++;
}
vl a(modpow(3,11),mp(0,0)),ans(modpow(3,11),mp(0,0));
trav(i,m) {
a[tf[i.fst]].fst=i.snd;
}
fwht(a);
rep(i,sz(a)){
a[i] = modpow(a[i],k);
}
ifwht(a);
ll inv = modpow(sz(a),mod-2);
cout<<(modpow(n,k)-a[0].fst*inv%mod+mod)%mod<<'\n';
}
int main() {
cin.sync_with_stdio(0); cin.tie(0);
cin.exceptions(cin.failbit);
pre();
int n;cin>>n;
rep(i,n) {
solve();
}
return 0;
}