GUESS_ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: satyam_343
Editorialist: iceknight1093

DIFFICULTY:

3273

PREREQUISITES:

Math

PROBLEM:

Alice and Bob play a game.
Carol gives Alice an integer A and Bob an integer B, such that 1 \leq A, B \leq N and \max(A, B) is a strictly larger multiple of \min(A, B).

Alice and Bob take turns playing, Alice goes first.
On their turn, a player does the following:

  • If they certainly know the value of \min(A, B), they’ll say it and the game ends.
  • Otherwise, they say “I don’t know” and the turn passes to the other player.

If any valid (A, B) pair can be chosen equally randomly, find the expected number of turns the game takes.
Note that Alice and Bob only know their numbers (and subsequently, whatever the other person says) — in particular, neither player knows N.

EXPLANATION:

First, let’s analyze who wins a single instance of the game and how many turns it takes.

  • If A = 1, Alice knows that \min(A, B) = 1 for sure and wins.
    This took 1 turn.
  • Otherwise, A \gt 1, but Alice has no further information: maybe B = 1, and she wouldn’t know.
    So, the turn passes to Bob.
  • On Bob’s turn, he knows that A \neq 1 for sure.
    If B = 1, he knows \min(A, B) = 1 and the game ends in 2 turns.
    Further, if B is a prime, Bob still knows that \min(A, B) = B. This is because Alice doesn’t have 1, and so whatever she has must be a multiple of the prime B.
  • If B is neither 1 nor a prime, Bob cannot answer definitively - there’s no way of him knowing if Alice has a prime, for example. So, the turn passes to Alice.
  • Now, Alice knows that Bob has neither 1 nor a prime.
    So, if she has a prime herself, \min(A, B) is known.
    Further, if A has two prime factors (i.e A = pq or A = p^2 for primes p, q) then Alice knows that \min(A, B) = A; because Bob’s number will also have at least two prime factors, and if A has two, then B must have more than 2.
  • If not, the turn passes to Bob, and so on.

It can be observed that the number of turns depends on the number of prime factors of \min(A, B).
In particular, if players have said “I don’t know” x times, that eliminates all numbers with \lt x prime factors; so if your own number has x prime factors you know it’s the minimum.

So, if \min(A, B) has k prime factors,

  • Suppose A = \min(A, B). Then,
    • if k is even, Alice wins on the (k+1)-th move.
    • Otherwise, Alice wins on the (k+2)-th move (for example, A = 4 and B = 8).
  • Suppose B = \min(A, B). Then,
    • If k is odd. Then, Bob wins on the (k+1)-th move.
    • Otherwise, Bob wins on the (k+2)-th move.

Notice that for any valid pair (x, y), the pairs (x, y) and (y, x) together contribute 2k+3 moves, where k is the number of prime factors of \min(x, y).

This already gives us a solution in \mathcal{O}(N\log N).

  • iterate across all pairs of integers (a, d\cdot a) where 1 \leq a \lt d\cdot a \leq N
    There are \mathcal{O}(N\log N) such pairs.
  • For each such pair, let k be the number of prime factors of a.
    k can be computed for all integers \leq N with a sieve.
    Then, the pairs (a, d\cdot a) and (d\cdot a, a) contribute 2k+3 moves in total.
  • This tells us both the total number of moves and the number of valid pairs, so their ratio is the answer.

This will pass subtask 1.


Subtask 2 follows the same general idea as subtask 1, but some optimizations are needed.

First, the sum of N across all testcases is no longer bounded.
We’ll precompute the answers for all 1 \leq N \leq 10^7 and just output the answer when asked.

However, the precomputation itself can’t be \mathcal{O}(N\log N) time: that will likely be too slow.

We need to find two things:

  • First, the contribution of each pair (a, d\cdot a), to form the numerator.
  • Second, the number of proper divisors of each number, to form the denominator.
    A proper divisor of a number is a factor that’s strictly less than it.

