BANQUNT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Jubayer Nirjhor
Tester: Raja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Non-basic maths

PROBLEM:

Given two integers N and M, find the largest subset of integers in the range [1, N] such that there are no two elements x and y in the subset such that \displaystyle\frac{x}{y} = M

Also, find the number of ways to select such subset mod 998244353.

QUICK EXPLANATION

  • Let’s write all numbers in the range [1, N] into form x*M^p with the largest possible p and group them on the basis of the same x. We can select any elements from the same group independently of other groups.
  • Each group is just x*M^p where 0 \leq p \leq P = \log_M{\frac{N}{x}}. Let’s say C = P+1 denote number of elements.
  • If there are an odd number of values in this group, we can just select x, x*M^2, x*M^4, and so on till x*M^P. This gives us (C+1)/2 elements and only one way to choose (C+1)/2 elements.
  • If there are even number of values in this group, we can select at most C/2 elements and the number of ways to select these elements is (C/2+1)

EXPLANATION

Let’s solve this problem by working on say N = 10 and M = 2. Let’s write all pair of numbers which conflict with each other. i.e. all pairs (x, y) such that x = y*M

We have (1, 2), (2, 4), (3, 6), (4, 8), (5, 10). The numbers not appearing in any pairs 7, 9 do not affect any other number, and thus, are included.

We can notice that choosing 3 or 6 doesn’t affect the choice of choosing 5 or 10, but choosing 4 or 8 does affect the choice between 2 and 4

By working out, let’s write all integers in the form x*M^p such that p is maximum possible. The numbers from 1 to 10 are written as
1*2^0, 1*2^1, 3*2^0, 1*2^2, 5*2^0, 3*2^1, 7*2^0, 1*2^3, 9*2^0, 5*2^1
Grouping above by value of x, we have
x = 1: 1*2^0,1*2^1,1*2^2,1*2^3 = 1,2,4,8
x = 3: 3*2^0,3*2^1 = 3,6
x = 5: 5*2^0, 5*2^1 = 5,10
x = 7: 7*2^0 = 7
x = 9: 9*2^0 = 9

The critical observation is, that writing in this form, all the conflict pairs appear adjacent to each other. Also, numbers having different x do not affect each other.

Hence, the problem here is to choose the largest number of elements such that no two elements are adjacent to each other.

For 1,2,4,8 we can choose \{1, 4\}, \{1, 8\} and \{2,8\}. These choices do not affect our selection in other groups. So we choose 2 elements and have 3 ways of doing so. For groups with x = 3 and x = 5, we can choose at most 1 element, and have exactly 2 ways of doing so. For x = 7 and x = 9, we can choose only one element, and only 1 way to do so.

This gives us 2+1+1+1+1 = 6 elements and 3*2*2*1*1 = 12 ways to select elements.

Let’s consider group 3,6,12,24,48. It is obvious that we can select only 3 elements maximum and we have only one way to select 3 elements.

But, what if we have group 3,6,12,24,48, 96, we still can select only 3 elements at max, but now we have multiple ways of doing so.

A way to visualize selection

Let’s pair (3,6) and (12, 24) and (48, 96), and initially, we have selected the second element of each pair. i.e. we have selected \{6,24,96\}. Now, let’s consider pairs from left to right, and add the first element of pair and remove the second element of this pair. We get \{3,24,96\}, and then \{3,12,96\} and \{3,12,48\}.

This gives us 4 ways to choose subset of size 3 such that no two adjacent elements are chosen.

For a general group with even group size, above gives S/2+1 ways to select elements with exactly S/2 elements where S denote group size.

But what happens for odd S. In my visualization, the last element is left unpaired at the end, so choosing that element in subset forces us to choose only the first element of each pair, effectively leaving only one way to choose (S+1)/2 elements.

Hence, we found a way to solve this in O(N). Make all groups, and for each group, if group size S is odd, we get (S+1)/2 elements and one way to choose, otherwise we get S/2 elements and S/2+1 ways to choose. We take the sum of the elements for each group and the product of number of ways to select elements to get the required answer.

Optimizing to fit Time Limit
It’d be a miracle if the above solution gets accepted within our lifetimes.

What we can notice is that while computing the number of elements and the number of ways to choose those in a group, we only used the group size and not the actual elements. So, if we could find out the number of groups with same size for each size, we can solve the problem in time proportional to the number of sizes of groups.

