ARLR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics, matrix exponentiation

PROBLEM:

For an array A of length N, consider the following process:

  • At time T, all indices i such that A_i = T will explode (if they haven’t exploded before).
  • When i explodes, if A_{i-1} equals A_{i}+1, it’ll explode too (in the same second).
    This applies to A_{i+1} as well.

At time T, you can save a single index i such that A_i = T and index i hasn’t exploded before.
The safety of A is the maximum number of indices that can possibly be saved.

Given N and M, find the sum of the safety of A across all arrays A of length N, with elements from 1 to M.

EXPLANATION:

First, as always, we need to figure out what the safety of a fixed array A is.

Observe that at time T, (at most) one instance of value T can be saved; and we aren’t allowed to save any other value at this instant.
So, at most one occurrence of each integer in the array can be saved - so the answer is bounded above by the number of distinct elements in A.

However, it’s not always possible to save one copy of each element - the simplest example of this is A = [1, 2, 1], where it’s impossible to save the 2 (no matter which 1 is saved, the other one will explode and take the 2 with it).
Extending the idea behind this example, it’s easy to see that for some value x, it it’s neighbors are both x-1, it’s impossible to save it.

Let’s say A_i is surrounded if A_{i-1} = A_i - 1 and A_{i+1} = A_i + 1 (meaning it’s impossible to save A_i).
For a fixed value x, if every occurrence of x in A is surrounded, clearly it’s impossible to save any of them.
On the other hand, if there exists even one occurrence that’s not surrounded, it’s always possible to save this occurrence!

Proof

We’ll prove this with induction on the value being saved.
The claim is “if there exists a non-surrounded occurrence of x, it can be saved at time x”.

For x = 1 this is trivially true: any occurrence of 1 can be saved in the first second.

Suppose the claim is true till x-1, and we’ll look at what happens with x.
Consider some non-surrounded occurrence of x, with neighbors y and z (if x is on the border of A, y or z might not exist; in such a case pretend they’re some large number, say 10^{100}).

If both y and z don’t equal x-1, it’s obviously possible to save x: whether its neighbors explode or not doesn’t affect it at all, so it’ll remain intact at time x (and hence can be saved).

The remaining case is when one neighbor equals x-1, without loss of generality let this be y.
Note that z \neq x-1, so whether z explodes or not won’t affect x or not; meaning we can ignore it entirely.

Now, y neighbors x, which in turn means x neighbors y; and this means y isn’t surrounded either.
Since y = x-1, by the inductive hypothesis we know there’s a way to save y — and if y doesn’t explode and it doesn’t matter whether z explodes or not, clearly x itself can’t explode either.
So, this occurrence of x can be kept unexploded at time x, at which point it can be saved.


So, the answer for a fixed array is the number of its elements such that at least one occurrence is not surrounded.

To find the sum of safety across all arrays, we use the trick of counting contributions.
That is, instead of fixing an array A and computing its answer, we fix an element x, and count the number of arrays to which x adds 1 to the answer; this can then be summed up across all values of x.

That leaves us with the following problem: given x, how many arrays of length N with elements from 1 to M have at least one non-surrounded occurrence?

If x = 1, this is simply the number of arrays that contain a 1 at all.
That’s easy to compute: it’s M^N - (M-1)^N (the total number of arrays, minus the number of arrays with elements from 2 to M only).

If x \gt 1 however, the situation isn’t quite so simple.
Let’s try to count the opposite: how many arrays have all their occurrences of x surrounded?

One (slow) way to compute this is to use DP.
Let f_x(i, j) denote the number of arrays of length i such that:

  • Every occurrence of x so far is surrounded; and
  • j represents the state of the last element. Specifically,
    • if j = 0, the last element is x.
    • If j = 1, the last element is x-1.
    • If j = 2, the last element is anything other than these two values.

Note that when computing dp(x, 0), the occurrence of x at the end of the array technically isn’t surrounded.
We’ll allow this, as long as we ensure that the next placed element is x-1.
In particular, this means the value we’re looking for is f_x(N, 1) + f_x(N, 2) (since the N-th element can’t be x, because we aren’t placing anything after it).

It’s not too hard to come up with a recurrence for this definition:

f_x(i, 0) = f_x(i-1, 1) \\ f_x(i, 1) = f_x(i-1, 0) + f_x(i-1, 1) + f_x(i-1, 2) \\ f_x(i, 2) = (M-2)\cdot (f_x(i-1, 1) + f_x(i-1, 2))

because:

  • Before an x, there must be an x-1.
  • Before an x-1, there can be anything.
  • Before one of the other M-2 values, there can’t be an x, but anything else is ok.
    Here, we can also choose which of the M-2 values we’d like to place.

Computing this using DP will take \mathcal{O}(N) time normally.
However, notice that f_x(i, \cdot) depends only on f_x(i-1, \cdot) in a linear manner.
This is exactly the type of recurrence that can be optimized with the help of matrix exponentiation!

The exact details of building an appropriate matrix are left as an exercise to the reader; though the code below contains one implementation.

At any rate, matrix exponentiation allows us to get the required value in \mathcal{O}(\log N) time.
This solves the problem for a single x.
How do we do it separately for every one of them from 2 to M?

As it turns out, we don’t have to!
Observe that the recurrence we defined didn’t depend on the value of x at all!
This means the answer is exactly the same for every x from 2 to M, so we compute it for one of them and multiply the value by M-1.

TIME COMPLEXITY:

