YACP - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Aryan Choudhary
Tester: Utkarsh Gupta
Editorialist: Taranpreet Singh

DIFFICULTY

Easy-Medium

PREREQUISITES

Dynamic Programming, Bitmasking

PROBLEM

You are given an array A = [A_1, A_2, \dots, A_N] containing N distinct integers. Count the number of ways to form (unordered) sets of disjoint increasing subsequences of A.

Formally, count the number of sets S = \{S_1, S_2, \dots, S_k\} such that:

  • Each S_i is an increasing subsequence of A.
  • If i \neq j, S_i and S_j are disjoint, i.e, i\neq j \implies S_i \cap S_j = \emptyset

Note that it is not necessary that the sequences S_1, S_2, \dots, S_k form a partition of A - in other words, some elements of A may not be in any chosen subsequence.

Two sets are considered equal if they contain the same subsequences. For example, the sets \{[1, 2], [3]\} and \{[3], [1, 2]\} are considered to be the same and should only be counted once.

Note that the final answer can be rather large, so compute its remainder after dividing it by 10^9 + 7.

QUICK EXPLANATION

  • Let us process the elements from left to right one by one, and try to maintain active sequences.
  • Use dynamic programming using bitsets to maintain active sequences, we can either not include the current element in any sequence, create a new sequence with the current element, or add the current element at the end of any active sequence.
  • Set bit in bitmask would represent an active sequence ending at that position.

EXPLANATION

Since the order of sequences does not matter, we will process the elements from left to right, and try adding elements either at the end of one of the active lists, or create a new list starting with this element, or not add this element at all.

Considering example A = [1,2,3,4], Suppose we have processed first three elements. We may have \{[1,2],[3]\}, or \{[2,3] \} or so on. What information do we need?

We need the last elements of all active sequences. For \{[1,2],[3]\}, we can either add 4 at end of [1,2], or at end of [3] or start a new sequence [4], or not add this element at all.

If, instead of \{[1,2],[3]\}, we had \{[2],[3]\}, the treatment would have been the same. We don’t care about elements of sequences other than the last element, because whether or not we can append at the end, solely depends on the last element.

Hence, we can represent \{[1,2],[3]\} by set \{2, 3\} denoting the endpoints of active sequences. Similar set for sequences \{[2,3] \} would be \{3\}.

Programatically, this set can be represented by a bitmask, where ith bit set would imply there’s a sequence ending at A_i.

We consider elements from left to right and try to update these.

Let’s assume f_x(mask) denote the number of unordered set of increasing sequences considering first x elements of A. We aim to compute f_x(mask) from f_{x-1}(mask) for any mask.

Initially, we start with f_0(0) = 1, the empty sequence.

The transitions are as follows:

  • If we do not include x-th element at all, then it contributes f_{x-1}(mask) to f_x(mask)
  • If we start a new sequence with x-th element, then it contributes f_{x-1}(mask) to f_x(mask + 2^x)
  • If for some active sequence ending at element y \lt x such that A_y \lt A_x, A_x can be added at end of such sequence. It contributes f_{i-1}(mask) to f_x(mask - 2^y + 2^x)

f_x(mask) can be represented by a 2D array and the DP table can be built iteratively or recursively as well.

In the end, the final answer would be the sum of f_N(mask) for all masks in [0, 2^N-1].

For recursive implementation, see setter’s solution. For iterative implementation, refer to Editorialist’s solution.

TIME COMPLEXITY

The time complexity would be O(N * 2^N) per test case.

The actual number of operations can be estimated as \displaystyle\sum_{i = 1}^N i*2^i.

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
  
#define int long long 
#define pb push_back
#define S second
#define F first
#define f(i,n) for(int i=0;i<n;i++)
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define vi vector<int>
#define pii pair<int,int>
#define all(x) x.begin(),x.end()
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update> 
#define precise(x) fixed << setprecision(x) 
 
const int MOD = 1e9+7;
const int N = 20;
int n;
int a[N];
int dp[1<<N][N];
 
int recur(int mask,int last)
{
    if(mask ==  0) return 1;
    
    int & res = dp[mask][last];
    
    if(res == -1) 
    {
        res = 0;
        
        //continue the last sequence
        for(int j=last+1;j<n;j++)
            if((mask>>j) & 1) 
              if(a[j] > a[last]) res += recur(mask^(1<<j),j);
            
        //start a new sequence
        for(int j=0;j<n;j++)
            if((mask >> j) & 1) 
        {
            res += recur(mask^(1<<j),j);
            break;
        }
        
        res %= MOD;
    }
    
    return res;
}
 
