PYTHAGORAS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Aryan
Preparer: Utkarsh Gupta
Testers: Nishank Suresh, Takuki Kurokawa
Editorialist: Nishank Suresh

DIFFICULTY:

1862

PREREQUISITES:

Algebraic manipulation

PROBLEM:

Given an integer N whose largest odd factor is at most 10^5, find two integers A and B such that A^2 + B^2 = N or claim that none exist.

EXPLANATION:

There are several different constructions that can work in this task, so if you have an interesting one feel free to share it in the comments below.

The constraint on the largest odd factor is a bit weird, so let’s try to use that. Note that it immediately implies that any large N must be even.

So, if we were able to obtain a solution for N from a solution for N/2, we could potentially use that to build a solution.

It turns out that the relationship is a bit stronger: when N is even, there exists an integer pair (A_1, B_1) such that A_1^2 + B_1^2 = N if and only if there exists an integer pair (A_2, B_2) such that A_2^2 + B_2^2 = N/2.

Proof

Suppose A^2 + B^2 = N/2. Then, (A+B)^2 + (A-B)^2 = 2A^2 + 2B^2 = N gives us a solution for N.

Conversely, if A^2 + B^2 = N, turning the above construction around gives us

\left ( \frac{A+B}{2} \right )^2 + \left ( \frac{A-B}{2} \right )^2 = \frac{A^2}{2} + \frac{B^2}{2} = \frac{N}{2}

The interesting thing here is that (A+B)/2 and (A-B)/2 are both integers: if N is even, the only way A^2 + B^2 = N can have integer solutions is if both A and B have the same parity.

This tells us that it is enough to solve the problem for small N: we can reduce N to its largest odd factor, solve for this factor, then reconstruct the result for the original N using the method above.

Solving for N \leq 10^5 can be done with bruteforce: fix a value of A, then check if N - A^2 is itself a square integer.
We only need to check those A such that A^2 \leq N, giving us a \sqrt{10^5} solution, which is good enough.

There are also other constructions, though most use the same idea. One simple one is as follows:
Suppose we have a solution to A^2 + B^2 = N. Then, simply multiplying A and B by 2 gives us a solution to 4N.
So, we can simply divide N by 4 as long as possible, which will end with it being \leq 2\cdot 10^5. Use the bruteforce to solve for this, then reconstruct the answer for the original N by multiplying by 2 as many times as needed.

TIME COMPLEXITY

\mathcal{O}(\sqrt{10^5} + \log N) per test case.
Spending 10^5 \cdot \sqrt{10^5} time on precomputation can bring this down to \mathcal{O}(\log N) per test case, but is unnecessary.

CODE:

Preparer's code (C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int avail[N];
pair<ll,ll> good[N];
int issq[N];
int sqrtval[N];
void solve()
{
    ll n=readInt(1,(ll)100000000*10000000,'\n');
    ll tmp=n;
    while(tmp%2==0)
        tmp/=2;
    assert(tmp<=100000);
    if(avail[tmp]==0)
        cout<<-1<<'\n';
    else
    {
        ll a=good[tmp].first;
        ll b=good[tmp].second;
        while(tmp!=n)
        {
            ll c=a+b;
            ll d=abs(a-b);
            a=c;
            b=d;
            tmp*=2;
        }
        cout<<a<<' '<<b<<'\n';
    }
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,100000,'\n');
    for(int i=0;i<400;i++)
    {
        issq[i*i]=1;
        sqrtval[i*i]=i;
    }
    for(int n=1;n<=100000;n++)
    {
        for(int a=0;a<=1000;a++)
        {
            if(a*a>n)
                break;
            if(issq[n-a*a]==1)
            {
                ll b=sqrtval[n-a*a];
                good[n]=mp(a,b);
                avail[n]=1;
                break;
            }
        }
    }
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    while (tt--) {
        long long n = in.readLong(1, 1e15);
        in.readEoln();
        long long m = 1;
        while (n % 2 == 0) {
            n /= 2;
            m *= 2;
        }
        assert(n <= 1e5);
        if (__builtin_ctzll(m) % 2 == 1) {
            m /= 2;
            n *= 2;
        }
        m = llround(sqrtl(m));
        int a = -1;
        for (int i = 0; i * i <= n; i++) {
            int j = (int) llround(sqrtl(n - i * i));
            if (i * i + j * j == n) {
                a = i;
                break;
            }
        }
        if (a == -1) {
            cout << -1 << '\n';
        } else {
            cout << a * m << " " << llround(sqrtl(n - a * a)) * m << '\n';
        }
    }
    in.readEof();
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	const int mx = 2e5 + 10;

	int t; cin >> t;
	while (t--) {
		ll n; cin >> n;
		ll mul = 1;
		while (n > mx) {
			n /= 4;
			mul *= 2;
		}
		bool done = false;
		for (int i = 0; i*i <= n; ++i) {
			int rem = n - i*i;
			int s = sqrtl(rem);
			while (s*s > rem) --s;
			while ((s+1)*(s+1) <= rem) ++s;
			if (s*s == rem) {
				cout << mul*i << ' ' << mul*s << '\n';
				done = true;
				break;
			}
		}
		if (!done) cout << -1 << '\n';
	}
}
3 Likes