If the group size is S, then the last element must be x*M^{S-1} where x*M^{S-1} \leq N. Assuming x = 1, we get S-1 \leq log_MN. For N = 10^18 and M = 2, we get S \leq 60

So, the maximum group size is at most 60. Now, we need the number of groups of each size. Also, each group is uniquely defined by x.

Another observation is, that each group starts with x*M^0. This means, that for an interval [L, R] the number of groups starting in this range is the number of non-multiple of M in the given interval. This is given by R-(L-1)-(R/M-(L-1)/M)

Hence, we need to decompose interval [1, N] into the set of intervals [L_p, R_p] such that group size for all groups starting in this interval is p.

We can prove that L_p = \displaystyle \bigg\lfloor\frac{N}{M^p}\bigg\rfloor+1 and R_p = \displaystyle \bigg\lfloor\frac{N}{M^{P-1}}\bigg\rfloor, since \displaystyle\frac{N}{M^{p}} gives largest value of x such that x*M^p \leq N

Suppose C_p is the number of groups with group size p, then the largest subset has size \sum C_p*(p+1)/2 and the number of ways is \prod (p/2+1)^{C_p} for all even p

TIME COMPLEXITY

The time complexity of this solution is O(log_MN) per test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MOD = 998244353;

ll bigMod (ll a, ll e) {
  ll ret = 1;
  while (e) {
	if (e & 1) ret = ret * a % MOD;
	a = a * a % MOD, e >>= 1;
  }
  return ret;
}

pair <ll, ll> brute (ll n, ll m) {
  assert(n <= 100000);
  bitset <100069> vis;
  ll size = 0, tot = 1;
  for (int i = 1; i <= n; ++i) {
	if (vis[i]) continue;
	ll cur = i, len = 0;
	while (cur <= n) vis[cur] = 1, ++len, cur *= m;
	size += 1 + len >> 1;
	if (~len & 1) tot *= 2 + len >> 1, tot %= MOD;
  }
  return make_pair(size, tot);
}

ll t, n, m;

int main() {
  cin >> t;
  while (t--) {
	cin >> n >> m;
	assert(2 <= n and n <= 1000000000000000000LL);
	assert(2 <= m and m <= 1000000000000000000LL);
	ll size = 0, tot = 1;
	__int128 one = 1, two = m;
	for (ll k = 1; k <= 69 and one <= n; ++k, one *= m, two *= m) {
	  // am^{k - 1} <= n, am^k > n
	  ll l = n / two, r = n / one, cnt = (r - l) - (r / m - l / m);
	  size += cnt * (1 + k >> 1);
	  if (~k & 1) tot *= bigMod(2 + k >> 1, cnt), tot %= MOD;
	}
	cout << size << " " << tot << '\n';
  }
  return 0;
}
Tester's Solution
//raja1999

//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")

#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string> 
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip> 
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define int ll

typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;


