BIGTREE : Editorial

Link to the Contest:
https://www.codechef.com/CSPK2020

BIGTREE : Editorial
Author: CodeChef Public Repository
Tester: easy3279
Editorialist : chaitanya_44
Difficulty : Hard
Problem:
There are all kinds of trees in Chef’s garden. There are black sapote trees. There are big santol trees. And of course, there are binary search trees.

Chef’s binary search trees are grown in a very specific way. Binary search trees are grown by inserting nodes into it. The first inserted node is called the root. Unlike most of Chef’s other trees, binary search trees are grown downwards, and the root is at the top of the tree.

Each node is connected by edges to at most two other nodes below it, called its left child and right child. These can be “null”, meaning there are no such children. Finally, each node also has a number on it, called its label.

To grow a binary search tree, you insert a new node to it. The exact rules for inserting a new node are specified by the following pseudocode:

def insert(x, y)
if x == null
return y

if y.label < x.label
    x.left  = insert(x.left,  y)
else
    x.right = insert(x.right, y)

return x

To insert a new node y to the tree with root x, simply call insert(x, y). The value returned is the root of the new tree.

The following is an example of a binary search tree:

3
/
1 5

7
The root of this tree has label 3.

After inserting a new node with a label of 3, the binary search tree becomes:

3
/
1 5
/
3 7
The depth of a node is the number of edges in the path from that node to the root. For example, the node with label 7 above has a depth of 2, while the node with label 1 has a depth of 1.

Chef has a total of T binary search trees in his garden. Each tree has four numbers associated with it: a, b, m and N. This means that the tree has a total of N nodes, and the kth inserted node has label (a + bk) mod m. So the root, being the first inserted node, has label (a + b) mod m.

For each of Chef’s binary search trees, can you determine the depth of the node that was inserted last?

Input
The first line of the input contains an integer T denoting the number of trees. The description of T trees follows.

Each test case consists of a single line containing four space separated integers a, b, m and N.

Output
For each test case, output a single line containing a single integer, denoting the depth of the node that was inserted last in the binary search tree.

Constraints
1 ≤ T ≤ 5×104
1 ≤ N ≤ 1016
0 ≤ a, b < m ≤ 108

Solution:

CPP14

#include
#include
#include
#include
#include
#include
#include
#include

#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i–)

using namespace std;

typedef long long LL;
typedef double db;

int get(){
char ch;
while(ch=getchar(),(ch<‘0’||ch>‘9’)&&ch!=‘-’);
if (ch==‘-’){
int s=0;
while(ch=getchar(),ch>=‘0’&&ch<=‘9’)s=s10+ch-‘0’;
return -s;
}
int s=ch-‘0’;
while(ch=getchar(),ch>=‘0’&&ch<=‘9’)s=s
10+ch-‘0’;
return s;
}
LL gcd(LL x,LL y){return !y?x:gcd(y,x%y);}

void exgcd(LL a,LL b,LL &x,LL &y){
if (!b){x=1;y=0;return;}
LL tx,ty;
exgcd(b,a%b,tx,ty);
x=ty,y=tx-a/b*ty;
}

LL a,b,n,m,v;
LL inv,tmp,g,cir;

LL getw0(LL a,LL b,LL m,LL v){
//a+bx mod m<=v
a=(a%m+m)%m,b=(b%m+m)%m;
if (v<a%g)return -1;
LL now=(a+b)%m;
if (now>v){
LL p=(m-b)%m,q=m;
if (p<=now){
if (now%p>v)now%=p;
else{
now=now%p+v/pp;
while(now>v)now-=p;
while(now+p<=v)now+=p;
}
}
for(;now>v;){
LL q_=q%p;
if (!q_)return -1;
LL p_=p%q_;
LL tot=p/q_;
while(p_<=now&&now>v){
if (p<=now){
if (now%p>v)now%=p;
else{
now=now%p+v/p
p;
while(now>v)now-=p;
while(now+p<=v)now+=p;
}
}
else{
LL key=p-(p-now)/q_q_;
while(key>now)key-=q_;
while(key+q_<=now)key+=q_;
if (!key)key+=q_;
if (now%key>v)now%=key;
else{
now=now%key+v/key
key;
while(now>v)now-=key;
while(now+key<=v)now+=key;
}
}
}
p=p_,q=q_;
}
}
return now;
}

