PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author:
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Euclidean algorithm, prime factorization, the inclusion-exclusion principle
PROBLEM:
Given N and M, consider the set of all binary strings with exactly M distinct subsequences.
Count the number of inversions in the N-th lexicographically largest among them.
EXPLANATION:
First, we must analyze when exactly a string can have M distinct non-empty subsequences.
Let’s make things a bit easier for ourselves and allow the empty subsequence too (so we’re looking for M+1 subsequences instead).
Suppose we have a binary string S. Let’s try to count the number of distinct subsequences it has.
One way to do this, is to iterate through the string and try to compute the number of new subsequences ending at each index.
Say we’re considering index i, and w.l.o.g S_i = 1.
Then, it turns out that the number of new subsequences ending at index i is exactly the number of distinct subsequences ending with a 0 so far.
Proof
Consider some new subsequence ending at index i.
Let’s look at the last 0 in this subsequence; suppose at index j \lt i.
Note that the subsequence must then include every 1 that’s present after index j: if it did not, then the subsequence would also appear ending at an earlier index, and wouldn’t be new here.Now, we observe that any subsequence ending with a 0 can be uniquely extended to end at index i, by just appending every 1 after it.
This gives us a bijection between the new subsequences ending at index i and subsequences so far ending with 0, as claimed.
(Note that the subsequence with only ones is also taken care of here, since it can be thought of as taking the empty subsequence and then appending all ones after it.)
Similarly, if S_i = 0 instead, the number of new subsequences would just be the number of distinct subsequences ending with a 1 so far.
With the above knowledge, let’s define |S| pairs (x_i, y_i), where x_i is the number of distinct subsequences (including the empty subsequence) ending with a 0, in the prefix of length i of S.
y_i is defined similarly but for subsequences ending with a 1.
Note that with these definitions, the number of distinct subsequences in the length i prefix of S equals simply x_i + y_i - 1 (1 is subtracted because the empty subsequence is counted twice).
Our previous observation gives us a relation between these pairs: depending on the value of S_i, we’ll have one of:
- (x_i, y_i) = (x_{i-1}, x_{i-1} + y_{i-1}), or
- (x_i, y_i) = (x_{i-1} + y_{i-1}, y_{i-1})
The base case is, of course, (x_0, y_0) = (1, 1).
Here we make the most important observation necessary to solve this task: this process is really just what happens when you run the Euclidean algorithm!
In particular, notice that if you reverse the process, the pairs change as either
(x, y) \to (x, y-x) or (x, y) \to (x-y, y) depending on which of x or y is greater.
Observe that this choice is unique (one of x or y will be greater than the other - in particular the last character of the string will always have strictly more subsequences ending with it), so if we’re given the sequence of pairs it’s also possible to reconstruct the string uniquely.
Now, given that we’re really just running the Eucliean algorithm starting with the pair (x_{|S|}, y_{|S|}), and the final state reached is (1, 1), this is equivalent to saying that \gcd(x_{|S|}, y_{|S|})=1.
In particular, note that we now have a way to classify all binary strings with K distinct subsequences.
There exists a bijection between binary strings with K distinct subsequences, and ordered pairs of positive integers (x, y) such that \gcd(x, y) = 1 and x+y-1 = K.
This is a powerful criterion - observe that \gcd(x, y) = 1 and x+y = K+1 is equivalent to saying that 1 \leq x \leq K+1 and \gcd(x, K+1) = 1.
So, each binary string with K distinct subsequences corresponds uniquely to an integer that’s not larger than K+1, and is coprime to it.
There are, by definition, \varphi(K+1) such integers (\varphi denotes Euler’s totient function).
Let’s go back to the original problem.
We’re interested in strings with M+1 distinct subsequences.
As noted above, there are exactly \varphi(M+2) such strings.
So, if N \gt \varphi(M+2), no valid string exists and the answer is -1.
When N \leq \varphi(M+2) we need to find the appropriate string and compute its inversion count.
As it turns out, this is not too hard.
Observing the bijection between coprimes to M+2 and the string their Euclidean algorithm constructs, it can be seen that the smaller the number, the smaller the string it constructs (in lexicographic order).
This means our first order of business is to find the N-th number coprime to M+2.
First, we need to compute \varphi(M+2) to compare against N. This can be done in \mathcal{O}(\sqrt M) time by prime-factorizing M+2 naively.
Once this is done, the N-th number coprime to M+2 can be found using binary search and inclusion-exclusion on the prime divisors of M+2.
Details
Let the distinct prime factors of M+2 be p_1, p_2, \ldots, p_k.
Note that an integer is not coprime to M+2 if and only if it has some p_i as a prime factor.So, to count the number of coprimes to M+2 that are \leq x for some fixed x,
- Start with all the elements in [1, x].
- Subtract out all multiples of each p_i, since these are not coprime to M+2.
- Add back in all multiples of p_ip_j for each i \lt j
- Subtract out all multiples of products of three primes.
- And so on.
This can be done in \mathcal{O}(2^k k) time by just iterating through all masks of the primes, and for our bounds k is no more than 11 so this is pretty fast.
Once the appropriate integer x is known, running the Euclidean algorithm on (x, M+2-x) will allow us to reconstruct the string, and hence compute its inversion count.
Note that since M \leq 10^{12} the constructed string might be too long to explicitly build - however, since it’s constructed from the Euclidean algorithm, the run-length encoding of the string can be found in \mathcal{O}(\log M) time (and length) by replacing repeated subtractions with the modulo operation.
The inversion count of a binary string in RLE form is also easy to compute in linear time, which finishes the task.
TIME COMPLEXITY:
\mathcal{O}(\sqrt M + 2^{\omega(M+2)}\cdot\omega(M+2)\cdot\log M) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define fast \
ios_base::sync_with_stdio(0); \
cin.tie(0); \
cout.tie(0);
vector<ll> p,d;
void sol(ll x){
ll y=x,var=0;
bool ok=true;
for(auto i:p){
ll cnt=0;
while(y%i==0){
cnt++;
y/=i;
}
if(cnt>=2){
ok=false;
break;
}
else if(cnt==1) var++;
}
if(ok){
if(var%2) d.push_back(-x);
else d.push_back(x);
}
}
ll sol1(ll x){
ll ans=0;
for(auto i:d){
if(i>0) ans+=(x/i);
else ans-=(x/(-i));
}
return ans;
}
int main(){
fast;
ll t;
cin>>t;
while(t--){
ll m,n;
cin>>m>>n;
n++;
ll xx=n+1;
p.clear();
d.clear();
for(ll i=2;i*i<=xx;i++){
if(xx%i==0){
p.push_back(i);
while(xx%i==0) xx/=i;
}
}
if(xx>1) p.push_back(xx);
xx=n+1;
for(ll i=1;i*i<=xx;i++){
if(xx%i==0){
sol(i);
ll j=xx/i;
if(j!=i) sol(j);
}
}
ll tot=sol1(n);
m=tot-m+1;
if(m<=0) cout<<"-1\n";
else{
ll l=1,r=n,res=n;
while(l<=r){
ll mid=(l+r)/2;
ll val=sol1(mid);
if(val>=m){
res=mid;
r=mid-1;
}
else l=mid+1;
}
l=res;
r=n+1-res;
ll ans=0,sum=0;
while(l!=1 || r!=1){
if(l>r){
sum+=(l-1)/r;
l-=(l-1)/r*r;
}
else{
ll ex=(r-1)/l;
ans+=ex*sum;
r-=(r-1)/l*l;
}
}
cout<<ans<<"\n";
}
}
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
ll go(ll x, ll y, ll z){
if(x == 1){
return (y-1)*z;
}
if(y == 1){
return 0;
}
if(x == 1 and y == 1) return 0;
assert(x != y);
if(x > y){
// 0
return go(x%y,y,z+x/y);
}
else{
// 1
return (y/x)*z+go(x,y%x,z);
}
}
void solve(int test_case){
ll n,m; cin >> n >> m;
m += 2;
// find the nth smallest x s.t gcd(x,m) = 1
vector<ll> primes;
{
ll x = m;
for(ll i = 2; i*i <= x; ++i){
bool ok = false;
while(x%i == 0){
x /= i;
ok = true;
}
if(ok) primes.pb(i);
}
if(x > 1) primes.pb(x);
}
ll lo = 1, hi = m;
ll fx = 0;
while(lo <= hi){
ll mid = (lo+hi)>>1;
ll cnt = 0;
rep(mask,1<<sz(primes)){
ll coeff = 1, prod = 1;
rep(i,sz(primes)){
if(mask&(1<<i)){
coeff = -coeff;
prod *= primes[i];
}
}
cnt += (mid/prod)*coeff;
}
if(cnt >= n){
fx = mid;
hi = mid-1;
}
else{
lo = mid+1;
}
}
/*
ll c = 0;
ll fx = 0;
rep1(x,m){
if(gcd(x,m) != 1) conts;
c++;
if(c == n){
fx = x;
break;
}
}
*/
if(!fx){
cout << -1 << endl;
return;
}
ll fy = m-fx;
swap(fx,fy);
ll ans = go(fx,fy,0);
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Editorialist's code (PyPy3)
def gen(a, b, ones):
if a == 1 and b == 1: return 0
if a > b:
ones += a//b
a %= b
if a == 0:
a = b
ones -= 1
return gen(a, b, ones)
zeros = b//a
b %= a
if b == 0:
zeros -= 1
b = a
return zeros*ones + gen(a, b, ones)
def calc(m):
res = []
phi = m
p = 2
while p*p <= m:
if m%p == 0:
phi = phi*(p-1)//p
while m%p == 0: m //= p
res.append(p)
p += 1
if m > 1:
phi = phi*(m-1)//m
res.append(m)
return res, phi
for _ in range(int(input())):
n, m = map(int, input().split())
m += 2
pfac, phi = calc(m)
if n > phi:
print(-1)
continue
n = phi - n + 1
pct = len(pfac)
lo, hi = 1, m-1
while lo < hi:
mid = (lo + hi)//2
ct = 0
for mask in range(2**pct):
val, sgn = 1, 1
for i in range(pct):
if mask & (1 << i):
val *= pfac[i]
sgn *= -1
ct += sgn * (mid // val)
if ct >= n: hi = mid
else: lo = mid + 1
# lo
print(gen(lo, m-lo, 0))