//std::ios::sync_with_stdio(false);
int power(int a,int b){
	int res=1;
	while(b>0){
		if(b%2){
			res*=a;
			res%=mod;
		}
		b/=2;
		a*=a;
		a%=mod;
	}
	return res;
}
int n,m;
pii solve(int a){
	if(a>n){
		return {0,1};
	}
	int val=a,c=1,p=1,size=0,ways=0,u,rem;
	while(1){
		if(val>(n/m)){
			break;
		}
		val*=m;
		p*=m;
		c++;
	}
	u=n/p;
	pii ans=solve(u+1);
	rem=u/m;
	rem-=(a-1)/m;
	u-=rem;
	size=((c+1)/2)*(u-a+1);
	if(c%2)
		ways=1;
	else
		ways=power(c/2+1,u-a+1);
	size+=ans.ff;
	ways*=ans.ss;
	ways%=mod;
	//cout<<a<<" "<<ways<<" "<<size<<endl;
	return {size,ways};
}
main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	while(t--){
		pii ans;
		cin>>n>>m;
		ans=solve(1);
		cout<<ans.ff<<" "<<ans.ss<<endl;
	}
	return 0;
} 
Editorialist's Solution
import java.util.*;
import java.io.*;
class BANQUNT{
	//SOLUTION BEGIN
	long MOD = 998244353;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    long[] ans = solve(nl(), nl());
	    pn(ans[0]+" "+ans[1]);
	}
	//O(log_M(N)) solution
	long[] solve(long N, long M){
	    long pw = 1;
	    long last = N;
	    long ways = 1, max = 0;
	    for(int size = 1; pw <= N; size++){
	        long nxt = pw;
	        if(pw <= N/M)nxt *= M;
	        else nxt = N+1;
	        long hi = N/nxt;
	        //L-1 is hi, R is last
	        long cnt = (last-hi-(last/M-hi/M));//Number of groups starting in range [hi+1, last]
	        max += cnt*((size+1)/2);
	        if(size%2 == 0)
	            ways = (ways*pow(size/2+1, cnt, MOD))%MOD;
	        last = hi;
	        pw = nxt;
	    }
	    return new long[]{max, ways};
	}
	//O(N) solution
	long[] brute(int N, int M){
	    long ways = 1;int sz = 0;
	    for(int i = 1; i<= N; i++){
	        if(i%M == 0)continue;
	        int cur = i, cnt = 0;
	        while(cur <= N){
	            cnt++;
	            cur *= M;
	        }
	        sz += (cnt+1)/2;
	        if(cnt%2 == 0)ways = (ways*(cnt/2+1))%MOD;
	    }
	    return new long[]{sz, ways};
	}
	long pow(long a, long p, long MOD){
	    long o = 1;a%=MOD;
	    for(;p>0;p>>=1){
	        if((p&1)==1)o = (o*a)%MOD;
	        a = (a*a)%MOD;
	    }
	    return o;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new BANQUNT().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

10 Likes

Can anyone help me debug my solution - https://www.codechef.com/viewsolution/34620870

I think my code runs in the above complexity. Can someone help with TLE? Link

here’s my solution : https://www.codechef.com/viewsolution/34620405

I have done exactly same, but i have done in PYPY 3, I am getting TLE. All my submission got TLE. Can anyone suggest a way to remove TLE in python.

Pretty sure you have integer overflows in the line A *= m;.
Maybe that’s the reason.

You have an infinite loop. Just try N = 6 and M = 3. Doesn’t matter how many times you divide N by M, you will never reach 1.

1 Like

Thanks! That was the reason.

Thanks a lot. It got accepted.

very nice editorial
please correct me if im wrong

  • if i want a group of size 1: will lie between |N / M | + 1 to N

    group of size k: will lie between |N / M^k| to |N / M^(k - 1)|

  • count of number of groups for any range will be number of x’s

  • all x’s are basically numbers from L to R which arent a multiple of M

    number of x’s are (R - L + 1) - ((R / M) - ((L - 1) / M))

3 Likes

@afaxnraner can someone please help me with TLE?

    #include<bits/stdc++.h>
    using namespace std;
    #define MOD 998244353
    #define ll long long
    inline ll mul(ll a, ll b)
    {
    	a = ((a % MOD) + MOD) % MOD;
    	b = ((b % MOD) + MOD) % MOD;
    	return (a * b) % MOD;
    }
    ll power(ll a, ll b)
    {
    	a = a % MOD;
    	ll res = 1;
    	while (b)
    	{
    		if (b & 1)
    			res = (res * a) % MOD;
    		b >>= 1;
    		a = (a * a) % MOD;
    	}
    	return res;
    }
    int main()
    {
    	ios::sync_with_stdio(0);
    	cin.tie(0); cout.tie(0);

    	int t; cin >> t;
    	ll m, n, L, R, flag, x, sz, poss;
    	while (t--)
    	{
    		cin >> n >> m;
    		flag = 1;
    		sz = 0;
    		poss = 1;
    		for (ll i = 1; (n / flag); i++)
    		{
    			L = (n / (flag * m)) + 1;
    			R = n / flag;
    			x = (R - L + 1) - ((R / m) - ((L - 1) / m));
    			sz = sz + (x * ((i + 1) / 2));
    			if (i % 2 == 0)
    				poss = mul(poss, power(1 + (i / 2), x));
    			flag *= m;
    		}
    		cout << sz << " " << poss << "\n";
    	}
    	return 0;
    }

Probably an integer overflow at flag *= m;

x/y!=m . Is it only for x%y==0 or floored division also?

samjh mein nahi aaya par sunn ke acha laga. :sweat_smile:

2 Likes

@afaxnraner thanks!!

only for x%y==0

Any idea why my code gives WA?
https://www.codechef.com/viewsolution/34636748

Great Question !!!

2 Likes

can anyone help me to find out why I am getting runtime Error?
Solution is : https://www.codechef.com/viewsolution/34638659

Can anyone help me to find out why my code is giving WA?
https://www.codechef.com/viewsolution/34641540

Decrement b in power function when b is odd