void solve()
{
   cin >> n;
   f(i,n) cin >> a[i];
    
   f(i,1<<n) f(j,n) dp[i][j] = -1;
    
   int res = 0;
    
   for(int i=0;i<n;i++)
       for(int j=0;j<(1<<(n - i - 1));j++)
   {
           res += recur((j<<(i+1)),i);
   }
    
   res++;
       
   res %= MOD;
    
   cout << res << '\n';
}
 
signed main()
{
    fast;
    
    int t = 1;
    
    cin >> t;
    
    while(t--)
        
    solve();
}
Tester's Solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#include <chrono>
#include <random>
#define ll long long int
#define ull unsigned long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define rep(i,n) for(ll i=0;i<n;i++)
#define loop(i,a,b) for(ll i=a;i<=b;i++)
#define vi vector <int>
#define vs vector <string>
#define vc vector <char>
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
#define max3(a,b,c) max(max(a,b),c)
#define min3(a,b,c) min(min(a,b),c)
#define deb(x) cerr<<#x<<' '<<'='<<' '<<x<<'\n'
using namespace std;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
typedef vector<vector<ll>> matrix;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
void solve()
{
    ll n;
    cin>>n;
    vl v;
    for(int i=0;i<n;i++)
    {
        ll c;
        cin>>c;
        v.pb(c);
    }
    ll dp[(1<<n)+10][n+10];
    memset(dp,0,sizeof(dp));
    dp[0][0]=1;
    dp[1][0]=1;
    for(int i=1;i<n;i++)
    {
        for(int mask=0;mask<(1<<(i+1));mask++)
        {
            if((mask&(1<<i))==0)
                dp[mask][i]=dp[mask][i-1];
            else
            {
                dp[mask][i]=dp[mask^(1<<i)][i-1];
                for(int j=0;j<i;j++)
                {
                    if(v[j]<v[i] && (mask&(1<<j))==0)
                    {
                        dp[mask][i]+=dp[mask^(1<<j)^(1<<i)][i-1];
                        dp[mask][i]%=mod;
                    }
                }
            }
        }
    }
    ll ans=0;
    for(int mask=0;mask<(1<<n);mask++)
    {
        ans+=dp[mask][n-1];
        ans%=mod;
    }
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int T=1;
    cin>>T;
    int t=0;
    while(t++<T)
    {
        //cout<<"Case #"<<t<<":"<<' ';
        solve();
        //cout<<'\n';
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class YACP{
    //SOLUTION BEGIN
    int MOD = (int)1e9+7;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni();
        int[] A = new int[N];
        for(int i = 0; i< N; i++)A[i] = ni();
        
        long[] ans = new long[1<<N];
        // ans[mask] denotes the number of sets of sequences, such that their end set is represented by mask
        ans[0] = 1;
        for(int i = 0; i< N; i++){
            
            long[] nxt = new long[1<<N];
            for(int mask = 0; mask < 1<<i; mask++){
                nxt[mask] += ans[mask];//If A[i] is not included in any set
                if(nxt[mask] >= MOD)nxt[mask] -= MOD;
                nxt[mask|(1<<i)] += nxt[mask];//If A[i] is the first element of a new sequence
                if(nxt[mask|(1<<i)] >= MOD)nxt[mask|(1<<i)] -= MOD;
                
                for(int x = 0; x< i; x++){
                    if(((mask>>x)&1)==1 && A[x] < A[i]){
                        //If A[i] is added at the end of sequence ending with A[x]
                        nxt[mask^(1<<x)^(1<<i)] += ans[mask];
                        if(nxt[mask^(1<<x)^(1<<i)] >= MOD)nxt[mask^(1<<x)^(1<<i)] -= MOD;
                    }
                }
            }
            ans = nxt;
        }

        long total = 0;
        for(long x:ans)total = (total+x)%MOD;
        pn(total);
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new YACP().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

3 Likes

Time complexity is O(N*2^N)

1 Like

True. @taran_1407 was about to update it but it got published before. He would later add a section on time complexity.

Roughly this were intended complexities for different subtasks -
15 pts - O(N*3^N)
30 pts - O(3^N)
70 pts - O(N^2*2^N)
100 pts - O(N*2^N)

There also exists O(N^3*2^N) soln but we didn’t have an explicit subtask for it because it is very difficult to cut off O(3^N) from O(N^2*2^N). For N=18 difference between O(3^N) and O(N^2*2^N) isn’t much. In the end, most O(3^N) solns in the contest did get 70 pts. N=19 was also the difference between them isn’t much. N=20 both of them would be too slow.