CHFNSWAP - Editorial

PROBLEM LINK:

Division 1
Division 2
Video Editorial

Author: Ritesh Gupta
Tester: Suchan Park
Editorialist: Ritesh Gupta

DIFFICULTY:

EASY

PREREQUISITES:

Number Theory, Math

PROBLEM:

You are numbers from 1 to N lying in ascending order. You can do one swap operation in which you can select two distinct numbers from 1 to N and swap them. The swap is called good if numbers can be divided in two non-empty continuous parts such that both have the same sum.

QUICK EXPLANATION:

  • For the first subtask, you can try the brute force approach where you can try all possible pairs of swaps and check whether it is satisfying the given condition.
  • Let suppose, it is initially possible to divide the numbers into two equal sum continuous parts. If M is such a number that sum from 1 to M is equal to sum from M+1 to N then the answer is only possible by swapping the numbers from the same group, i.e. swap any pair of numbers from 1 to M or M+1 to N.
  • Else if no such valid M is there then we can say there should be X such that sum from 1 to X is less then sum from X+1 to N but sum from 1 to X+1 is greater then sum from X+2 to N. In this case, only possible good swaps are pairs of numbers in which the first number belongs to the first group and second from the second group.
  • For the second subtask, we first look whether there exists any valid M or not. If YES then the answer would be simple math but NO then find some valid X and then for any one group try whether there exist some conjugate in another group using binary search.
  • We can also see that number are 1 to N so, for some N, if there is no valid M then there should be vaid X such that 2*X > N. And we can also say that if for some Z (> X and \le N) there should be Y( \le X). So, answer for subtask three can be calculated easily.

EXPLANATION:

Terms:

  • N - Numbers given in ascending order.
  • M - If it exists then it should be 1 < M < N and the sum of first M numbers is equal to sum from M+1 to N numbers.
  • X - If it exists then it should be 1 < X < N and sum from 1 to X is less then sum from X+1 to N but sum from 1 to X+1 is greater then sum from X+2 to N.

Lemma:

  • Both M and X can not be found for any given N.
    • As M divides the numbers in two equal sum groups and X divides them in equal sum groups.
  • If some valid M is there then only valid swaps are possible by swapping into the same group element, i.e. any pair (u,v) is good only if either 1 \le u < v \le M or M+1 \le u < v \le N.
    • If we are swapping numbers from the same group i.e. either from 1 to M of from M+1 to N then value of M is kept preserve and these swaps are good but if we are doing inter swapping then it shift the value of M toward left (decrease) and we can easily show that 2*M -1 > N which implies that if we take the maximum value from the second group which is N in the first group, we are not able to give more than one greater value from the first group which is M. So, we are unable to shift M or after swapping, we don’t have any valid M.
  • If some valid X is there then swapping numbers from the same group never be a good swap but if we can swap numbers from different groups then we can have some valid swap.
    • By swapping numbers from the same group does not affect the value X which implies that there should not be any M exist after swap. If we swapping (u, v) such that u lies in first group and v in second group then it might be good swap as suppose sum of first X numbers is sum_X and rest is {sum_X}^’ then for each v we should find u such that u + {sum_X}^’ - sum_X = v.
  • We can show that for any valid X, N < 2*X which implies that for each valid v from X+1 to N, there exists a u in first X numbers.
    • We can say that {sum_X}^’ - sum_X < X+1 and 2*X > N. If v is X+1 then u should be X + 1 - {sum_X}^’ + sum_X (>= 1) and if v is N then u should be N - {sum_X}^’ + sum_X (<= X). As if there is a valid answer for X+1 and N then it all the numbers also have valid u.

Now, the problem can be solved easily.

TIME COMPLEXITY:

TIME: O(T * log(N))
SPACE: O(1)

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
 
#define int long long
#define endl "\n"
#define mod 1000000007
 
using namespace std;
 
int32_t main()
{
	ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    
	int t;
	cin >> t;
 
	while(t--)
    {
        int m;
        cin >> m;
 
        if(m%4 == 1 || m%4 == 2)
            cout << 0 << endl;
        else
        {
            int cnt = m*(m+1);
            cnt /= 4;
 
            int n = sqrt(2*cnt) + 1;
            
            while((n*(n+1))/2 > cnt)
                n--;
 
            cnt -= (n*(n+1))/2;
 
            int ans = m-n;
 
            if(cnt == 0)
                ans += n*(n-1)/2 + (m-n)*(m-n-1)/2;
 
            cout << ans << endl;
        }
    }
} 
Tester's Solution
// subtask 3
class CHFNSWAP_SUBTASK3 {
    fun triangular (n: Long) = n*(n+1)/2
    fun choose2 (n: Long) = triangular(n-1)