I solved it with the idea on:

and with Cornacchia algorithm:

Basically, just solve x^2+y^2 = p, where p is a prime number from factorization of N.
Multiply the solutions on complex form and it is done.

        ll aux = 1;
	    pair <ll, ll> ans;
	    ans.first = 1;
	    ans.second = 0;
	    for(int i=0; i<f.size(); i++){
	        if(f[i].second %2 == 1){
	            ans = multiply(ans, cornacchia(f[i].first));
	            int rest = f[i].second-1;
	            aux = aux*fexp(f[i].first, rest/2);
	        }    
	        else{
	            aux = aux*fexp(f[i].first, f[i].second/2);
	        }
	    }
	    
	    ans.first *= aux;
        ans.second *= aux;
4 Likes

I found it very difficult to understand the solution, that is provided. Can someone hire a person who can tell this is Hindi, or may a bit more clearer English.
I don’t mean to sound rude, or I am not at all commenting on the coding skill of the Solution provider, but still, as I have paid for the subscription, it should be a bit more clear.

7 Likes

so i came up with this but 1 problem i couldn’t figure out for the even n , can anyone say this?

// bool isPerfectSq(long double n){

// if(n>=0){

// ll sr = sqrt(n);

// return (sr*sr == n);

// }

// return false;

// }

int main(){

ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);

//this is fast I/O (inputput output) use header file <cstdio>

ll t;cin>>t;

while(t--){

    ll n;cin>>n;

    ll x = sqrt(n);

   

    if(x*x == n){

        if(x*x==n) cout<<0<<" "<<x<<endl;

    }

    else if(n%2==1){

        int first = floor(x);

        int second = first-1;

        if((first*first)+(second*second)==n)

            cout<<second<<" "<<first<<endl;

        else cout<<-1<<endl;

    }

    else if(n%2==0){

        int first = floor(x);

        int second = x-1;

        int third = x+1;

        if((first*first)+(first*first)==n) cout<<first<<" "<<first<<endl;

        else if((first*first)+(second*second)==n) cout<<first<<" "<<second<<endl;

        else if((second*second)+(second*second)==n) cout<<second<<" "<<second<<endl;

        else if((second*second)+(second*third)==n) cout<<second<<" "<<third<<endl;

        else if((third*third)+(third*third)==n) cout<<third<<" "<<third<<endl;

        else if((first*first)+(third*third)==n) cout<<first<<" "<<third<<endl;

    }

}

// int n;cin>>n;

// cout<<(sqrt(n));

return 0;

}

can anyone explain this one ? the editorial is …

