SORTSET - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: notsoloud
Tester: yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

Given a multiset A, count the number of its subsets whose mode is unique.

EXPLANATION:

For each x, let \text{freq}_x denote the frequency of x in the multiset.

Most counting problems require us to fix some property of the object we’re counting, and this one is no different.
Since we’d like to count the number of subsets with unique mode, the defining factors of a subset are:

  • Which element is the mode, and
  • The frequency of the mode in this subset

So, let’s fix both of these: suppose x is the mode of the subset and occurs k times in this subset; and let’s attempt to count the number of valid subsets.

Note that, for any other integer y:

  • If \text{freq}_y \lt k, then there are (\text{freq}_y+1) choices for y: it can occur 0 times, 1 time, 2 times, \ldots \text{freq}_y times; and all these choices are viable since they don’t affect the mode
  • If \text{freq}_y \geq k, then y has k choices: it can appear anywhere between 0 and k-1 times.

These are all independent choices for each y, and so the total number of subsets equals the product of:

  • \prod(\text{freq}_y + 1) across all y whose frequency doesn’t exceed k-1; and
  • k^{m-1}, where m is the number of elements whose frequency in A is \geq k (we subtract 1 because x also has frequency \geq k but shouldn’t be counted here).

Now, notice that:

  • \prod(\text{freq}_y + 1) is pretty much just a prefix product, if frequencies are sorted in ascending order.
  • The number of frequencies that are \geq k can also be computed quite easily if the sorted list of frequencies is known (either directly via binary search, or in \mathcal{O}(1) by precomputing some suffix sums).

In particular, we can calculate the number of subsets for the pair (x, k) in \mathcal{O}(1) or \mathcal{O}(\log{N}) after a bit of precomputation.

This is enough to solve the problem: simply iterate over all pairs of (x, k) (such that k \geq 1) and perform this computation; the answer is the sum of all these values.

The number of pairs we consider is exactly \text{freq}_1 + \text{freq}_2 + \text{freq}_3 + \ldots + \text{freq}_{10^9} = N, so the time complexity of this is \mathcal{O}(N\log N).

TIME COMPLEXITY

\mathcal{O}(N\log{N}) per testcase.

CODE:

Setter's code (C++)
#include <iostream> 
#include <string> 
#include <set> 
#include <map> 
#include <stack> 
#include <queue> 
#include <vector> 
#include <utility> 
#include <iomanip> 
#include <sstream> 
#include <bitset> 
#include <cstdlib> 
#include <iterator> 
#include <algorithm> 
#include <cstdio> 
#include <cctype> 
#include <cmath> 
#include <math.h> 
#include <ctime> 
#include <cstring> 
#include <unordered_set> 
#include <unordered_map> 
#include <cassert>
#define int long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;

const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}

int power(int a, int b, int m){
    if(a == 0)
        return 0;
    if(b == 0)
        return 1;
    int res = 1;
    while(b){
        if(b&1){
            res = (res*a)%m;
        }
        a = (a*a)%m;
        b >>= 1;
    }
    return res;
}

int sumN = 0;
void solve()
{
    int n = readInt(1,100000,'\n');
    sumN += n;
    int a[n];
    for(int i=0; i<n-1; i++){
        a[i] = readInt(1,1000000000,' ');
    }
    a[n-1] = readInt(1,1000000000,'\n');

    unordered_map<int, int> freq;
    int maxFreq = 0;
    for(int i=0; i<n; i++){
        freq[a[i]]++;
        maxFreq = max(maxFreq, freq[a[i]]);
    }
    int count[n+1] = {0};
    for(auto i: freq){
        count[i.second]++;
    }
    
    int dist = freq.size();
    int total = 0;
    int ans = 0;
    int maxReached = 1;
    for(int i = 1; i<=maxFreq; i++){
        ans = (ans + (((maxReached*(dist-total))%mod)*power(i, dist-total-1, mod)))%mod;
        total += count[i];
        maxReached = (maxReached * power(i+1, count[i], mod))%mod;
    }

    cout << ans << '\n';
}

int32_t main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,20000,'\n');
    while(T--){
        solve();
        // cout<<'\n';
    }
    cerr << sumN << '\n';
    assert(sumN <= 200000);
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>                   
#define int long long     
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);
#define mod 1000000007ll //998244353ll
#define mii map<int, int> 
using namespace std;

int power(int a, int b, int p) {
    if(a==0)
    return 0;
    int res=1;
    a%=p;
    while(b>0)
    {
        if(b&1)
        res=(1ll*res*a)%p;
        b>>=1;
        a=(1ll*a*a)%p;
    }
    return res;
}

int32_t main()
{
    IOS;
    int t;
    cin>>t;
    while(t--)
    {
        int n;
        cin>>n;
        int co[n+1];
        memset(co, 0, sizeof(co));
        mii mp;
        for(int i=0;i<n;i++)
        {
            int a;
            cin>>a;
            mp[a]++;
        }
        for(auto it:mp)
            co[it.second]++;
        int small_choices=1, ans=0, dist=mp.size(), small_count=0;
        for(int i=1;i<=n;i++)
        {
            ans += (small_choices*(dist-small_count)%mod*power(i, dist-small_count-1, mod))%mod;
            ans %= mod;
            small_count += co[i];
            small_choices = (small_choices*power(i+1, co[i], mod))%mod;
        }
        cout<<ans<<"\n";
    }
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    
    freq = {}
    for x in a:
        if x not in freq: freq[x] = 0
        freq[x] += 1
    
    freqct = [0]*(n + 1)
    for f in freq.values(): freqct[f] += 1
    
    suf, pref = [0]*(n+1), [0]*(n+1)
    
    suf[n] = freqct[n]
    for i in reversed(range(1, n)):
        suf[i] = suf[i+1] + freqct[i]
    
    pref[0] = 1
    for i in range(1, n+1):
        pref[i] = pow(i+1, freqct[i], mod)
        pref[i] *= pref[i-1]
        pref[i] %= mod
    
    ans = 0
    for f in freq.values():
        for i in range(1, f+1):
            more = suf[i] - 1
            ans += pref[i-1] * pow(i, more, mod)
            ans %= mod
    print(ans)
1 Like

Hello @iceknight1093 ,

if( freq[x] >= k) , then y has k choices , not k+1 , starting from 0 to k-1,

small typo, pls correct if u would like.

Very nice editorial. It was good contest, very nice problem set.

Fixed, thanks for noticing.

Can anyone please help me to debug my code . It is giving wrong answer only for last 2 test case .
It is giving some negative number . Please help me , I am not able to figure out , i have made all variable as long but not able to get where it got overflow .
Thanks in advance.

Solution: 92248513 | CodeChef

I am able to figure out the problem .
problem was that i has declared function name is “pow” that is also name of inbuilt function , that’s why it was giving some unexpected error ,
Now i have changed name as “poww” , and It is giving correct answer for all Test cases .
This is the solution
Solution: 92249784 | CodeChef