    fun lbTriangular(target: Long): Long {
        var low = 0L
        var high = 1000000000L
        var ret = -1L

        while (low <= high) {
            val mid = (low + high) / 2
            val diff = triangular(mid) - target
            if (diff <= 0L) {
                ret = mid
                low = mid+1
            } else {
                high = mid-1
            }
        }

        require(triangular(ret) <= target)
        require(triangular(ret+1) > target)

        return ret
    }

    /*
    93924
27510
     */

    fun get(N: Long): Long {
        if (triangular(N) % 2 == 1L) return 0L

        var ans = 0L

        val target = triangular(N) / 2
        lbTriangular(target).let {
            if (triangular(it) == target) {
                ans += choose2(it) + choose2(N-it)
            }
        }

        // maxOf(1, m+1-diff) <= minOf(m, N-diff) ?
        // diff >= 1

        // 아래 4개가 동시에 성립해야 함
        // 1 <= m
        // 1 <= N-(target - triangular(m))   iff triangular(m) >= target-(N-1)
        // m+1-(target - triangular(m)) <= m iff triangular(m) <= target-1
        // m+1-diff <= N-diff                iff m <= N-1
        // 1 <= target - triangular(m)       iff triangular(m) <= target-1

        val smallestM = maxOf(1, lbTriangular(target - (N-1) - 1) + 1)
        val largestM = minOf(N-1, lbTriangular(target))
        for (m in smallestM..largestM) {
            val diff = target - m * (m+1)/2
            ans += maxOf(minOf(m, N-diff) - maxOf(1, m+1-diff) + 1, 0L)
        }
        return ans
    }
}

fun main(args: Array<String>) {
    val br = java.io.BufferedReader(java.io.InputStreamReader(System.`in`), 32768 * 10)
    val bw = java.io.BufferedWriter(java.io.OutputStreamWriter(System.`out`), 32768 * 10)

    val solver = CHFNSWAP_SUBTASK3()

    val T = br.readLine()!!.toInt()
    require(T in 1..1000000)

    for (_t in 0 until T) {
        val N = br.readLine()!!.toLong()
        require(N in 1..1000000000)
        bw.write("${solver.get(N)}\n")
    }
    bw.flush()
}

Video Editorial

17 Likes

Check out Screencast Tutorial for this problem: https://www.youtube.com/watch?v=xlMriamMbPo&list=PLz-fHCc6WaNJa2QJq7qULBBV0YOJunq75&index=4

Very much of an observation based problem. Thanks for this!

3 Likes
1 Like

Hey please help me out with my code I tried to find alternative to reduce time.But Time Limit is exceeding here in my code

import java.util.Scanner;
class Codechef {
public static void main (String[] args) throws java.lang.Exception
{
Scanner scan=new Scanner(System.in);
int t=scan.nextInt();
for(int i=0;i<t;i++) {
int n=scan.nextInt();
count(n);
}
scan.close();
}
public static void count(int n) {
int c=0;
int sum=n*(n+1)/2;
if(sum%2==0) {
int[] a=new int[n];
for(int p=0;p<n;p++) {
a[p]=p+1;
}
for(int j=0;j<n;j++) {
for(int k=j+1;k<n;k++){
int[] b=a.clone();
int temp=b[j];
b[j]=b[k];
b[k]=temp;
int u=0;
for(int w:b) {
u+=w;
if(u==sum/2) {
c+=1;
break ;
}else if(u>sum/2) {
break;
}
}
}
}
}
System.out.println©;
}
}
Hoping someone will surely help.

It is simple binary search problem… you can do it in O(logn) …there is one O(1) approach too

3 Likes

Can you please share the code in Java?

No i can in c++
here is my c++ code

#include <bits/stdc++.h>
#define whatis(x) cout << #x << " is " << x << endl;
#define whatis2(x, y) cout << #x << " is " << x << " and " << #y << " is " << y << endl;
#define whatis3(x, y, z) cout << #x << " is " << x << " and " << #y << " is " << y < < " and " << #z << " is " << z << endl;

#define MOD 1000000007
#define IOS                           \
    ios_base::sync_with_stdio(false); \
    cin.tie(NULL);
using namespace std;
typedef long long ll;
typedef unsigned long long ull;

//function to return sum of n elements
unsigned long long sum(ull n) {
    ull a = n;
    ull b = n+1;
    if(a%2==0){
        a = a/2;
    }
    else{
        b = b/2;
    }
    ull res = a*b;
    return res;
}