\mathcal{O}(\log N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define di(a) int a;cin>>a;
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define sis string s;
#define sin string s;cin>>s;
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;


const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
long long pow(long long a, long long b, long long m = M) {
    a %= m;
    long long res = 1;
    while (b > 0) {
        if (b & 1)
            res = res * a % m;
        a = a * a % m;
        b /= 2;
    }
    return res;
}

const int MOD = M;
 
struct Matrix
{
    vector< vector<int> > mat;
    int n_rows, n_cols;
 
    Matrix() {}
 
    Matrix(vector< vector<int> > values): mat(values), n_rows(values.size()),
        n_cols(values[0].size()) {}
 
    static Matrix identity_matrix(int n)
    {
        vector< vector<int> > values(n, vector<int>(n, 0));
        for(int i = 0; i < n; i++)
            values[i][i] = 1;
        return values;
    }
 
    Matrix operator*(const Matrix &other) const 
    {
        int n = n_rows, m = other.n_cols;
        vector< vector<int> > result(n_rows, vector<int>(m, 0));
        for(int i = 0; i < n; i++)
            for(int j = 0; j < m; j++) {
                long long tmp = 0;
                for(int k = 0; k < n_cols; k++) {
                    tmp += ((mat[i][k]%MOD) * 1ll * (other.mat[k][j]%MOD))%MOD;
                    tmp %= MOD;
                }
                result[i][j] = tmp % MOD;
            }
 
        return move(Matrix(move(result)));
    }
 
    inline bool is_square() const
    {
        return n_rows == n_cols;
    }
};
Matrix pw(Matrix a,int p){
    Matrix result = Matrix::identity_matrix(a.n_cols);
    while (p > 0) {
        if (p & 1)
            result = a * result;
        a = a * a;
        p /= 2;
    }
    return result;
}
 
void solve()
{
    int n,m;
    cin>>n>>m;
    assert(1<=n && n<=1e9);
    assert(1<=m && m<=1e9);
    if(m==1){
        cout<<"1\n";
        return;
    }
    vector<vector<int>> v = {{0,0,1}};
    vector<vector<int>> v1 = {
                                {0,1,0},
                                {1,1,(m-2)%M},
                                {0,1,(m-2)%M},
                             };
    Matrix A(v1);
    Matrix B(v);
    Matrix ans = B*pw(A,n);
    ll final=(pow(m,n)-pow(m-1,n))%MOD;
    final += MOD;final %= MOD;
    ll fff=(m-1)%MOD;
    ll cnt=(((pow(m,n)-ans.mat[0][1])%MOD)-ans.mat[0][2])%MOD;
    cnt*=fff;cnt%=MOD;cnt+=MOD;cnt%=MOD;
    final+=cnt;final%=MOD;final += MOD;final %= MOD;
    cout<<final<<"\n";

}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    //freopen("input.txt","r",stdin);freopen("output.txt","w",stdout);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    ll tt;
    cin>>tt;
    assert(1<=tt && tt<=100);
    while(tt--){
        solve();
    }
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = (int) 1e9 + 7;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

template <typename T>
vector<vector<T>> operator*(const vector<vector<T>>& a, const vector<vector<T>>& b) {
    vector<vector<T>> c(a.size(), vector<T>(b[0].size()));
    for (int i = 0; i < (int) c.size(); i++) {
        for (int k = 0; k < (int) b.size(); k++) {
            for (int j = 0; j < (int) c[0].size(); j++) {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    return c;
}

template <typename T>
vector<vector<T>>& operator*=(vector<vector<T>>& a, const vector<vector<T>>& b) {
    return a = a * b;
}

template <typename T>
vector<vector<T>> power(vector<vector<T>> a, long long n) {
    vector<vector<T>> res(a.size(), vector<T>(a.size()));
    for (int i = 0; i < (int) a.size(); i++) {
        res[i][i] = 1;
    }
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e2);
    in.readEoln();
    while (tt--) {
        int n = in.readInt(1, 1e9);
        in.readSpace();
        int m = in.readInt(1, 1e9);
        in.readEoln();
        vector<vector<mint>> a = {
            {m - 2, m - 2, 0, 0},
            {1, 1, 1, 0},
            {0, 1, 0, 0},
            {1, 0, m - 1, m}};
        a = power(a, n);
        mint t = a[2][0] + a[3][0];
        t *= m - 1;
        t += power(m, n) - power(m - 1, n);
        cout << t << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7

from copy import deepcopy
mat_mul = lambda A, B: [[sum(i%mod * j%mod for i, j in zip(row, col)) for col in zip(*B)] for row in A]
def eye(m):
    """returns an indentity matrix of order m"""
    identity = [[0] * m for _ in range(m)]
    for i, row in enumerate(identity):
        row[i] = 1
    return identity
def mat_pow(mat, power):
    """returns mat**power"""

    result = eye(len(mat))
    if power == 0:
        return result

    while power > 1:
        if power & 1 == 1:
            result = mat_mul(result, mat)
        mat = mat_mul(mat, mat)
        power >>= 1
    return mat_mul(result, mat)

for _ in range(int(input())):
    n, m = map(int, input().split())
    mat = [
        [0, 1, 0],
        [1, 1, 1],
        [0, m-2, m-2]
    ]
    pw = mat_pow(mat, n)
    res = mat_mul(pw, [[0], [0], [1]])
    ans = pow(m, n, mod) - res[1][0] - res[2][0]
    ans *= m-1
    ans += pow(m, n, mod) - pow(m-1, n, mod)
    print(ans % mod)