GCDMASK - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Shahjalal Shohag
Tester: Radoslav Dimitrov
Editorialist: Taranpreet Singh

DIFFICULTY

Hard

PREREQUISITES

Number theory, Lagrange Interpolation and Faulhaber’s formula

PROBLEM

For a non-negative integer X, the function G(X) is defined as the greatest common divisor of all the integers generated by all submasks of the binary representation of X, i.e. integers obtained by replacing some (possibly none or all) 1-s in this binary representation by 0-s.

You are given an integer N. Find S = \sum_{X=1}^N X^{G(X)} modulo 998,244,353.

QUICK EXPLANATION

  • GCD of all submasks of N is given as 2^{lb(N)} where lb(N) denotes the lowest set bit of N. So we need to compute \sum_{x = 1}^N x^{2^{lb(x)}}
  • Let’s group all x by lb(x), we get log_2(N) groups, and all values x in group are odd multiples of 2^{lb(x)}. So, we can take common factors and simplify expression into one consisting only of power sums (n, p) where p can only be a power of 2 and n*p \leq N holds.
  • To compute power sums, we can use lagrange interpolation for small p and for larger p, we have small n which allow us to precompute the power sums for larger p beforehand.

EXPLANATION

Submasks and powers seem daunting, so let’s try to get something more pleasent.

Let’s write all submasks of a number and find the smallest number among those. We can be sure that the GCD doesn’t exceed that number. The smallest number is 2^{lb(N)} where lb(N) denotes the lowest set bit of N. But is that number guaranteed to be the GCD of the remaining submasks?

Actually yes. Let S denote the set bits of N. Since all other submasks can be written as sum of some terms of form 2 raised to p where p \geq lb(N) and p \in S, thus, all terms of this summation are multiple of 2^{lb(N)}.

Hence, we are required to compute \displaystyle\sum_{x = 1}^N x^{2^{lb(x)}} for each query. For subtask 1, we can just precompute this for all N and answer queries.

Another thing we can notice is that lb(x) can take at most log_2(N) different values, so let us group all values of x by lb(x), we get

\displaystyle \sum_{p = 1}^N p^p * \sum_{x = 1}^{N/p} x^p where p is a power of 2 and x takes only odd values.

Let’s subtract even terms from all terms to get sum of odd terms.
\displaystyle \sum_{p = 1}^N p^p * \Bigg[ \sum_{x = 1}^{N/p} x^p - \sum_{x = 1}^{N/(2*p)} (2*x)^p \Bigg] where p is a power of 2

\displaystyle \sum_{p = 1}^N p^p * \Bigg[ \sum_{x = 1}^{N/p} x^p - 2^p *\sum_{x = 1}^{N/(2*p)} x^p \Bigg] where p is a power of 2

If ps(N, P) denote \displaystyle\sum_{x = 1}^N x^p, then above can be written as
\displaystyle \sum_{p = 1}^N p^p * \Bigg[ ps(N/p, p) - 2^p *ps(N/(2*p), p)\Bigg] where p is a power of 2

Now, we only need to compute power sums efficiently.

Notice the fact that for every call to ps(n, p), p is a power of 2 and n*p \leq N. Also, Faulhaber’s formula states that power sum for p-th powers can be written as p+1 degree polynomial, thus allowing us to compute ps(N, p) in O(p) by using lagrange interpolation.

This can work for small p, but cannot work for large p. Here’s the fact that p*n \leq 10^9 comes into play.

For some specific p, we only need to compute power sums up to 10^9/p. Considering p = 256, we have to precompute power sums up to ~ 4*10^6 for each power of p which we can compute beforehand.

So, for p < LIM, we use lagrange interpolation (By precomputing first p+1 points beforehand for each p < LIM) and precomputing for p \geq LIM.

Hence, we need to compute power sums for p \geq LIM where p is a power of 2, and also compute p+2 points for lagrange interpolation (it’d be useful to use unit distance points from 0 to p+1, as we can use inverses of factorials in denominators for lagrange interpolation.)

Alternative to Lagrange interpolation
In case you do not want to use Lagrange interpolation (whatever reasons you may have), you may also apply Faulhaber’s formula, precomputing Bernoulli numbers finding Faulhaber polynomials. It shall give same asymptotic complexity, and more than a mild headache if you are not sure what you are doing. Refer my solution below for this.

Bonus (Not sure)
Instead of separating odd terms, can we somehow alter Lagrange interpolation to handle sum of powers of first x odd integers? Share your thoughts in comments.

