PATHWAY - Editorial

Problem Link:

Practice

Contest

Author: Niket Agarwal

Tester: Pritish Priyatosh Nayak

Editorialist: Niket Agawal

Difficulty:

Easy

Prerequisites:

Combinatorics, Modular Inverse

Problem:

Given a set of moves. Find the number of ways to reach from one point to other in a lattice.

Explanation:

This problem was an extension of the normal lattice path counting problem.
In the original problem, we had to find the number of ways to reach from (0,0) to (x,y). So, consider we had x horizontal arrows and y vertical arrows. All the possible arrangements of the arrows was the answer, which is (x+y)!/(x!y!).

We see that we have to simply count the number of permutations with repeated elements.


For this question, instead of all arrows having the same length, we have different lengths. So, we simply count the frequency of horizontal and vertical moves. Remember not to mix the frequency of horizontal and vertical moves, since they have different arrows.

So, simply the answer would be (N+M)!/Product of(frequency of each move)!


Remember to use a map for counting the frequency since the range lies upto 1e9.

Solutions:

Setter's/Editorialist's Solution
import java.util.*;import java.io.*;import java.math.*;
public class Main
{
    static long[]fac;
    public static void process()throws IOException
    {
        long x=nl();
        long y=nl();
        int n=ni();
        int m=ni();
        fac=new long[n+m+1];
        fac[0]=1l;
        for(int i=1;i<=n+m;i++)
            fac[i]=(i*fac[i-1])%mod;
        long ans=fac[n+m];
        HashMap<Integer,Integer>mx=new HashMap<>();
        HashMap<Integer,Integer>my=new HashMap<>();
        for(int i=0;i<n;i++)
        {
            int temp=ni();
            if(mx.containsKey(temp))
                mx.put(temp,mx.get(temp)+1);
            else mx.put(temp,1);
        }
        for(int i=0;i<m;i++)
        {
            int temp=ni();
            if(my.containsKey(temp))
                my.put(temp,my.get(temp)+1);
            else my.put(temp,1);
        }
        for(Map.Entry<Integer,Integer>e:mx.entrySet())
            ans=(ans*modInv(fac[e.getValue()]))%mod;
        for(Map.Entry<Integer,Integer>e:my.entrySet())
            ans=(ans*modInv(fac[e.getValue()]))%mod;
        pn(ans);
    }
 
    static AnotherReader sc;
    static PrintWriter out;
    public static void main(String[]args)throws IOException
    {
        boolean oj = true;
        if(oj){sc=new AnotherReader();out=new PrintWriter(System.out);}
        else{sc=new AnotherReader(100);out=new PrintWriter("output.txt");}
        int t=1;
        // t=ni();
        while(t-->0) {process();}
        out.flush();out.close();  
    }
    static long power(long x,long y) 
        { 
            long res=1l; 
            x%=mod;         
            while(y>0) 
            { 
                if(y%2==1) 
                    res=(res*x)%mod; 
                y/=2;
                x=(x*x)%mod; 
            } 
            return res%mod; 
        } 
        static long modInv(long n) 
            {return power(n,mod-2);} 
    static long mod=(long)1e9+7l;
    static void pn(Object o){out.println(o);}
    static void p(Object o){out.print(o);}
    static void pni(Object o){out.println(o);out.flush();}
    static int ni()throws IOException{return sc.nextInt();}
    static long nl()throws IOException{return sc.nextLong();}
    static double nd()throws IOException{return sc.nextDouble();}
    static String nln()throws IOException{return sc.nextLine();}
    static int[] nai(int N)throws IOException{int[]A=new int[N];for(int i=0;i!=N;i++){A[i]=ni();}return A;}
    static long[] nal(int N)throws IOException{long[]A=new long[N];for(int i=0;i!=N;i++){A[i]=nl();}return A;}
    static long gcd(long a, long b)throws IOException{return (b==0)?a:gcd(b,a%b);}
    static int gcd(int a, int b)throws IOException{return (b==0)?a:gcd(b,a%b);}
    static int bit(long n)throws IOException{return (n==0)?0:(1+bit(n&(n-1)));}
 
/////////////////////////////////////////////////////////////////////////////////////////////////////////
 
    static class AnotherReader{BufferedReader br; StringTokenizer st;
    AnotherReader()throws FileNotFoundException{
    br=new BufferedReader(new InputStreamReader(System.in));}
    AnotherReader(int a)throws FileNotFoundException{
    br = new BufferedReader(new FileReader("input.txt"));}
    String next()throws IOException{
    while (st == null || !st.hasMoreElements()) {try{
    st = new StringTokenizer(br.readLine());}
    catch (IOException  e){ e.printStackTrace(); }}
    return st.nextToken(); } int nextInt() throws IOException{
    return Integer.parseInt(next());}
    long nextLong() throws IOException
    {return Long.parseLong(next());}
    double nextDouble()throws IOException { return Double.parseDouble(next()); }
    String nextLine() throws IOException{ String str = ""; try{
    str = br.readLine();} catch (IOException e){
    e.printStackTrace();} return str;}}
   
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
} 
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mod=1e9+7;
 
int add(int a, int b, int p = mod){ int c = a + b; if(c >= p) c -= p; return c; }
int sub(int a, int b, int p = mod){ int c = a - b; if(c < 0) c += p; return c; }
int mul(int a, int b, int p = mod){ return (a * (long long)1 * b) % p; }
int powm(int b, int e,int m=mod){int r=1;while(e>0){if(e%2==1)r=(r*b)%m;b=(b*b*1LL)%m;e/=2;}return r%m;}
int divi(int a, int b, int p = mod){ return mul(a,powm(b,p-2,p));}
 
 
signed main()
{
   ios_base::sync_with_stdio(false);cin.tie(NULL);
   #ifdef Zoro
   freopen("/home/pritish/Competitive/in", "r", stdin);
   freopen("/home/pritish/Competitive/out", "w", stdout);
   #endif
 
   int x,y;
   cin>>x>>y;
   // assert(x>=1&&y>=1&&x<=1e10&&y<=1e10);
 
   int n,m;
   cin>>n>>m;
   // assert(n>=1&&m>=1&&n<=1e5&&m<=1e5);
 
   int N[n],M[m];
   int sumN=0,sumM=0;
 
   map<int,int> mpp1,mpp2;
   for (int i = 0; i < n; ++i)
   {
      cin>>N[i];
      sumN+=N[i];
      mpp1[N[i]]++;
      // assert(N[i]>=1&&N[i]<=1e9);
   }
   // assert(sumN==x);
   for (int i = 0; i < m; ++i)
   {
      cin>>M[i];
      sumM+=M[i];
      mpp2[M[i]]++;
      // assert(M[i]>=1&&M[i]<=1e9);
   }
   // assert(sumM==y);
 
   int ans=1;
   int fact[n+m+5]={0};
   fact[0]=1;
   for (int i = 1; i <= n+m; ++i)
   {
      fact[i] = mul(fact[i-1],i);
   }
   ans=fact[n+m];
 
   for(auto &[x,y]:mpp1)
   {
      ans=divi(ans,fact[y]);
   }
   for(auto &[x,y]:mpp2)
   {
      ans=divi(ans,fact[y]);
   }
 
   cout<<ans;
 
   cerr<<"\n"<<(float)clock()/CLOCKS_PER_SEC*1000<<" ms"<<endl;
   return 0;
}