From the problem statement, N’s largest odd factor is at most 10^5, we can assume N = n * 2^k, where n <= 10^5 and k >= 0.

  1. When k % 2 == 0, if we can find some (a, b) that a^2 + b^2 = n, we can construct an answer:
    A = a * 2^{k/2}
    B = b * 2^{k/2}

  2. When k % 2 == 1, if we can find some (a, b) that a^2 + b^2 = 2n, we can also construct an answer:
    A = a * 2^{ \lfloor k/2 \rfloor}
    B = b * 2^{ \lfloor k/2 \rfloor }

So the problem reduced to find (a, b) for a^2 + b^2 = x, where x <= 2 * 10^5.
We can bruce all a and b <= sqrt(2e5), and store x = (a, b) in a map.

3 Likes

#include <bits/stdc++.h>
using namespace std;
#define fast()
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
#pragma GCC optimize(“O3”)
#pragma GCC Optimization(“unroll-loops”)
#pragma GCC target(“avx2”)
#define ll long long int
#define endl “\n”
#define debug(x)
for (auto element : x)
cout << element << " ";
cout << endl;
#define debugp(x)
for (auto element : x)
cout << element.first << " " << element.second << endl;
#define db(x) cout << #x << " = " << x << endl;
#define MAX2(x, y) ((x) >= (y) ? (x) : (y))
#define MIN2(x, y) ((x) >= (y) ? (y) : (x))
#define MIN3(x, y, z) MIN2(x, MIN2(y, z))
#define MAX3(x, y, z) MAX2(x, MAX2(y, z))
#define pb push_back
#define pf push_front
#define popf pop_front
#define popb pop_back
const ll MOD = 1e9 + 7;
const ll N = 1e5 + 10;
#define ni1(t)
ll t;
cin >> t;
#define ni2(a, b)
ll a, b;
cin >> a >> b
#define ni3(a, b, c)
ll a, b, c;
cin >> a >> b >> c
#define ni4(a, b, c, d)
ll a, b, c, d;
cin >> a >> b >> c >> d
#define ni5(a, b, c, d, e)
ll a, b, c, d, e;
cin >> a >> b >> c >> d >> e
#define ni6(a, b, c, d, e, f)
ll a, b, c, d, e, f;
cin >> a >> b >> c >> d >> e >> f
#define rep(i, a, b) for (ll i = a; i <= b; i++)
#define revrep(i, a, b) for (ll i = a; i >= b; i–)
#define mem0(a) memset(a, 0, sizeof(a))
#define vll(v, n)
vector v(n);
rep(i, 0, n - 1) { cin >> v[i]; }
#define array(arr, n)
ll arr[n];
rep(i, 0, n - 1) cin >> arr[i];
#define arrayx(arr, n, x)
ll arr[n];
rep(i, 0, n - 1) arr[i] = x;
#define printarray(arr, n) rep(i, 0, n - 1) cout << arr[i];
vector<pair<ll,ll>> square;
unordered_map<ll,ll> s;
void store(){
for(int i=0;ii<=1e5;i++) {
square.push_back(make_pair(i
i,i));
s[i*i]=i;
}
}
int main(){
fast();

ni1(t);
store();
// for(int i=0;i<square.size();i++){
// cout<<square[i].first<<" “<<square[i].second<<endl;
// }
while(t–){
ni1(n);
ll odddivisor=0;
if(n%2) odddivisor=n;
else{
ll temp=n;
while(temp){
temp/=4;
if(temp%2) {
odddivisor=temp;
break;
}
}
}
bool flag=0;
ll multiplesquare=n/odddivisor;
//cout<<odddivisor<<” “<<multiplesquare<<” “<<endl;
if(s.find(multiplesquare)==s.end()) cout<<-1<<endl;
else{
for(int i=0;i<square.size();i++){
if(square[i].first>odddivisor) break;
if(s.find(odddivisor-square[i].first)!=s.end() && square[i].first!=odddivisor-square[i].first){
cout<<square[i].second*s[multiplesquare]<<” "<<s[odddivisor-square[i].first]*s[multiplesquare]<<endl;
flag=1;
break;
}
}
if(!flag) cout<<-1<<endl;
}

}
return 0;
}
can anyone help me out why it is giving runtime error?

