CNTTRIANGLE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Combinatorics

PROBLEM:

You are given integers N and M.
Count the number of sets of size N, containing integers between 1 and M, such that any three distinct elements in the set can form the sides of a non-degenerate triangle.

EXPLANATION:

Recall that for three positive numbers x, y, z to form the sides of a non-degenerate triangle, the three inequalities

x + y \gt z \\ x + z \gt y \\ y + z \gt x

must hold.
In particular, the largest one of the three should be strictly smaller than the sum of the other two.

Now, suppose you have a set of size N, say \{x_1, x_2, \ldots, x_N\} where x_i \lt x_{i+1}.
We’d like any triplet among them to form a triangle: it’s not hard to see that this will happen if and only if (x_1, x_2, x_N) can form a triangle - that is, x_1 + x_2 \gt x_N should hold.

Let’s fix the values of x_1 and x_2.
Then, as long as x_N \lt x_1 + x_2, the elements \{x_3, x_4, \ldots, x_N\} can be chosen freely.
That is, among the values \gt x_2 and \lt \min(M+1, x_1 + x_2), we can freely choose N-2 of them. (The \min(M+1, \ldots) is needed since we aren’t allowed to choose values \gt M no matter what.)

The number of elements we have is thus \min(M+1, x_1 + x_2) - x_2 - 1, from which we want to choose N-2, which obviously can be done in

\binom{\min(M+1, x_1+x_2)-x_2-1}{N-2}

ways.

Note that the term \min(M+1, x_1+x_2)-x_2-1 is:

  • Just x_1-1, when x_1+x_2 \leq M+1.
    • Rewrite this condition to x_2 \leq M+1-x_1.
  • M - x_2, otherwise.

So, if we fix only x_1 and look at what’s being added across all x_2, we see that it’s

  • \displaystyle\binom{x_1-1}{N-2} (which is a constant, given that x_1 is fixed) for a certain range of x_2 (specifically from x_1 + 1 to M+1-x_1).

  • \displaystyle\binom{0}{N-2} + \binom{1}{N-2} + \ldots + \binom{k}{N-2} summed up across the rest.
    Here k = \min(x_1-2, M-x_1-1), since this part comes from choosing x_2 = M, M-1, M-2, \ldots till we either reach x_2 = x_1 + 1, or x_2 = M+2-x_1 (at which point it moves to the first case).

Note that the first term is a binomial coefficient multiplied by a constant, while the second is a prefix sum. (In fact, the summation can be reduced to a single binomial coefficient using the hockey-stick identity, though seeing it as a prefix sum suffices to solve this problem.)

So, for a fixed x_1, we’re able to compute the answer across all x_2 in constant time.
Sum this up across all x_1 to obtain a solution in \mathcal{O}(M).

TIME COMPLEXITY:

\mathcal{O}(M) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int facN = 1e6 + 5;
const int mod = 1e9 + 7; // 998244353
int ff[facN], iff[facN];
bool facinit = false;

int power(int x, int y){
    if (y == 0) return 1;

    int v = power(x, y / 2);
    v = 1LL * v * v % mod;

    if (y & 1) return 1LL * v * x % mod;
    else return v;
}

void factorialinit(){
    facinit = true;
    ff[0] = iff[0] = 1;

    for (int i = 1; i < facN; i++){
        ff[i] = 1LL * ff[i - 1] * i % mod;
    }

    iff[facN - 1] = power(ff[facN - 1], mod - 2);
    for (int i = facN - 2; i >= 1; i--){
        iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
    }
}

int C(int n, int r){
    if (!facinit) factorialinit();

    if (n == r) return 1;

    if (r < 0 || r > n) return 0;
    return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}

int P(int n, int r){
    if (!facinit) factorialinit();

    assert(0 <= r && r <= n);
    return 1LL * ff[n] * iff[n - r] % mod;
}

int Solutions(int n, int r){
    //solutions to x1 + ... + xn = r 
    //xi >= 0

    return C(n + r - 1, n - 1);
}

