PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Md. Mahamudur Rahaman Sajib
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Matrix Exponentiation, Bitwise operations.
PROBLEM:
Given an array A of N elements and an integer K, we define power of array A, denoted as A^K as
- If k = 1, then the original array A
-
A_k is an N-length array such that i-th element of this array is
(A^k)_i = (A^{k-1})_{l_i} \oplus (A^{k-1})_{l_i+1} \oplus \ldots \oplus (A^{k-1})_{r_i-1} \oplus (A^{k-1})_{r_i} \,.
You need to find the array A^K
QUICK EXPLANATION
- For a simpler problem, where xor is replaced by addition, it is easy to apply matrix exponentiation, which is a well-known problem.
- The only issue in applying matrix exponentiation is that there are carry forward in matrix exponentiation from lower bits which may affect answer for higher bits.
- To avoid that, we shall solve for each bit in answer separately and use matrix multiplication in modulo field 2 where the carry forward become multiple of 2 and thus neglected.
- Hence, we can compute the K-th power of the recurrence matrix, and for each bit, get the vector representing positions having current bit set, and find the resultant vector. This i-th position in this resulting vector determines whether i-th position shall have the current bit set or not.
EXPLANATION
Considering a simpler problem where XOR is replaced by addition, we get a linear recurrence relation, which we can solve via matrix exponentiation.
Suppose we have the current vector A^{K-1} and the recurrence matrix M, we can write A_{K} = M*A_{K-1} It is easy to write the recurrence matrix M consisting of 0 and 1 depending upon which terms contribute to next power.
But in our original problem, this approach fails at a point. The XOR operation is a bitwise operation. This means that values of i-th bits in A^K are dependent only upon i-th bits of values of A. But matrix exponentiation shall also consider the carry forward from lower bits to higher bits, which cause issues to our approach.
But there is a way after all to handle this. If we could do our operations modulo 2, the XOR operation works the same way as addition modulo 2.
So, we actually move toward solving each bit separately and do all operations.
Considering b-th bit, we get a binary vector C such that C_i = 1 if A_i has b-th bit set, otherwise C_i = 0
Our problem now becomes, given a binary array C of N values, and a binary matrix M, Find C^K = C* M^{K-1}. Doing the matrix operations modulo $2$we get the resulting C^K. If {C^K}_i = 1, {A^K}_i has b-th bit set.
Hence, this allowed us to solve the problem, in time O(N^3*log(K)+log(max(A_i))*N^2) if we precompute M^{K-1}.
Optimization
But still, the complexity is a bit high, and we can reduce it by observing that all the vector and Matrix multiplications are performed only over binary vectors/matrices. We can use bitsets in order to optimize this process. The addition operation in matrix modulo 2 can be replaced by xor of bitsets and multiplication can be replaced by and of two bitsets.
TIME COMPLEXITY
The final time complexity becomes O((N^3*log(K)+log(max(A_i))*N^2)/64) per test case.
SOLUTIONS:
Setter's Solution
#include<bits/stdc++.h>
//#include <ext/pb_ds/assoc_container.hpp>
//#include <ext/pb_ds/tree_policy.hpp>
#include <cstring>
#include <iostream>
#define pie acos(-1)
#define si(a) scanf("%d",&a)
#define sii(a,b) scanf("%d %d",&a,&b)
#define siii(a,b,c) scanf("%d %d %d",&a,&b,&c)
#define sl(a) scanf("%lld",&a)
#define sll(a,b) scanf("%lld %lld",&a,&b)
#define slll(a,b,c) scanf("%lld %lld %lld",&a,&b,&c)
#define ss(st) scanf("%s",st)
#define sch(ch) scanf("%ch",&ch)
#define ps(a) printf("%s",a)
#define newLine() printf("\n")
#define pi(a) printf("%d",a)
#define pii(a,b) printf("%d %d",a,b)
#define piii(a,b,c) printf("%d %d %d",a,b,c)
#define pl(a) printf("%lld",a)
#define pll(a,b) printf("%lld %lld",a,b)
#define plll(a,b,c) printf("%lld %lld %lld",a,b,c)
#define pd(a) printf("%lf",a)
#define pdd(a,b) printf("%lf %lf",a,b)
#define pddd(a,b,c) printf("%lf %lf %lf",a,b,c)
#define pch(c) printf("%ch",c)
#define debug1(str,a) printf("%s=%d\n",str,a)
#define debug2(str1,str2,a,b) printf("%s=%d %s=%d\n",str1,a,str2,b)
#define debug3(str1,str2,str3,a,b,c) printf("%s=%d %s=%d %s=%d\n",str1,a,str2,b,str3,c)
#define debug4(str1,str2,str3,str4,a,b,c,d) printf("%s=%d %s=%d %s=%d %s=%d\n",str1,a,str2,b,str3,c,str4,d)
#define for0(i,n) for(i=0;i<n;i++)
#define for1(i,n) for(i=1;i<=n;i++)
#define forab(i,a,b) for(i=a;i<=b;i++)
#define forstl(i, s) for (__typeof ((s).end ()) i = (s).begin (); i != (s).end (); ++i)
#define nl puts("")
#define sd(a) scanf("%lf",&a)
#define sdd(a,b) scanf("%lf %lf",&a,&b)
#define sddd(a,b,c) scanf("%lf %lf %lf",&a,&b,&c)
#define sp printf(" ")
#define ll long long int
#define ull unsigned long long int
#define MOD 1000000007
#define mpr make_pair
#define pub(x) push_back(x)
#define pob(x) pop_back(x)
#define mem(ara,value) memset(ara,value,sizeof(ara))
#define INF INT_MAX
#define eps 1e-9
#define checkbit(n, pos) (n & (1<<pos))
#define setbit(n, pos) (n (1<<pos))
#define para(i,a,b,ara)\
for(i=a;i<=b;i++){\
if(i!=0){printf(" ");}\
cout<<ara[i];\
}\
printf("\n");
#define pvec(i,vec)\
for(i=0;i<vec.size();i++){\
if(i!=0){printf(" ");}\
cout<<vec[i];\
}\
printf("\n");
#define ppara(i,j,n,m,ara)\
for(i=0;i<n;i++){\
for(j=0;j<m;j++){\
if(j!=0){printf(" ");}\
cout<<ara[i][j];\
}\
printf("\n");\
}
#define ppstructara(i,j,n,m,ara)\
for(i=0;i<n;i++){\
printf("%d:\n",i);\
for(j=0;j<m;j++){\
cout<<ara[i][j];printf("\n");\
}\
}
#define ppvec(i,j,n,vec)\
for(i=0;i<n;i++){\
printf("%d:",i);\
for(j=0;j<vec[i].size();j++){\
if(j!=0){printf(" ");}\
cout<<vec[i][j];\
}\
printf("\n");\
}
#define ppstructvec(i,j,n,vec)\
for(i=0;i<n;i++){\
printf("%d:",i);\
for(j=0;j<vec[i].size();j++){\
cout<<vec[i][j];printf("\n");\
}\
}
#define sara(i,a,b,ara)\
for(i=a;i<=b;i++){\
scanf("%d",&ara[i]);\
}
#define pstructara(i,a,b,ara)\
for(i=a;i<=b;i++){\
cout<<ara[i];nl;\
}
#define pstructvec(i,vec)\
for(i=0;i<vec.size();i++){\
cout<<vec[i];nl;\
}
#define pstructstl(stl,x)\
for(__typeof(stl.begin()) it=stl.begin();it!=stl.end();++it){\
x=*it;\
cout<<x;nl;\
}\
nl;
#define pstl(stl)\
for(__typeof(stl.begin()) it=stl.begin();it!=stl.end();++it){\
if(it!=stl.begin()){sp;}\
pi(*it);\
}\
nl;
#define ppairvec(i,vec)\
for(i=0;i<vec.size();i++){\
cout<<vec[i].first;sp;cout<<vec[i].second;printf("\n");\
}
#define ppairara(i,a,b,ara)\
for(i=a;i<=b;i++){\
cout<<ara[i].first;sp;cout<<ara[i].second;printf("\n");\
}
#define pppairvec(i,j,n,vec)\
for(i=0;i<n;i++){\
printf("%d:\n",i);\
for(j=0;j<vec[i].size();j++){\
cout<<vec[i][j].first;sp;cout<<vec[i][j].second;nl;\
}\
}
#define pppairara(i,j,n,m,ara)\
for(i=0;i<n;i++){\
printf("%d:\n",i);\
for(j=0;j<m;j++){\
cout<<ara[i][j].first;printf(" ");cout<<ara[i][j].second;nl;\
}\
}
#define SZ 2 * 100010
#define xx first
#define yy second
using namespace std;
//using namespace __gnu_pbds;
//bool status[100010];
//vector <int> prime;
//void siv(){
// int N = 100005, i, j; prime.clear();
// int sq = sqrt(N);
// for(i = 4; i <= N; i += 2){ status[i] = true; }
// for(i = 3; i <= sq; i+= 2){
// if(status[i] == false){
// for(j = i * i; j <= N; j += i){ status[j] = true; }
// }
// }
// status[1] = true;
// for1(i, N){ if(!status[i]){ prime.pub(i); } }
//}
//mt19937_64 mt(chrono::steady_clock::now().time_since_epoch().count());
//auto seed = chrono::high_resolution_clock::now().time_since_epoch().count();
//std::mt19937 mt(seed);
inline int add(int _a, int _b){
if(_a < 0){ _a += MOD; }
if(_b < 0){ _b += MOD; }
if(_a + _b >= MOD){ return _a + _b - MOD; }
return _a + _b;
}
inline int mul(int _a, int _b){
if(_a < 0){ _a += MOD; }
if(_b < 0){ _b += MOD; }
return ((ll)((ll)_a * (ll)_b)) % MOD;
}
const int N = 500;
int n, k, l[N + 5], r[N + 5];
ll ara[N + 5], sol[N + 5];
bitset <N> row[N + 5], col[N + 5];
bitset <N> mat_row[N + 5], mat_col[N + 5];
void mat_mul(bitset <N> *a_row, bitset <N> *a_col, bitset <N> *b_row, bitset <N> *b_col){
int i, j;
bitset <N> c_row[N + 5], c_col[N + 5];
for0(i, n){
for0(j, n){
c_row[i][j] = c_col[j][i] = (a_row[i] & b_col[j]).count() & 1;
}
}
for0(i, n) a_row[i] = c_row[i], a_col[i] = c_col[i];
}
void expo(int p){
int i, j;
for0(i, n) mat_row[i][i] = mat_col[i][i] = 1;
while(p){
if(p & 1) mat_mul(mat_row, mat_col, row, col);
mat_mul(row, col, row, col), p >>= 1;
}
}
void solve(){
int i, j;
mem(sol, 0);
for0(i, n) for0(j, n) row[i][j] = col[j][i] = 0, mat_row[i][j] = mat_col[j][i] = 0;
for0(i, n) for(j = l[i]; j <= r[i]; ++j) row[i][j] = col[j][i] = 1;
expo(k - 1);
bitset <N> bit;
for0(i, 60){
for0(j, n){
if(ara[j] & 1ll << i) bit[j] = 1;
else bit[j] = 0;
}
for0(j, n){
int x = (mat_row[j] & bit).count() & 1;
if(x) sol[j] |= 1ll << i;
}
}
for0(i, n){
if(i) sp;
pl(sol[i]);
} nl;
}
int main(){
// freopen("input.txt","r",stdin);
// freopen("0.in", "r", stdin);
// freopen("0.out", "w", stdout);
// freopen("1.in", "r", stdin);
// freopen("1.out", "w", stdout);
// freopen("2.in", "r", stdin);
// freopen("2.out", "w", stdout);
// freopen("output.txt", "w", stdout);
int cs, ts;
si(ts);
for0(cs, ts){
int i, j;
sii(n, k);
assert(n >= 1 && n <= 500 && k >= 1 && k <= 50000000);
for0(i, n){
sl(ara[i]);
assert(ara[i] >= 1ll && ara[i] <= 1000000000000000000ll);
}
for0(i, n){
sii(l[i], r[i]);
assert(l[i] >= 1 && l[i] <= n && r[i] >= 1 && r[i] <= n);
--l[i], --r[i];
}
solve();
}
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983
#define int ll
int res[512][512],c[512][512],mat[512][512];
int mult(int n,int a[512][512],int b[512][512]){
int i,j,k;
rep(i,n){
rep(j,n){
c[i][j]=0;
}
}
vector< bitset<512> > vec1(n,0);
vector< bitset<512> > vec2(n,0);
bitset<512> gg;
rep(i,n){
rep(j,n){
vec1[i][j]=a[i][j];
}
//cout<<vec1[i]<<endl;
}
rep(i,n){
rep(j,n){
vec2[i][j]=b[j][i];
}
}
rep(i,n){
rep(j,n){
gg=vec1[i]&vec2[j];
c[i][j]=gg.count();
}
}
rep(i,n){
rep(j,n){
a[i][j]=(c[i][j]&1);
}
}
return 0;
}
int getpow(int n,int k){
int i,j;
rep(i,n){
rep(j,n){
res[i][j]=0;
}
res[i][i]=1;
}
while(k){
if(k%2){
mult(n,res,mat);
}
//return 0;
mult(n,mat,mat);
k/=2;
}
return 0;
}
int l[12345],r[12345],a[12345],bit[12345],ans[12345];
signed main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
while(t--){
int n,k;
n=501;
k=5e7;
cin>>n>>k;
int i,j;
rep(i,n){
//a[i]=12345;
cin>>a[i];
ans[i]=0;
//cout<<i<<" "<<a[i]<<endl;
}
//return 0;
rep(i,n){
l[i]=1;
r[i]=n;
cin>>l[i]>>r[i];
//cout<<l[i]<<" "<<r[i]<<endl;
l[i]--;
r[i]--;
}
rep(i,n){
rep(j,n){
mat[i][j]=0;
}
}
rep(i,n){
f(j,l[i],r[i]+1){
mat[i][j]=1;
}
}
k--;
getpow(n,k);
int p;
rep(i,60){
//cout<<(1LL<<i)<<endl;
rep(j,n){
bit[j]=0;
if(a[j]&(1LL<<i))
bit[j]=1;
}
rep(j,n){
int gg=0;
rep(p,n){
gg+=bit[p]*res[j][p];
}
gg%=2;
if(gg%2)
ans[j]+=(1LL<<i);
}
}
rep(i,n){
cout<<ans[i]<<" ";
}
cout<<endl;
cerr << "\nTime elapsed: " << 1000 * clock() / CLOCKS_PER_SEC << "ms\n";
}
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class HXR{
//SOLUTION BEGIN
int B = 60, KB = 26;
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();long K = nl()-1;
long[] a = new long[N];
for(int i = 0; i< N; i++)a[i] = nl();
BitSet[] matrix = new BitSet[N];
for(int i = 0; i< N; i++)matrix[i] = new BitSet(N);
for(int i = 0; i< N; i++){
int l = ni()-1, r = ni()-1;
for(int j = l; j<= r; j++)matrix[i].set(j);
}
matrix = pow(N, matrix, K);
long[] ans = new long[N];
for(int valueBit = 0; valueBit < B; valueBit++){
BitSet vec = new BitSet(N);
for(int i = 0; i< N; i++)vec.set(i, (a[i]&(1L<<valueBit))>0);
vec = mul(N, matrix, vec);
for(int i = 0; i< N; i++)if(vec.get(i))ans[i] |= 1L<<valueBit;
}
for(int i = 0; i< N; i++)p(ans[i]+" ");pn("");
}
BitSet[] pow(int N, BitSet[] matrix, long K){
BitSet[] ans = new BitSet[N];
for(int i = 0; i< N; i++){
ans[i] = new BitSet(N);
ans[i].set(i);
}
for(; K > 0; K>>=1){
if((K&1)==1)ans = mul(N, ans, matrix);
matrix = mul(N, matrix, matrix);
}
return ans;
}
BitSet mul(int N, BitSet[] matrix, BitSet column){
BitSet res = new BitSet(N);
for(int i = 0; i< N; i++){
BitSet row = matrix[i].get(0, N);
row.and(column);
if(row.cardinality()%2 == 1)res.set(i);
}
return res;
}
BitSet getColumn(int N, BitSet[] matrix, int idx){
BitSet col = new BitSet(N);
for(int i = 0; i< N; i++)col.set(i, matrix[i].get(idx));
return col;
}
BitSet[] mul(int N, BitSet[] A, BitSet[] B){
BitSet[] C = new BitSet[N];
for(int i = 0; i< N; i++)C[i] = new BitSet(N);
for(int i = 0; i< N; i++){
BitSet col = mul(N, A, getColumn(N, B, i));
for(int j = 0; j< N; j++)C[j].set(i, col.get(j));
}
return C;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new HXR().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.