It is giving time limit exceeded. Can you please show where I am going wrong in this submission please?
https://www.codechef.com/viewsolution/78413809

You will need to prepare the map first, then answering the T queries.
https://www.codechef.com/viewsolution/78325119

hey @iceknight1093

So knowing => X2+ Y2 = N
we can find X and Y for 4N, 8N, and so on. How do we find X and Y for 2N, 6N… ?
You solution seems to only find answers for 4KN
for example:
knowing 32+ 42 = 25
how do we find X and Y for X2+ Y2 = 50 ?

1 Like

now the solutions require money. isn’t it cheap? we do contest and then can’t upsolve for money. soln is really helpful

You’re correct that my code finds a solution for 4N knowing a solution for N. For this problem, that is enough because if N \gt 2 \times 10^5, then N will definitely be divisible by 4 because of the constraint on the largest odd divisor.
So while N \gt 2 \times 10^5 you can keep dividing by 4, and once you reach something \leq 2\times 10^5 you just bruteforce.

There is also a way to get a solution for 2N from a solution for N. The method of doing that is detailed in the editorial: see the “Proof” section.

Written editorials never have and never will be paid, I sure hope you didn’t pay anything to view this page.

Which part do you find unclear?

1 Like

@vedant_k07

I think for N which is not in the form of 4KN we should have to use brute force,
and the worst case TC will be 10 ^ 15 * 10 ^ 5

But that case never came because the largest odd factor of N is at most 10^5
For ex 10000000019 is not a valid test case, because the largest odd factor will be 10000000019 (which is greater than 10 ^ 5).

Also 2 * 10000000019, 6 * 10000000019, 8 * 10000000019 all these are invalid test cases

Please correct me if I am wrong

how ??
wouldn’t it be (A^2+B^2)/2=N/2;

Oh, I see your confusion, I used the same variable names.

What I meant was the following:
When N is even, there exists a pair (A_1, B_1) such that A_1^2 + B_1^2 = N if and only if there exists a pair (A_2, B_2) such that A_2^2 + B_2^2 = N/2.

Essentially, you can find an answer for N if and only if you can find an answer for N/2. This is the main reason why the given solution works at all.

3 Likes

@iceknight1093 thanks for the great explanation.
Most of the explanation revolves around 2n but in the actual code we are checking for 4n; I get it’s the same concept, but I’m not able to understand how 4n will always exist and why not 8n or 16n.

Thanks!

My code (and lots of other people’s) uses 4N simply because it’s easy to implement: you just multiply both A and B by 2.

It’s totally fine to use the 2N construction by doing (A, B) \to (A+B, A-B), you can see that the preparer’s code does exactly this. In fact, if you notice, the construction for 4N just comes from applying the construction for 2N twice, since (A+B)+(A-B) = 2A and (A+B)-(A-B) = 2B.

1 Like

What wrong in this code??
Cannot pass last two test cases

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define tc int t;cin>>t;while(t--)

void solve()
{
  ll n,count = 0,flag = 0;
  cin>>n;

  while(n % 4 == 0)
  {
    n /= 4; 
    count++; 
  }

  ll i , j ;

  for(i = 0 ; i*i <= n ; i++)
  {
     j = n - i * i ;  
     ll root = sqrt(j);

     if(root * root == j)
     {
        flag = 1 ;
        j = root ;
        break;
     } 
  }

  if(!flag)
  cout << " -1 "<< endl; 
  else
  cout << pow(2,count) * i << " " << pow(2,count) * j << endl ;   
}

int main() 
{
  #ifndef ONLINE_JUDGE
  freopen("error.txt","w",stderr);
  #endif
  tc solve();
  return 0;
}

@iceknight1093 what is the use of these two lines??
without that also the code gives AC.

I have to use the type conversion
cout << (ll)powl(2,count) * i << " " << (ll)powl(2,count) * j << endl ;

Now it works