CHARVER - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Srikkanth R
Tester: Utkarsh Gupta
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Randomization, Number theory

PROBLEM

Given an array C of M integers and a square matrix A (with integer entries) of dimension N \times N, verify whether the following equation is true,

C_0 I_{N} + C_1 A + C_2 A^2 + \dots C_{M-1} A^{M-1} \equiv 0_N \pmod{998244353}

where 0_N is the square null matrix (matrix in which all elements are zero) and I_N is the identity matrix, both having dimensions N \times N.

QUICK EXPLANATION

  • Choose a column vector B with N random values in range [0, 998244352] and for each term in left expression, compute its product with B. If the final column vector contains at least one non-zero entry, then the given equation is not true.
  • The above method is probabilistic, which may give false positives, but not a false negatives. So, running it several times should be enough to avoid false positives.

EXPLANATION

We can manually compute the final matrix on the left side of the expression and check if is 0_N, but that involves doing matrix multiplication for N \times N matrices, which would take O(N^3) time. But problem statement specifies N = 1000, where O(N^3) is too slow.

The core of this problem is to find a faster way of checking the above condition while avoiding multiplying two N \times N matrices.

Let’s consider first example where
A = \left [ \begin{matrix} 1 && 2 \\ 3 && 4\end{matrix} \right ] and we have -2*I_N -5*A + A^2, which gives \left [\begin{matrix} -2 -5 + 7 && 0 -10 + 10 \\ 0-15+15 && -2-20+22 \end{matrix} \right ] = 0_N

Let’s consider a column vector B = \left [ \begin{matrix} x \\ y \end{matrix} \right ] and mutliply it with each term of expression, getting -2*I_N*B -5*A*B + A^2 * B = 0_N*B = 0_N

The expression becomes

\displaystyle \left [ \begin{matrix} -2*x \\ -2*y \end{matrix} \right ] + \left [ \begin{matrix} -5*x-10*y \\ -15*x - 20*y \end{matrix} \right ] + \left [ \begin{matrix} 7*x+10*y \\ 15*x+22*y\end{matrix} \right ] = \left [ \begin{matrix} 0 \\ 0 \end{matrix} \right ]

Hence, irrespective of values of x and y, the resulting expression here is a column vector containing N zeros when the recurrence is satisfied.

Hence, if we try B with random values filled, we shall always get 0 vector.

Theory
if we think about it theoretically, multiplying a matrix with a column vector is effectively taking a linear combination of each of its rows.

For example, considering matrix X = \left [ \begin{matrix} a && b \\ c && d \end{matrix} \right ] and column vector B = \left [ \begin{matrix} x \\ y \end{matrix} \right ], X*B = \left [ \begin{matrix} a*x+b*y \\ c*x+d*y \end{matrix} \right ]. Each row \left [ \begin{matrix} a && b\end{matrix} \right ] and \left [ \begin{matrix} c && d \end{matrix} \right ] is taken as a linear combination in resulting vector.

If we think about it, it is similar to checking if x*A+y*B = 0 to check if A = 0 and B = 0 for random values x and y. This happens for all rows.

If it doesn’t satisfy the recurrence

But, if let’s say we had got a column vector \left [ \begin{matrix} 3*x+5*y \\ 0*x+0*y \end{matrix} \right ], this can be zero vector if 3*x+5*y = 0 is satisfied, which may happen.

So, if we run the above process multiple times and get a non-zero term even a single time, this means that the recurrence is not satisfied.

Implementation

We can pick random vector B for R repetitions, We can write the expression as -2*I_N+A*(-5+A) and compute I_N*B, A*B and A^2*B in this order, and add C_i times the vector to a sum vector. The sum vector should be filled with 0 every time if the recurrence is satisfied.

We only needed to multiply matrices of dimension N \times N with a vector of dimension N \times 1 which can be done in O(N^2), leading to overall time complexity O(R*M*N^2)

TIME COMPLEXITY

The time complexity is O(R*M*N^2) per test case, where R is the number of repetitions.

SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
typedef vector<int> vint;
typedef vector<vector<int>> mat;
#define LL long long
LL seed = chrono::steady_clock::now().time_since_epoch().count();
mt19937_64 rng(seed);
#define rand(l, r) uniform_int_distribution<LL>(l, r)(rng)
clock_t start = clock();
mat operator+(mat a, mat b) {
    int n = a.size(), m = a[0].size();
    assert(b.size() == n && b[0].size() == m);
    mat ret(n, vint(m, 0));
    for (int i=0;i<n;++i) {
        for (int j=0;j<m;++j) {
            ret[i][j] = a[i][j] + b[i][j];
            if (ret[i][j] >= MOD) ret[i][j] -= MOD;
        }
    }
    return ret;
}   
mat operator*(mat a, mat b) {
    int n = a.size(), m = a[0].size(), r = b[0].size();
    assert(b.size() == m);
    mat ret(n, vint(r, 0));
    for (int i=0;i<n;++i) {
        for (int j=0;j<r;++j) {
            int res = 0;
            for (int k=0;k<m;++k) {
                res += (a[i][k] * 1LL * b[k][j]) % MOD;
                if (res >= MOD) res -= MOD;
            }
            ret[i][j] = res;
        }
    }
    return ret;
}   
mat operator*(mat a, int b) {
    int n = a.size(), m = a[0].size();
    mat ret(n, vint(m, 0));
    for (int i=0;i<n;++i) {
        for (int j=0;j<m;++j) {
            ret[i][j] = (a[i][j] * 1LL * b) % MOD;
        }
    }
    return ret;
}   
mat zero(int n) {
    return mat(n, vint(n, 0));
}
mat eye(int n) {
    mat ret = zero(n);
    for (int i=0;i<n;++i) ret[i][i] = 1;
    return ret;
}
int main() {
    ios_base::sync_with_stdio(false);cin.tie(NULL);
    int T;
    cin >> T;
    while (T--) {
        int m;
        cin >> m;
        vint v(m);
        for (int i=0;i<m;++i) cin >> v[i];
        int n;
        cin >> n;
        mat A(n, vint(n, 0));
        for (int i=0;i<n;++i) for (int j=0;j<n;++j) cin >> A[i][j];
        bool ans = true;
        int iter = 5;
        while (iter--) {
            mat X(n, vint(1, 0));
            for (int i=0;i<n;++i) X[i][0] = (int)rand(0, MOD-1);
            mat mul = X, ret(n, vint(1, 0));
            for (int i=0;i<m;++i) {
                ret = ret + (mul * v[i]);
                mul = (A * mul);
            }
            for (int i=0;i<n;++i) {
                if (ret[i][0] != 0) ans = false;
            }
        }
        cout << (ans ? "yes\n" : "no\n");
    }
    cerr << fixed << setprecision(10);
    cerr << (clock() - start) / ((long double)CLOCKS_PER_SEC) << " secs\n";
    return 0;
}
Tester's Solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#include <chrono>
#include <random>
#define ll long long int
#define ull unsigned long long int
#define pb push_back
#define mp make_pair
#define mod 998244353
#define rep(i,n) for(ll i=0;i<n;i++)
#define loop(i,a,b) for(ll i=a;i<=b;i++)
#define vi vector <int>
#define vs vector <string>
#define vc vector <char>
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
#define max3(a,b,c) max(max(a,b),c)
#define min3(a,b,c) min(min(a,b),c)
#define deb(x) cerr<<#x<<' '<<'='<<' '<<x<<'\n'
using namespace std;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
typedef vector<vector<ll>> matrix;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
int K;
// computes A * B
matrix mul(matrix A, matrix B)
{
    matrix C(K+1, vector<ll>(K+1));
    for(int i=1;i<=K;i++) for(int j=1;j<=K;j++) for(int k=1;k<=K;k++)
        C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod;
    return C;
}

