EXPPERM - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Lavish Gupta
Tester: Abhinav Sharma
Editorialist: Taranpreet Singh

DIFFICULTY

Easy

PREREQUISITES

Contribution Trick, Modular arithmetic

PROBLEM

Let us represent a permutation of numbers from 1 to N by A. Let A_i represents the i^{th} number of the permutation.

We define a function over the permutation as follows:
F(A) = (A_1 * A_2) + (A_2 * A_3) + \cdots + (A_{N-2} * A_{N-1}) + (A_{N-1}*A_N)

What is the expected value of the function over all the possible permutations A of numbers from 1 to N?

The expected value of the function can be represented as a fraction of the form \frac{P}{Q}. You are required to print P \cdot Q^{-1} \pmod{1 \, 000 \, 000 \, 007}.

QUICK EXPLANATION

  • For each pair (x, y), the number of permutations where y appear right after x consecutively is (N-1)!. Hence, it contributes \displaystyle (N-1)!*\sum_{x = 1}^N \sum_{y = 1, y \neq x}^{N} x*y
  • The above expression can be simplified to (N-1)! * [(N*(N+1)/2)^2 - N*(N+1)*(2*N+1)/6]. Since we need to compute the expected value over all permutations, we need to divide it by N!.

EXPLANATION

Math section

In this problem, we would focus on an individual pair (x, y) appearing consecutively, and try to compute the contribution of pair (x, y) by computing the number of permutations where y appears right after x. Note that pair (x, y) is different from pair (y, x).

Since y must appear right after x, we can consider (x, y) to be a single element. The number of elements becomes N-1, so the number of permutations where y appears right after x is (N-1)!.

So the pair (x, y) contributes x*y*(N-1)! to answer. Pair (x, x) should not be considered, since x cannot appear twice.

Hence, the sum of contribution of all valid pairs can be written as expression
S = \displaystyle (N-1)! * \sum_{x = 1}^N \sum_{y = 1, y \neq x}^N x*y

The x \neq y condition is tricky, so let’s subtract cases with (x, x).

S = \displaystyle (N-1)! * \left [ \sum_{x = 1}^N \sum_{y = 1}^N x*y - \sum_{x = 1}^N x^2 \right ]= (N-1)! * \left [\left(\sum_{x = 1}^N x \right) * \left(\sum_{y = 1}^N y\right) - \left( \sum_{x = 1}^N x^2 \right) \right ].

The sum of first N integers is \displaystyle\frac{N*(N+1)}{2} and sum of squares of N integers is \displaystyle\frac{N*(N+1)*(2*N+1)}{6}

Hence, S = \displaystyle (N-1)! * \left[ \frac{N*(N+1)}{2} * \frac{N*(N+1)}{2} - \frac{N*(N+1)*(2*N+1)}{6}\right ].

This is the sum of values of each consecutive pair over all permutations. But we need to compute the expected value over all N! permutations. We just need to divide S by N!.

Hence, the required answer is \displaystyle \frac{1}{N} * \left[ \frac{N*(N+1)}{2} * \frac{N*(N+1)}{2} - \frac{N*(N+1)*(2*N+1)}{6}\right ]

Implementation

We need to know about Modular arithmetic in order to actually compute this mod 10^9+7. This blog explains the idea, and this shares some code.

You also need to compute modular multiplicative inverse to divide expression by N, which can be computed by fermat’s little theorem.

TIME COMPLEXITY

The time complexity is O(log_2(N)) per test case.

SOLUTIONS

Setter's Solution
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        assert('a'<=g and g<='z');
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_LEN = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
 
int sum_len = 0;
int max_n = 0;
int max_k = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
 
long long n;
string s;

long long mod = 1e9+7;

long long binary_expo(long long x, long long y){
    long long ret = 1;
    x%=mod;

    while(y){
        if(y&1) ret = (ret*x)%mod;
        x = (x*x)%mod;
        y>>=1;
    }

    return ret;
}

 
void solve()
{

    n = readIntLn(2, 1e9);

    long long sum1 = ((n*n+n)/2)%mod;
    sum1 = (sum1*sum1)%mod;

    long long sum2 = (((n*n+n)%mod)*(2*n+1))%mod;
    sum2 = (sum2*binary_expo(6, mod-2))%mod;

    long long sum = (sum1-sum2)%mod;

    sum = (sum*binary_expo(n, mod-2))%mod;
    if(sum<0) sum+=mod;

    cout<<sum<<'\n';

}
 
signed main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;

    int t = 1;
    
    t = readIntLn(1,MAX_T);
    
    for(int i=1;i<=t;i++)
    {     
       solve();
    }
    
    assert(getchar() == -1);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';

}
Editorialist's Solution
import java.util.*;
import java.io.*;
class EXPPERM{
    //SOLUTION BEGIN
    long MOD = (long)1e9+7;
    long inv2 = inv(2), inv6 = inv(6);
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        long N = nl();
        long ans = (sumN(N)*sumN(N)%MOD + MOD - sumN2(N))%MOD * inv(N)%MOD;
        pn(ans);
    }
    long sumN(long N){
        return (N*(N+1))%MOD*inv2%MOD;
    }
    long sumN2(long N){
        return (N*(N+1))%MOD*(2*N+1)%MOD*inv6%MOD;
    }
    long inv(long a){return pow(a, MOD-2);}
    long pow(long a, long p){
        long o = 1;
        for(; p>0; p>>=1){
            if((p&1)==1)o = (o*a)%MOD;
            a = (a*a)%MOD;
        }
        return o;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new EXPPERM().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

3 Likes

N could be canceled so the expression simplifies to (3N+2)(N-1)(N+1) \cdot {P+1 \over 12}. Note that P+1 is divisible by 3 and 4 and thus by 12. Also, (3N+2)\cdot P does not go over 2^{63}, so it does not overflow signed 64-bit integer type.
Solution.

1 Like

i dont know why it is not accepting,
need help
https://www.codechef.com/viewsolution/54231941

your code fails in this test case :

1
100000000

Expected ans : 142250001

change int n; to long long int n; in line number 18 of your program

thanks a lot man