Also, this problem also uses lagrange interpolation to compute power sums, with a nice explanation. Do try that.

Optimizations

  • Do precompute all terms except ps(n, p) in the final summation, to ensure O(log(MAXN)*LIM) running time for queries.
  • Precompute points, but not lagrange polynomial, as it doesn’t save any time during query.
  • Be careful with MOD.

TIME COMPLEXITY

The time complexity is O(MAXN/P*log(MAXN)+T*P*log(MAXN)) where T denote the number of test cases and P is the chosen threshold.

SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>
using namespace std;

const int N = 3e5 + 9, mod = 998244353;

template <const int32_t MOD>
struct modint {
	int32_t value;
	modint() = default;
	modint(int32_t value_) : value(value_) {}
	inline modint<MOD> operator + (modint<MOD> other) const { int32_t c = this->value + other.value; return modint<MOD>(c >= MOD ? c - MOD : c); }
	inline modint<MOD> operator - (modint<MOD> other) const { int32_t c = this->value - other.value; return modint<MOD>(c <    0 ? c + MOD : c); }
	inline modint<MOD> operator * (modint<MOD> other) const { int32_t c = (int64_t)this->value * other.value % MOD; return modint<MOD>(c < 0 ? c + MOD : c); }
	inline modint<MOD> & operator += (modint<MOD> other) { this->value += other.value; if (this->value >= MOD) this->value -= MOD; return *this; }
	inline modint<MOD> & operator -= (modint<MOD> other) { this->value -= other.value; if (this->value <    0) this->value += MOD; return *this; }
	inline modint<MOD> & operator *= (modint<MOD> other) { this->value = (int64_t)this->value * other.value % MOD; if (this->value < 0) this->value += MOD; return *this; }
	inline modint<MOD> operator - () const { return modint<MOD>(this->value ? MOD - this->value : 0); }
	modint<MOD> pow(uint64_t k) const {
	    modint<MOD> x = *this, y = 1;
	    for (; k; k >>= 1) {
	        if (k & 1) y *= x;
	        x *= x;
	    }
	    return y;
	}
	modint<MOD> inv() const { return pow(MOD - 2); }  // MOD must be a prime
	inline modint<MOD> operator /  (modint<MOD> other) const { return *this *  other.inv(); }
	inline modint<MOD> operator /= (modint<MOD> other)       { return *this *= other.inv(); }
	inline bool operator == (modint<MOD> other) const { return value == other.value; }
	inline bool operator != (modint<MOD> other) const { return value != other.value; }
	inline bool operator < (modint<MOD> other) const { return value < other.value; }
	inline bool operator > (modint<MOD> other) const { return value > other.value; }
};
template <int32_t MOD> modint<MOD> operator * (int64_t value, modint<MOD> n) { return modint<MOD>(value) * n; }
template <int32_t MOD> modint<MOD> operator * (int32_t value, modint<MOD> n) { return modint<MOD>(value % MOD) * n; }
template <int32_t MOD> istream & operator >> (istream & in, modint<MOD> &n) { return in >> n.value; }
template <int32_t MOD> ostream & operator << (ostream & out, modint<MOD> n) { return out << n.value; }

using mint = modint<mod>;

vector<mint> fact, finv;
// p = first at least n + 1 points of the n degree polynomial, returns f(x)
// O(n)
mint Lagrange(const vector<mint> &p, mint x) {
	int n = p.size() - 1;
	if (x.value <= n) return p[x.value];

	vector<mint> pref(n + 1, 1), suf(n + 1, 1);
	for (int i = 0; i < n; i++) pref[i + 1] = pref[i] * (x - i);
	for (int i = n; i > 0; i--) suf[i - 1] = suf[i] * (x - i);

	mint ans = 0;
	for (int i = 0; i <= n; i++) {
	    mint tmp = p[i] * pref[i] * suf[i] * finv[i] * finv[n-i];
	    if ((n - i) & 1) ans -= tmp;
	    else ans += tmp;
	}

	return ans;
}

const int M = 6;
vector<mint> P[30];
mint ppw2[30];
int pw2[30];