//Takes log n time to find the first number whose sum is smaller than equal to total sum/2;
ull findindex(ull n) {
    ull target = (ull)sum(n) / 2;
    ull l = 1;
    ull r = n;
    while (l < r) {
        if ((l+1)== r) {
            break;
        }
        ull m = (l + r) / 2;
        ull curr = (ull)sum(m);

        //if the sum of m which is the middle number is equal to the sum of the target return m
        if (curr == target) {
            return m;
        }
        if (curr < target) {
            l = m;
        } else {
            r = m;
        }
    }
    if(sum(r)<target) return r;
    else return l;
}

ull findstart(ull target, ull end, ull n) {
    ull l = 1;
    ull r = end;

    while (l < r) {

        if ((l + 1) == r) {
            break;
        }
        ull m = (l + r) / 2;
        if (target - (ull)sum(m) > n) {
            l = m;
        }else if (target - (ull)sum(m) <= n) {
            r = m;
        }

    }
    if((ull)sum(r) < target) return r;
    else return l;

    return l;
}
int main() {
    IOS;
    long long t;
    cin >> t;
    while (t--) {
        ull n;
        cin >> n;
        ull s = sum(n);
        if(s%2!=0){
            cout<<0<<"\n";
            continue;
        }
        ull target = sum(n) / 2;
        ull rend = findindex(n);
        ull rbegin = findstart(target, rend, n);
        ull ans = 0;
        ull f = rend - rbegin;
        for (ull i = rbegin; i <= rend; i++) {
            ull diff = target - sum(i);
            ull lower = 1 + diff;
            if(lower<=i){
                lower = i+1;
            }
            ull upper = i + diff;
            if(upper > n){
                upper = n;
            }
            ans+=(upper-lower) + 1;
            if(diff==0){
                ans+=sum(i-1);
                ans+=sum(n-i-1);
            }
        }
        cout<<ans<<"\n";
        }
    return 0;
}
1 Like

Felt more like september math challenge this time xD, nice problemset btw

1 Like

I code basically in Java , still thanks for the help.It will help me to develop logic.Thanks a lot !

Thanks @sk_kh5037coder for the solution.It will be of great help for me. :slight_smile:

my solution

#include<iostream>
#include<math.h>
//Done by vaibhavgadag
#define fastio ios_base::sync_with_stdio(0);cin.tie(NULL);cout.tie(NULL);
typedef long long int ll;
using namespace std;
int main()
{
    fastio;
    ll t;
    cin>>t;
    while(t--)
    {
        ll n;
        cin>>n;
        ll sum=((n)*(n+1))/2;
        if(sum%2==0)
        {
            ll x=(sqrt(1+4*sum)-1)/2;
            ll y=n-x;
            if(sum/2==(y*(2*n-y+1))/2)
            {
                cout<<ll((y*(y-1))/2+((n-y)*(n-y-1))/2+y)<<endl;
            }
            else
            {
                cout<<y<<endl;
            }
        }
        else
        {
            cout<<0<<endl;
        }
    }
    return 0;
}

can anyone help me, how can I short the code or maybe reduce time complexity
Note that my solution has been executed in 1.92 sec AC :expressionless:

Here is my observation based code…

 //Author:- Rahul Arya

#include<bits/stdc++.h>
using namespace std;
#define fio ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define w(x) ll x; cin>>x; while(x--)
#define ll long long int

int main(){
    fio
    ll n,k,s,c,e; double r;
    w(t)
    {
        cin>>n; c=0; r=0;
        k=n*(n+1)/2;
        if(k%2!=0)
            cout<<"0\n";
        else
        {
            r=sqrt(1+2*(n*n+n)); e=sqrt(1+2*(n*n+n));
            c=n-(e-1)/2; s=r;
            if(r-s==0)
            {
                c=c*(c-1)/2+(n-c-1)*(n-c)/2+c;
            }
            cout<<c<<"\n";
        }
    }
}

time complexity of your code is O(1)(you cannot achieve less than that) as for the time is concerned it may be because of heavy math operations like sqrt, division , multiplication on large dataset

Got TLE
https://www.codechef.com/viewsolution/37915951
Any help will be appereciated.

CodeChef: Practical coding for everyone here’s my submission in java i would recommend using bufferedOutputStream. check this article - java - What's the fastest way to output a string to system out? - Stack Overflow

What is the problem with my solution? It passed only test case 1.
I have done it in O(1) (probably) using roots of quadratic equation. To deal with the precision errors in sqrt, I have used sqrtl (1.0L*)) also. Kindly help.

https://www.codechef.com/viewsolution/37905032

You have to use ll instead of int for n.

Thanks.

Why is there only a single value of X for which nice swaps are possible ?
Please provide some intuition or proof

2 Likes