// computes A ^ p
matrix pow(matrix A, ll p)
{
    if (p == 1)
        return A;
    if (p % 2)
        return mul(A, pow(A, p-1));
    matrix X = pow(A, p/2);
    return mul(X, X);
}
//matrix ans(K+1,vl(K+1));
vl matrixmultvector(matrix A,vl B)
{
    vl temp;
    temp.pb(0);
    for(int i=1;i<=K;i++)
    {
        ll ans=0;
        for(int j=1;j<=K;j++)
        {
            ans+=(A[i][j]*B[j]);
            ans%=mod;
        }
        temp.pb(ans);
    }
    return temp;
}
vl constmultvector(vl A,ll x)
{
    for(int i=0;i<=K;i++)
    {
        A[i]*=x;
        A[i]%=mod;
    }
    return A;
}
vl addvectors(vl A,vl B)
{
    for(int i=0;i<=K;i++)
    {
        A[i]+=B[i];
        A[i]%=mod;
    }
    return A;
}
void solve()
{
    ll m;
    cin>>m;
    vl coeff;
    for(int i=0;i<m;i++)
    {
        ll c;
        cin>>c;
        coeff.pb(c);
    }
    ll n;
    cin>>n;
    matrix A(n+1,vl(n+1));
    K=n;
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
            cin>>A[i][j];
    }
    vl ans(n+1);
    vl curr(n+1);
    for(int i=0;i<=n;i++)
    {
        ans[i]=0;
        curr[i]=1;
    }
    curr[0]=0;
    for(int i=0;i<m;i++)
    {
        vl temp=constmultvector(curr,coeff[i]);
        ans=addvectors(ans,temp);
        curr=matrixmultvector(A,curr);
    }
    int flag=0;
    for(int i=1;i<=K;i++)
    {
        if(ans[i]!=0)
        {
            flag=1;
            break;
        }
    }
    if(flag)
        cout<<"no\n";
    else
        cout<<"yes\n";
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int T=1;
    cin>>T;
    int t=0;
    while(t++<T)
    {
        //cout<<"Case #"<<t<<":"<<' ';
        solve();
        //cout<<'\n';
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class CHARVER{
    //SOLUTION BEGIN
    final long MOD = 998244353;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int M = ni();
        long[] C = new long[M];
        for(int i = 0; i< M; i++)C[i] = nl();
        int N = ni();
        long[][] A = new long[N][N];
        for(int i = 0; i< N; i++)
            for(int j = 0; j< N; j++)
                A[i][j] = nl();
        
        boolean yes = true;
        Random R = new Random();
        for(int t = 0; t < 50; t++){
            long[] delta = new long[N];
            long[] col = new long[N];
            for(int i = 0; i< N; i++)col[i] = R.nextInt((int)MOD);
            for(int r = 0; r< M; r++){
                for(int i = 0; i< N; i++)delta[i] += col[i]*C[r]%MOD;
                col = mul(A, col);
            }
            boolean good = true;
            for(int i = 0; i< N; i++)good &= delta[i]%MOD == 0;
            yes &= good;
        }
        pn(yes?"YES":"NO");
    }
    long[] mul(long[][] A, long[] B){
        int N = B.length;
        long[] C = new long[N];
        for(int i = 0; i< N; i++){
            for(int j = 0; j< N; j++)
                C[i] += A[i][j]*B[j]%MOD;
            C[i] %= MOD;
        }
        return C;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new CHARVER().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

2 Likes

Can anyone give any tips on how I can solve these kind of hard problems?

4 Likes

Can anyone explain how will we find A^2, A^3…so on.

You don’t find that. That’s point of the editorial. Instead find B, A*B, A*(A*B) and so on where B is a random column vector of size n. Note that A*B, A^2*B etc are all column vectors as well. Thus multiplication reduces to O(n^2) from the normal O(n^3).

1 Like

Such a nice editorial @taran_1407
But I have issue with tester’s code @utkarsh_adm

for the below case expected output is no and setter’s code giving the same but tester’s code giving yes. So It must have failed
CASE
T=1
M=2
C=[1,1]
N=2
A=[[499122176 , 499122176],[ 998244352 , 0]]

Expected Output: no
setter’s output : no
Tester’s Output : yes
Correct me if i am considering wrong

1 Like

ohh…Now I understood. Thank you so much.

I know the logic but i got stuck bcz i wasnt able to manage when size of the matrix is more than 2

I think tester’s algorithm is computing a vector(N) where each element is the corresponding sum of elements of each row in the resultant matrix, hence in your test case, :-
where (X =998244353 ) matrix A is
A = [[X/2, X/2],[X, 0]]
the resultant matrix comes out to be
res = [[X/2 + 1, X/2],[X + 1,0]]
where both rows sum to X+1

1 Like

Setter’s solution is based on randomisation which have very low probability of failure as he checks for many random column vector but tester took a column vector having all elements as one which have very high probability of failure and this is the case where it’s failing