int32_t main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);

	// precomputing factorials
	fact = finv = vector<mint> (1 << M, 1);
	for (int i = 1; i < (1 << M); i++) fact[i] = fact[i - 1] * i;
	finv[(1 << M) - 1] /= fact[(1 << M) - 1];
	for (int i = (1 << M) - 1; i >= 1; i--) finv[i - 1] = finv[i] * i;

	pw2[0] = 1;
	for (int i = 1; i < 30; i++) pw2[i] = pw2[i - 1] * 2; // 2^i
	for (int i = 0; i < 30; i++) ppw2[i] = mint(pw2[i] % mod).pow(pw2[i]);// (2^i)^(2^i)

	// precomputing first (2^k+1) terms for Lagrange
	for (int k = 0; k < M; k++) {
		mint ans = 0;
		for (int i = 0; i <= (1 << k) + 1; i++) {
			ans += mint(1 + 2 * i).pow(pw2[k]);
			P[k].push_back(ans);
		}
	}

	// precomputing first 2^30/(2^(k+1)) terms
	for (int i = 0; i <= (1 << (30 - M - 1)); i++) {
		mint nw = mint(1 + 2 * i).pow(pw2[M]); // (1 + 2 * i) ^ (2^k)
		P[M].push_back(nw);
	}
	for (int k = M + 1; k < 30; k++) {
		for (int i = 0; i <= (1 << (30 - k - 1)); i++) {
			auto nw = P[k - 1][i] * P[k - 1][i]; // saving extra log (mod) factor
			P[k].push_back(nw);
		}
	}
	for (int k = M; k < 30; k++) {
		for (int i = 1; i <= (1 << (30 - k - 1)); i++) {
			P[k][i] += P[k][i - 1]; // summing up all (1 + 2 * i) ^ (2^k)
		}
	}

	int q; cin >> q;
	while (q--) {
		int n; cin >> n;
		mint ans = 0;
		for (int k = 0; k < 30 && n >= pw2[k]; k++) {
			if (k < M) {
				ans += ppw2[k] * Lagrange(P[k], ((n - pw2[k]) >> (k + 1)) % mod); // use Lagrange for smaller degrees
			}
			else {
				ans += ppw2[k] * P[k][(n - pw2[k]) >> (k + 1)]; // use precomputed values
			}
		}
		cout << ans << '\n';
	}
	return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'

#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back

using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int64_t mod = 998244353;

struct polynomial {
	vector<int> coef;
	polynomial() { coef = {}; }
	polynomial(int sz, int v) {  
		coef.assign(sz, v);
	}

	int eval(int x) {
		int ret = 0, pw = 1;
		for(int i = 0; i < SZ(coef); i++) {
			ret = (ret + pw * 1ll * coef[i]) % mod;
			pw = pw * 1ll * x % mod;
		}

		return ret;
	}

	void trim() {
		while(!coef.empty() && coef.back() == 0) {
			coef.pop_back();
		}
	}

	polynomial operator*(const polynomial &that) {
		polynomial ret(coef.size() + that.coef.size(), 0);
		for(int i = 0; i < SZ(coef); i++) {
			for(int j = 0; j < SZ(that.coef); j++) {
				ret.coef[i + j] = (ret.coef[i + j] + coef[i] * 1ll * that.coef[j]) % mod;
			}
		}

		ret.trim();
		return ret;
	}

	polynomial operator*(const int &x) {
		polynomial ret(coef.size(), 0);
		for(int i = 0; i < SZ(coef); i++) {
			ret.coef[i] = coef[i] * 1ll * x % mod;
		}
	
		ret.trim();
		return ret;
	}

	polynomial operator+(const polynomial &that) {
		polynomial ret(max(coef.size(), that.coef.size()), 0);
		for(int i = 0; i < SZ(ret.coef); i++) {
			if(i < SZ(coef)) ret.coef[i] = (ret.coef[i] + coef[i]) % mod;
			if(i < SZ(that.coef)) ret.coef[i] = (ret.coef[i] + that.coef[i]) % mod;
		}

		ret.trim();
		return ret;
	}

	void flip() {
		for(int i = 0; i < SZ(coef); i++) {
			coef[i] = (mod - coef[i]) % mod; 	
		}

		trim();
	}
};

// ------------
// NUMBER THEORY
// ------------
int pw(int x, int p) {
	int r = 1;
	while(p) {
		if(p & 1) r = r * 1ll * x % mod;
		x = x * 1ll * x % mod;
		p >>= 1;
	}

	return r;
}

int inv(int x) { return pw(x, mod - 2); } 
int fix(int x) { return x < 0 ? (x + mod) : x; }
// ------------