The total number of divisors can be found faster than \mathcal{O}(N\log N) by counting it from the prime factorization of each number.
That is, if N = p_1^{a_1} p_2^{a_2} \ldots p_k^{a_k}, it has (a_1+1)\cdot (a_2+1) \cdot \ldots \cdot (a_k+1) divisors in total; from which we subtract 1 to obtain the count of proper factors.

The prime factorization can be found in \mathcal{O}(N\log \log N) time by sieving only across primes and storing a prime divisor of each number.
This allows us to find divisor counts in \mathcal{O}(N\log \log N) time too.

Now, let’s look at the numerator.
Let \text{pct}[a] denote the number of prime factors of a (which we already computed, above).
Instead of fixing (a, d\cdot a) and adding 2\text{pct}[a]+3 to the total, let’s fix the larger number and count the contribution of all its factors to it.

Consider the integer n.
For each proper divisor d of n, we add 2\text{pct}[d] + 3 to the answer of all integers \geq n.

Let’s take the 3 out: if n has k proper divisors, it’ll contribute a total of 3k.
So, we only really need to know the sum of \text{pct}[d] across all proper factors of n.

Equivalently, we can find for each prime factor in the factorization of n, the number of factors it appears in.
Summed up across all d, it can be seen that this value is nothing but \frac{1}{2} \cdot\text{pct}[n] \cdot \text{divs}[n].
This includes case when n is treated as a divisor of itself, so make sure to subtract that later on.

