PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Lavish Gupta
Tester: Abhinav Sharma
Editorialist: Taranpreet Singh
DIFFICULTY
Easy
PREREQUISITES
Contribution Trick, Modular arithmetic
PROBLEM
Let us represent a permutation of numbers from 1 to N by A. Let A_i represents the i^{th} number of the permutation.
We define a function over the permutation as follows:
F(A) = (A_1 * A_2) + (A_2 * A_3) + \cdots + (A_{N-2} * A_{N-1}) + (A_{N-1}*A_N)
What is the expected value of the function over all the possible permutations A of numbers from 1 to N?
The expected value of the function can be represented as a fraction of the form \frac{P}{Q}. You are required to print P \cdot Q^{-1} \pmod{1 \, 000 \, 000 \, 007}.
QUICK EXPLANATION
- For each pair (x, y), the number of permutations where y appear right after x consecutively is (N-1)!. Hence, it contributes \displaystyle (N-1)!*\sum_{x = 1}^N \sum_{y = 1, y \neq x}^{N} x*y
- The above expression can be simplified to (N-1)! * [(N*(N+1)/2)^2 - N*(N+1)*(2*N+1)/6]. Since we need to compute the expected value over all permutations, we need to divide it by N!.
EXPLANATION
Math section
In this problem, we would focus on an individual pair (x, y) appearing consecutively, and try to compute the contribution of pair (x, y) by computing the number of permutations where y appears right after x. Note that pair (x, y) is different from pair (y, x).
Since y must appear right after x, we can consider (x, y) to be a single element. The number of elements becomes N-1, so the number of permutations where y appears right after x is (N-1)!.
So the pair (x, y) contributes x*y*(N-1)! to answer. Pair (x, x) should not be considered, since x cannot appear twice.
Hence, the sum of contribution of all valid pairs can be written as expression
S = \displaystyle (N-1)! * \sum_{x = 1}^N \sum_{y = 1, y \neq x}^N x*y
The x \neq y condition is tricky, so let’s subtract cases with (x, x).
S = \displaystyle (N-1)! * \left [ \sum_{x = 1}^N \sum_{y = 1}^N x*y - \sum_{x = 1}^N x^2 \right ]= (N-1)! * \left [\left(\sum_{x = 1}^N x \right) * \left(\sum_{y = 1}^N y\right) - \left( \sum_{x = 1}^N x^2 \right) \right ].
The sum of first N integers is \displaystyle\frac{N*(N+1)}{2} and sum of squares of N integers is \displaystyle\frac{N*(N+1)*(2*N+1)}{6}
Hence, S = \displaystyle (N-1)! * \left[ \frac{N*(N+1)}{2} * \frac{N*(N+1)}{2} - \frac{N*(N+1)*(2*N+1)}{6}\right ].
This is the sum of values of each consecutive pair over all permutations. But we need to compute the expected value over all N! permutations. We just need to divide S by N!.
Hence, the required answer is \displaystyle \frac{1}{N} * \left[ \frac{N*(N+1)}{2} * \frac{N*(N+1)}{2} - \frac{N*(N+1)*(2*N+1)}{6}\right ]
Implementation
We need to know about Modular arithmetic in order to actually compute this mod 10^9+7. This blog explains the idea, and this shares some code.
You also need to compute modular multiplicative inverse to divide expression by N, which can be computed by fermat’s little theorem.
TIME COMPLEXITY
The time complexity is O(log_2(N)) per test case.
SOLUTIONS
Setter's Solution
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
assert('a'<=g and g<='z');
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 1e5;
const int MAX_LEN = 1e5;
const int MAX_SUM_LEN = 1e5;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
int sum_len = 0;
int max_n = 0;
int max_k = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
long long n;
string s;
long long mod = 1e9+7;
long long binary_expo(long long x, long long y){
long long ret = 1;
x%=mod;
while(y){
if(y&1) ret = (ret*x)%mod;
x = (x*x)%mod;
y>>=1;
}
return ret;
}
void solve()
{
n = readIntLn(2, 1e9);
long long sum1 = ((n*n+n)/2)%mod;
sum1 = (sum1*sum1)%mod;
long long sum2 = (((n*n+n)%mod)*(2*n+1))%mod;
sum2 = (sum2*binary_expo(6, mod-2))%mod;
long long sum = (sum1-sum2)%mod;
sum = (sum*binary_expo(n, mod-2))%mod;
if(sum<0) sum+=mod;
cout<<sum<<'\n';
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;
int t = 1;
t = readIntLn(1,MAX_T);
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class EXPPERM{
//SOLUTION BEGIN
long MOD = (long)1e9+7;
long inv2 = inv(2), inv6 = inv(6);
void pre() throws Exception{}
void solve(int TC) throws Exception{
long N = nl();
long ans = (sumN(N)*sumN(N)%MOD + MOD - sumN2(N))%MOD * inv(N)%MOD;
pn(ans);
}
long sumN(long N){
return (N*(N+1))%MOD*inv2%MOD;
}
long sumN2(long N){
return (N*(N+1))%MOD*(2*N+1)%MOD*inv6%MOD;
}
long inv(long a){return pow(a, MOD-2);}
long pow(long a, long p){
long o = 1;
for(; p>0; p>>=1){
if((p&1)==1)o = (o*a)%MOD;
a = (a*a)%MOD;
}
return o;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new EXPPERM().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.