polynomial lagrange_interpolation(int k) {
	static int y[MAXN];
	static int fact[MAXN];
	static polynomial pref[MAXN];
	static polynomial suff[MAXN];

	// Put the polynomial we want to evaluate in y[:].
	y[0] = 0;
	for(int i = 1; i <= k + 1; i++) 
		y[i] = (y[i - 1] + pw(i, k)) % mod;

	pref[0] = polynomial(2, 0);
	pref[0].coef[0] = 0;
	pref[0].coef[1] = 1;

	suff[k + 1] = polynomial(2, 0);
	suff[k + 1].coef[0] = (mod - (k + 1)) % mod;
	suff[k + 1].coef[1] = 1;

	for(int i = 1; i <= k + 1; i++) {
		polynomial curr = polynomial(2, 0);
		curr.coef[0] = (mod - i) % mod;
		curr.coef[1] = 1;
		pref[i] = pref[i - 1] * curr;
	}

	for(int i = k; i >= 0; i--) {
		polynomial curr = polynomial(2, 0);
		curr.coef[0] = (mod - i) % mod;
		curr.coef[1] = 1;
		suff[i] = suff[i + 1] * curr;
	}

	fact[0] = 1;
	for(int i = 1; i <= k + 1; i++)
		fact[i] = fact[i - 1] * 1ll * i % mod;

	polynomial ans;
	for(int i = 0; i <= k + 1; i++) {
		polynomial v(1, 1);
		if(i > 0) v = v * pref[i - 1];
		if(i < k + 1) v = v * suff[i + 1];
	
		v = v * inv(fact[i]);
		v = v * inv(fact[(k + 1) - i]);

		if(((k + 1) - i) & 1) 
			v.flip();

		v = v * y[i];
		ans = ans + v;
	}

	return ans;
}

int N;

void read() {
	// We notice that the G(X) = 2^lowest_bit_pos(X) = X & -X
	// This can be proven by induction with base case X being odd.
	cin >> N;
}

// Used for precompute:
const int B = 8;
vector<int> prec[31 - B];
polynomial lagrange_prec[B];


void solve() {
	int answer = 0;	

	for(int k = 0; (1 << k) <= N; k++) {
		int cnt_with_zero_suff = N / (1 << k);

		int _pw = pw(1 << k, 1 << k);
		int _pw2 = pw(2 << k, 1 << k);

		if(k < B) {
			answer = (answer + _pw * 1ll * lagrange_prec[k].eval(cnt_with_zero_suff)) % mod;
			answer = (answer - _pw2 * 1ll * lagrange_prec[k].eval(cnt_with_zero_suff / 2)) % mod;
			answer = fix(answer);
		} else {
			answer = (answer + _pw * 1ll * prec[k - B][cnt_with_zero_suff]) % mod;
			answer = (answer - _pw2 * 1ll * prec[k - B][cnt_with_zero_suff / 2]) % mod;
			answer = fix(answer);
		}
	}

	cout << answer << endl;
}