void Solve() 
{
    int n, m; cin >> n >> m;
    
    vector <int> a(m + 1, 0);
    vector <int> d(m + 1, 0);
    for (int y = 4; y <= m; y++){
        int part = (y + 1) / 2;
        d[y - part - 1] += 1;
        part--;
        d[part - 1] += 1;
    }
    
    a[m] = d[m];
    
    for (int i = m - 1; i >= 0; i--){
        a[i] = a[i + 1] + d[i];
        a[i] %= mod;
    }
    
    d = a;
    a[m] = d[m];
    
    for (int i = m - 1; i >= 0; i--){
        a[i] = a[i + 1] + d[i];
        a[i] %= mod;
    }
    
    int ans = 0;
    for (int i = 1; i <= m; i++){
        ans += a[i] * C(i - 1, n - 3) % mod;
    }
    ans %= mod;
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

namespace mint_ns {
template<auto P>
struct Modular {
    using value_type = decltype(P);
    value_type value;
 
    Modular(long long k = 0) : value(norm(k)) {}
 
    friend Modular<P>& operator += (      Modular<P>& n, const Modular<P>& m) { n.value += m.value; if (n.value >= P) n.value -= P; return n; }
    friend Modular<P>  operator +  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r += m; }
 
    friend Modular<P>& operator -= (      Modular<P>& n, const Modular<P>& m) { n.value -= m.value; if (n.value < 0)  n.value += P; return n; }
    friend Modular<P>  operator -  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r -= m; }
    friend Modular<P>  operator -  (const Modular<P>& n)                      { return Modular<P>(-n.value); }
 
    friend Modular<P>& operator *= (      Modular<P>& n, const Modular<P>& m) { n.value = n.value * 1ll * m.value % P; return n; }
    friend Modular<P>  operator *  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r *= m; }
 
    friend Modular<P>& operator /= (      Modular<P>& n, const Modular<P>& m) { return n *= m.inv(); }
    friend Modular<P>  operator /  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r /= m; }
 
    Modular<P>& operator ++ (   ) { return *this += 1; }
    Modular<P>& operator -- (   ) { return *this -= 1; }
    Modular<P>  operator ++ (int) { Modular<P> r = *this; *this += 1; return r; }
    Modular<P>  operator -- (int) { Modular<P> r = *this; *this -= 1; return r; }
 
    friend bool operator == (const Modular<P>& n, const Modular<P>& m) { return n.value == m.value; }
    friend bool operator != (const Modular<P>& n, const Modular<P>& m) { return n.value != m.value; }
 
    explicit    operator       int() const { return value; }
    explicit    operator      bool() const { return value; }
    explicit    operator long long() const { return value; }
 
    constexpr static value_type mod()      { return     P; }
 
    value_type norm(long long k) {
        if (!(-P <= k && k < P)) k %= P;
        if (k < 0) k += P;
        return k;
    }
 
    Modular<P> inv() const {
        value_type a = value, b = P, x = 0, y = 1;
        while (a != 0) { value_type k = b / a; b -= k * a; x -= k * y; swap(a, b); swap(x, y); }
        return Modular<P>(x);
    }
 
    friend void __print(Modular<P> v) {
        cerr << v.value;
    }
};
template<auto P> Modular<P> pow(Modular<P> m, long long p) {
    Modular<P> r(1);
    while (p) {
        if (p & 1) r *= m;
        m *= m;
        p >>= 1;
    }
    return r;
}
 
template<auto P> ostream& operator << (ostream& o, const Modular<P>& m) { return o << m.value; }
template<auto P> istream& operator >> (istream& i,       Modular<P>& m) { long long k; i >> k; m.value = m.norm(k); return i; }
template<auto P> string   to_string(const Modular<P>& m) { return to_string(m.value); }
 
}
constexpr int mod = (int)1e9 + 7;
using mod_int = mint_ns::Modular<mod>;

struct Comb {
    int n;
    vector<mod_int> _fac, _invfac, _inv;
    Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
    Comb(int n) : Comb() { init(n); }
    void init(int m) {
        m = min(m, mod - 1);
        if (m <= n) return;
        _fac.resize(m + 1); _invfac.resize(m + 1); _inv.resize(m + 1);
        for (int i = n + 1; i <= m; i++) _fac[i] = _fac[i - 1] * i;
        _invfac[m] = _fac[m].inv();
        for (int i = m; i > n; i--) _invfac[i - 1] = _invfac[i] * i, _inv[i] = _invfac[i] * _fac[i - 1];
        n = m;
    }
    mod_int fac(int m) { if (m > n) init(2 * m); return _fac[m]; }
    mod_int invfac(int m) { if (m > n) init(2 * m); return _invfac[m]; }
    mod_int inv(int m) { if (m > n) init(2 * m); return _inv[m]; }
    mod_int ncr(int n, int r) { if (n < r || r < 0) return 0; return fac(n) * invfac(r) * invfac(n - r); }
  mod_int place(int n, int r) { return ncr(n + r - 1, r - 1); } // stars and bars : x1 + x2 - - - xr = n
} comb;
 
int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  auto __solve_testcase = [&](int test) {
    int N, M; cin >> N >> M;

    auto get_inc = [&](int a, int b) -> mod_int {
      if(b < a) return 0;
      return comb.ncr(b, a);
    };

    mod_int res = 0;
    for(int df = N - 1 ; df <= M ; ++df) {
      res += get_inc(N - 1, min(df, M - df));
      res += max(M - 2 * df, 0) * comb.ncr(df - 1, N - 2);
    }
    cout << res << '\n';
  };
  
  int NumTest = 1;
  cin >> NumTest;
  for(int testno = 1; testno <= NumTest ; ++testno) {
    __solve_testcase(testno);
  }
  
  return 0;
}

Editorialist's code (Python)
mod = 10**9 + 7
maxn = 10**6 + 5
fac = [1]
for i in range(1, maxn): fac.append(fac[-1] * i % mod)
ifac = fac[:]
ifac[-1] = pow(ifac[-1], mod-2, mod)
for i in reversed(range(maxn-1)): ifac[i] = ifac[i+1] * (i+1) % mod

def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * ifac[r] % mod * ifac[n-r] % mod

for _ in range(int(input())):
    n, m = map(int, input().split())
    
    ans = 0
    for x in range(1, m-1):
        mul = max(0, m+1-x - x)
        ans += mul * C(x-1, n-2) + C(min(x, m+1-x)-1, n-1)
    print(ans % mod)

Well using some maths and simplifying we can solve this problem in O(log(MOD)) per testcase if one knows how to find the sum \sum\limits_{r=0}^{m} \binom{r}{n}. To calculate this sum, one just needs to find the coefficient of x^n in \sum\limits_{r=0}^{m} (1+x)^r which is equal to \binom{m+1}{n+1}. Then the final answer comes out to be equal to:
4 \cdot \binom{1 + \left\lfloor \frac{M}{2} \right\rfloor}{N} - k \cdot \binom{\left\lfloor \frac{M}{2} \right\rfloor}{N - 1} , where k is equal to 1 if M is odd and is equal to 3 if M is even.
So time complexity is just O(log(MOD)) per testcase.

1 Like