LL getw1(LL a,LL b,LL m,LL v){
if (v==m-1)return -1;
//a+bx mod m>v
LL ret=getw0(-a-1,-b,m,m-v-2);
if (ret==-1)return -1;
ret=((m-ret-1)%m+m)%m;
return ret;
}

LL gettim(LL v){
LL ret=((v-a)/g*inv%cir+cir)%cir;
if (ret==0)ret=cir;
return ret;
}

void solve(){
cin>>a>>b>>m>>n;
a%=m,b%=m;
if (b==0){
cout<<n-1<<endl;
return;
}
v=((n%m)b+a)%m;
g=gcd(m,b);
cir=m/g;
exgcd(b,m,inv,tmp);
//down
LL st=getw0(a,b,m,v);
LL ans=1;
for(;st<v;){
LL u=getw0(0,b,m,v-st);
if (u==-1)break;
LL ad=(v-st)/u;
st+=ad
u;
ans+=ad;
}
LL delt=((v-a)/g*inv%cir+cir)%cir;
if (delt==0)delt=cir;
ans=ans+(n-delt)/cir-1;

//up
st=getw1(a,b,m,v);
if (st!=-1&&gettim(st)<n){
	ans++;
	LL now=gettim(st);
	for(;st>v+1;){
		LL u=getw0(-1,m-b,m,st-v-2);
		if (u==-1)break;
		u++;
		LL t=gettim((m*2-u+a)%m);
		LL ad=(st-v-1)/u;
		if (now+t*ad>n){
			ad=(n-now)/t;
			ans+=ad;
			st-=u*ad;
			now+=t*ad;
			break;
		}
		else{
			ans+=ad;
			st-=u*ad;
			now+=t*ad;
		}
	}
}
cout<<ans<<endl;

}

int main(){
for(int T=get();T;T–)solve();
return 0;
}

PYTHON3.6:

cook your dish here

class Node:
def init(self, val):
self.l = None
self.r = None
self.v = val
self.depth = 0

class Tree:
def init(self):
self.root = None

def getRoot(self):
    return self.root

def add(self, val):
    n = None
    if self.root is None:
        n = Node(val)
        self.root = n
    else:
        n = self._add(val, self.root)
    return n
    
def _add(self, val, node):
    if val < node.v:
        if node.l is not None:
            n = self._add(val, node.l)
            return n
        else:
            n = Node(val)
            n.depth = node.depth+1
            node.l = n
            return n
    else:
        if node.r is not None:
            n = self._add(val, node.r)
            return n
        else:
            n = Node(val)
            n.depth = node.depth+1
            node.r = n
            return n

def find(self, val):
    if self.root is not None:
        return self._find(val, self.root)
    else:
        return None

def _find(self, val, node):
    if val == node.v:
        return node
    elif (val < node.v and node.l is not None):
        self._find(val, node.l)
    elif (val > node.v and node.r is not None):
        self._find(val, node.r)

def deleteTree(self):
    # garbage collector will do this for us. 
    self.root = None

def printTree(self):
    if self.root is not None:
        self._printTree(self.root)

def _printTree(self, node):
    if node is not None:
        self._printTree(node.l)
        print(str(node.v) + ' ')
        self._printTree(node.r)

T = int(input(‘’))
for i in range(0,T):
a, b, m, N = input(‘’).split(" ")
a = int(a)
b = int(b)
m = int(m)
N = int(N)
tree = Tree()
n = None
for k in range(0,N):
label = (a+(b*(k+1)))%m
#print(label)
n = tree.add(label)
print(n.depth)