SPBALL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Abhinav Sharma
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

1435

PREREQUISITES:

Precomputing factorials

PROBLEM:

In one second, a ball with number k splits into k balls with number k-1; unless k = 1 in which case nothing happens.

You have N balls, the i-th with number A_i. After 10^{100} seconds, how many balls will you have?

EXPLANATION:

A ball with number k initially is going to end up in several balls with number 1 eventually: in fact, it will take exactly k-1 seconds to do so.
Since we have 10^{100} seconds, in the end we will simply have a bunch of balls with number 1: our task is thus to count how many of them there will be.

Consider a ball with number k. The following happens:

  • In one second, it splits into k balls with number k-1. We now have k-1 balls.
  • After another second, each of these will split into k-1 balls with number k-2. We now have k\cdot (k-1) balls.
  • After another second, each of these will split into k-2 balls with number k-3. We now have k\cdot (k-1)\cdot(k-2) balls.
  • And so on

It’s not hard to see that eventually, we end up with k \cdot (k-1) \cdot (k-2) \cdot \ldots 2 = k! balls.

In particular, the i-th ball we have will give us A_i! balls in the end, so the answer is simply A_1! + A_2! +\ldots + A_N!. All that remains is to compute this quickly.

To do this, note that A_i \leq 10^6, so we can simply precompute the factorials of every number \leq 10^6. After this, solving a single testcase is simple in \mathcal{O}(N).

Remember to perform all operations modulo the given mod.

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;

#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 

#define ll long long
#define db double
#define el "\n"
#define ld long double
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define all(ds) ds.begin(), ds.end()
#define ff first
#define ss second
#define pb push_back
#define mp make_pair
typedef vector< long long > vi;
typedef pair<long long, long long> ii;
typedef priority_queue <ll> pq;
#define o_set tree<ll, null_type,less<ll>, rb_tree_tag,tree_order_statistics_node_update> 

const ll mod = 1000000007;
const ll INF = (ll)1e18;
const ll MAXN = 1000006;

ll po(ll x, ll n){ 
    ll ans=1;
    while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
    return ans;
}


const ll MX=1000005;
ll fac[MX];

void pre(){
 fac[0]=1;
 rep_a(i,1,MX) fac[i]= (i*fac[i-1])%mod;
}



int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
   
    int T=1;
    cin >> T;
    pre();
    while(T--){
        int n;
        cin>>n;

        ll ans = 0;
        ll x;
        rep(i,n){
            cin>>x;
            ans += fac[x];
            ans%=mod;
        }

        cout<<ans<<el;

    
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
fac = [1, 1]
for x in range(2, 10**6 + 10):
    fac.append(x * fac[x-1])
    fac[x] %= mod
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    ans = 0
    for x in a:
        ans += fac[x]
    print(ans % mod)
1 Like

can you please tell me what’s wrong with my code it’s giving me WA
Solution: 79942463 | CodeChef

fac[i] = i * fac[i - 1];

This line has overflow, use long long for the fac array.

thank you so much

what’s wrong with this code??
#include

#include

#include

#include

#include

using namespace std;

#define ll long long int

const unsigned int mod = 1e9+7;

ll result[1000000] = {0};

ll fact_dp(ll n){

if(n>=0){

    result[0] = 1;

    for(ll i=1; i<=n; i++){

        result[i] = (i*result[i-1])%mod;

        result[i]%=mod;

    }

   

    return result[n];

}

}

int main(){

ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);

//this is fast I/O (inputput output) use header file <cstdio>

int t;cin>>t;

while(t--){

    ll n;cin>>n;

    vector<ll>v(n);

    for(int i=0; i<n; i++) cin>>v[i];

    ll sum = 0;

    for(ll i=0; i<n; i++){

        sum += fact_dp(v[i]);

        sum%=mod;

        //else if(v[i]==1) sum = (sum+1)%mod;

    }

    cout<<sum<<endl;

}

return 0;

}