void precompute() {
	// Full partial sums precompute for power >= 2^B:

	for(int i = B; i <= 30; i++) {
		prec[i - B].assign((1 << (31 - i)) + 1, 0);
	}

	for(int i = 1; i <= (1 << (31 - B)); i++) {
		int x = i;
		for(int v = 0; v < B; v++) {
			x = x * 1ll * x % mod;
		}

		for(int j = B; i < (int)prec[j - B].size(); j++) {
			// x = i^(2^j)
			prec[j - B][i] = (prec[j - B][i - 1] + x) % mod;		
			x = x * 1ll * x % mod;	
		}
	}

	for(int i = 0; i < B; i++) {
		lagrange_prec[i] = lagrange_interpolation(1 << i);
	}
}

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);

	precompute();

	int T;
	cin >> T;
	while(T--) {
		read();
		solve();
	}

	return 0;
}
Editorialist's Solution (using faulhaber's formula)
import java.util.*;
import java.io.*;
import java.text.*;
class GCDMASK{
	//SOLUTION BEGIN
	int MOD = 998244353;
	int LIM = 7, B = 30, MAXN = (int)1e9;
	int[][] PS = new int[B][];
	int BMAX = 1<<(LIM-1);
	int[] Ber = new int[1+BMAX];//Bernoulli Numbers
	int[][] C = new int[2+BMAX][2+BMAX];
	int[][] Faulhaber = new int[LIM][];
	int[] F = new int[1+BMAX], F2 = new int[1+BMAX];
	void pre() throws Exception{
	    long ct = System.currentTimeMillis();
	    for(int b = LIM; b < B; b++)PS[b] = new int[1+MAXN/(1<<b)];
	    for(int i = 1; i<= MAXN/(1<<LIM); i++){
	        long p = i;
	        for(int b = 0; b< LIM; b++)p = (p*p)%MOD;
	        for(int b = LIM; b < B; b++){
	            if(i >= PS[b].length)break;
	            PS[b][i] = (PS[b][i-1]+(int)p);
	            if(PS[b][i] >= MOD)PS[b][i] -= MOD;
	            p = (p*p)%MOD;
	        }
	    }
	    C[0][0] = 1;
	    for(int i = 1; i<= 1+BMAX; i++)
	        for(int j = 0; j<= 1+BMAX; j++){
	            C[i][j] = C[i-1][j]+(j>0?C[i-1][j-1]:0);
	            if(C[i][j] >= MOD)C[i][j] -= MOD;
	        }
	    Ber[0] = 1;
	    Ber[1] = (int)pow(2, MOD-2);
	    for(int i = 2; i<= BMAX; i++){
	        long s = 0;
	        for(int j = 0; j< i; j++){
	            s += (C[i+1][j]*(long)Ber[j])%MOD;
	            if(s >= MOD)s -= MOD;
	        }
	        s *= pow(i+1, MOD-2);
	        
	        Ber[i] = (int)((1+MOD-s%MOD)%MOD);
	    }
	    for(int p = 1, b = 0; b < LIM; b++, p<<=1){
	        Faulhaber[b] = new int[2+p];
	        long ip = pow(p+1, MOD-2);
	        for(int j = 0; j<= p; j++)
	            Faulhaber[b][p+1-j] = (int)(((Ber[j]*(long)C[p+1][j]%MOD)*ip)%MOD);
	    }
	    int p2 = 1;
	    for(int i = 0; i<= BMAX; i++){
	        F[i] = pow(p2, p2);
	        F2[i] = pow(2*p2, p2);
	        p2 = (p2*2)%MOD;
	    }
	    System.err.println(System.currentTimeMillis()-ct);
	}
	//Returns \sum_{x = 1}^N  x^(2^p)
	int faulhaber_compute(int N, int p) throws Exception{//N, 2^p;
	    int ans = 0;
	    long pow = 1;
	    for(int i = 0; i< Faulhaber[p].length; i++){
	        ans += (int)((pow*Faulhaber[p][i])%MOD);
	        if(ans >= MOD)ans -= MOD;
	        pow = (pow*N)%MOD;
	    }
	    return ans;
	    
	}
	void solve(int TC) throws Exception{
	    int N = ni();
	    long ans = 0;
	    for(int pow2 = 1, p = 0; p< BMAX && N/pow2 > 0; pow2 *= 2, p++){
	        hold(pow2 == 1<<p);
	        int t1 = (int)((F[p]*1L*ps(N/pow2, p))%MOD);
	        int t2 = (int)((F2[p]*1L*ps(N/(2*pow2), p))%MOD);
	        if(t1-t2 < MOD)ans += t1+MOD-t2;
	        else ans += t1-t2;
	        while(ans >= MOD)ans -= MOD;
	    }
	    pn(ans);
	}
	int ps(int n, int p) throws Exception{
	    if(p >= LIM){
	        hold(n < PS[p].length);
	        return PS[p][(int)n];
	    }else
	        return faulhaber_compute(n, p);
	}
	int pow(int a, int p){
	    int ans = 1;
	    for(; p> 0; p>>=1){
	        if((p&1)==1)ans = (int)((ans*1L*a)%MOD);
	        a = (int)((a*1L*a)%MOD);
	    }
	    return ans;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    long ct = System.currentTimeMillis();
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    System.err.println(System.currentTimeMillis()-ct);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new GCDMASK().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

1 Like

Can somebody pls point out why im getting tle ( even for subtask 1) https://www.codechef.com/viewsolution/33446143

Many have also done the same (for subtask 1)…

T \leq 2 \cdot 10^5, N \leq 10^6

1 Like

@physics0523 @akee could u explain your code,Thanks in advance. @akee how is gauss used here , also @physics0523 is your solution randomised or some discrete logic , I am very curious.

Got it :confused:

My solution has very simple improvement.
Let’s set p=32 explained in the editorial, and take cumulative sum.
Now, we want to calculate \sum_{i=1}^{N} (2i-1)^{2^k} for 0\le k \le4. Then, look up in WolframAlpha like this .

my solution

9 Likes

Hello, can you please explain your solution? It seems interesting!