Notice that we’ve already computed \text{pct} and \text{divs}, which means the above formula can be computed in \mathcal{O}(1) time for a fixed n.
Let \text{val}[n] denote the required value for a fixed n (including the 3\cdot (\text{divs}[n]-1) we took out earlier.

Then, for a given N,

  • The numerator is \text{vals}[1] + \text{vals}[2] + \ldots + \text{vals}[N]
  • The denominator is \text{divs}[1] + \text{divs}[2] + \ldots + \text{divs}[N]

Both of these can be precomputed using prefix sums, and their ratio gives the answer!

TIME COMPLEXITY

\mathcal{O}(N\log\log N + T\log{MOD}), where N = 10^7.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=10000010;
const int LOGN=28;
const ll  TMD=1000000007;
const ll  INF=2147483647;
int T,n;
int pfac[N],f[N],fa[N],fb[N],Sa[N],Sb[N];

void init()
{
	for(int i=2;i<N;i++)
	{
		if(pfac[i]) continue;
		for(int j=2;j*i<N;j++) pfac[j*i]=i;
	}
	for(int i=2;i<N;i++)
	{
		if(pfac[i]) f[i]=f[i/pfac[i]]+1;
		else f[i]=1;
	}
	for(int i=1;i<N;i++)
	{
		fa[i]=2*((f[i]+1)/2)+1;
		fb[i]=2*(f[i]/2+1);
		Sa[i]=Sa[i-1]+fa[i];
		Sb[i]=Sb[i-1]+fb[i];
	}
}

ll pw(ll x,ll p)
{
	if(!p) return 1;
	ll y=pw(x,p>>1);
	y=y*y%TMD;
	if(p&1) y=y*(x%TMD)%TMD;
	return y;
}

ll inv(ll x)
{
	return pw(x,TMD-2);
}

int main()
{
	init();
	scanf("%d",&T);
	while(T--)
	{
    	scanf("%d",&n);
    	int sqn=(int)sqrt(n),cur;
    	ll  P=(-Sa[n]-Sb[n]+TMD*2)%TMD,Q=TMD-n;
    	for(int i=1;i<=sqn;i++)
    	{
	    	P=(P+fa[i]*(n/i))%TMD;
	    	P=(P+fb[i]*(n/i))%TMD;
	    	Q=(Q+n/i)%TMD;
	    }
	    cur=sqn+1;
	    while(cur<=n)
	    {
    		int L=cur,R=n+1,M;
    		while(L+1!=R)
    		{
		    	M=(L+R)>>1;
		    	if(n/M==n/cur) L=M;
				else R=M; 
		    }
		    P=(P+(Sa[L]-Sa[cur-1])*(n/cur))%TMD;
		    P=(P+(Sb[L]-Sb[cur-1])*(n/cur))%TMD;
		    Q=(Q+(L-cur+1)*(n/cur))%TMD;
		    cur=L+1;
    	}
    	Q=Q*2%TMD;
    	printf("%lld\n",P*inv(Q)%TMD);
	}
	
	return 0;
}

Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
 
 
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=2e18; 
#define pb push_back                   
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>             
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif       
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}     
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;     
const ll MAX=10001000;
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;  
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
ll freq[MAX],spf[MAX];
void solve(){     
    ll n; cin>>n;
    ll till=min(n,1010ll),num=0,den=0;    
    for(ll i=1;i<=till;i++){
        num+=(freq[i]-freq[i-1])*(n/i-1);
        den+=2*(n/i-1);  
    }  
    for(ll i=2;i<=n/till;i++){
        ll l=n/(i+1)+1,r=n/i; 
        l=max(l,till+1);
        if(l>r){
            continue;   
        }
        num+=(freq[r]-freq[l-1])*(i-1); 
        den+=2ll*(r-l+1)*(i-1); 
    }
    num%=MOD;
    den%=MOD;
    ll ans=(num*inverse(den,MOD))%MOD;  
    cout<<ans<<nline;   
    return;                                
}                                                  
int main()                                                                                                 
{            
    ios_base::sync_with_stdio(false);                             
    cin.tie(NULL);        
    #ifndef ONLINE_JUDGE                     
    freopen("input.txt", "r", stdin);                                                    
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif                          
    ll test_cases=1;               
    cin>>test_cases;
    for(ll i=1;i<MAX;i++){
        spf[i]=i;
        freq[i]=0;
    }
    for(ll i=2;i<MAX;i++){
        if(spf[i]==i){
            for(ll j=i;j<MAX;j+=i){
                spf[j]=min(spf[j],i);
            }
        }
        freq[i]=freq[i/spf[i]]+1; 
    }
    for(ll i=1;i<MAX;i++){
        freq[i]=2ll*freq[i]+3+freq[i-1];
    }
    while(test_cases--){
        solve();
    } 
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}   
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int maxn = 1e7 + 5;
int spf[maxn], pct[maxn], mul[maxn];
ll divs[maxn], ans[maxn];

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
    
    divs[1] = 1;
    for (int i = 2; i < maxn; ++i) {
        if (spf[i] == 0) {
            for (int j = i; j < maxn; j += i) spf[j] = i;
        }

        int x = i / spf[i];
        pct[i] = 1 + pct[x];
        if (spf[i] == spf[x]) {
            divs[i] = (divs[x] / mul[x]) * (mul[x] + 1);
            mul[i] = mul[x] + 1;
        }
        else {
            divs[i] = divs[x] * 2;
            mul[i] = 2;
        }
        
        ans[i] = 3*divs[i] - 3;
        ans[i] += 1LL * pct[i] * (divs[i] - 2);
    }

    divs[1] = 0;
    for (int i = 2; i < maxn; ++i) {
        divs[i] += divs[i-1] - 1;
        ans[i] += ans[i-1];
    }

    const int mod = 1e9 + 7;
    auto inv = [&] (ll a) {
        a %= mod;
        int pw = mod-2;
        ll r = 1;
        while (pw) {
            if (pw & 1) r = (a * r) % mod;
            a = (a * a) % mod;
            pw /= 2;
        }
        return r;
    };
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        ll out = (ans[n] * inv(2*divs[n]) % mod);
        cout << out << '\n';
    }
}

We can also use the fact that there are atmost 2*\sqrt{n} distinct values of n/i, thus resulting in a much simpler solution for subtask 2 which takes O(T * \sqrt{n})

https://www.codechef.com/viewsolution/99025607

3 Likes

Correct, in fact the author and tester both had this solution — which is why the limits were left low enough for it to pass.

The approach I outlined in the editorial works for higher constraints though; same N but T upto 10^5 or so.