PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Setter: Abhinav Sharma
Tester: Tejas Pandey
Editorialist: Kanhaiya Mohan
DIFFICULTY:
Easy-Medium
PREREQUISITES:
PROBLEM:
You are given an array A of length N such that 1 \leq A_i \leq N and an integer M. Count the number of arrays B of length N, such that
- B_i \neq B_{A_i}
- 1 \leq B_i \leq M
Since the answer can be large, print it modulo 10^9 + 7.
QUICK EXPLANATION:
- Build a graph with N edges as (i, A_i). The graph will have some connected components.
- Each connected component has exactly one cycle.
- Calculate the answer for the cycle. Each non-cycle node can be filled in (M-1) ways. The total answer for a component is the product of the answer for cycle as well as that of non-cycle nodes.
- Final answer is the product of the answers of all connected components.
EXPLANATION:
Observation
We are given that B_i \neq B_{A_i}. To connect the i^{th} position with the {A_i}^{th} position, we can build a graph with N edges. Each edge is generated using the vertices i and A_i.
Considering 1 \leq A_i \leq N, the graph would consist of various connected components. Each of the connected components will have exactly one cycle. In other words, the graph would be a tree with one extra edge which forms a cycle.
Proof
Any connected component of size K will have exactly K edges. For more than one cycle to exist, there has to be atleast K+1 edges. Also, for no cycles, there can’t be more than K-1 edges. Thus, there is exactly one cycle present.
For each connected component, let us look at this cycle part separately.
Subproblem
Problem: Find number of arrays B of length N(\geq 2) such that:
- B_1 \neq B_N
- B_i \neq B_{i+1} (1 \leq i < N)
- 1 \leq B_i \leq M
Solution: We need to find the number of arrays such that no two consecutive elements have the same value (including first and last element) and each element has value in range [1,M]. We can use dynamic programming to calculate this.
Let dp[i][0] denote the number of ways to fill first i elements such that B_j \neq B_{j+1} (1 \leq j < i) and B_1 \neq B_i. Similarly, let dp[i][1] denote the number of ways to fill first i elements such that such that B_j \neq B_{j+1} (1 \leq j < i) and B_1 = B_i.
The answer to our problem is the value dp[N][0].
- Base Case: dp[1][0] = 0, there is no way in which we can fill the first element such that it is not equal to the first element. Similarly, dp[1][1] = M.
- Calculating dp[i][1]: Since there is only one way of filling the i^{th} element, dp[i][1] is nothing but dp[i-1][0].
- Calculating dp[i][0]: If the (i-1)^{th} element is equal to the first element, the i^{th} element can be filled in (M-1) ways (all numbers except the value of the first element). Else, there are (M-2) ways to fill that (all numbers except the value of the first and the (i-1)^{th} element). Thus, dp[i][0] = (dp[i-1][1] \cdot (M-1) + dp[i][0] \cdot (M-2)) \% mod.
Conclusion
Consider a connected component with X nodes. Let this component contain a cycle of length Y. For the cycle part, the answer is dp[Y][0]. For all other nodes, which are not part of the cycle, there are (M-1) ways to fill each one of them. Thus, the answer for this component is (dp[Y][0] \cdot (Y-X)^{M-1}) \%mod.
The final answer would be the product of the answers of all such components.
TIME COMPLEXITY:
The time complexity is O(N) per test case.
SOLUTION:
Setter's Solution
#include<bits/stdc++.h>
using namespace std;
const long long mod = 1000000007;
long long po(long long x, long long n){
long long ans=1;
while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
return ans;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int T=1;
int sum_len = 0;
cin >> T;
while(T--){
int n,m;
cin>>n>>m;
int a[n];
for(int i=0; i<n; i++){
cin>>a[i];
a[i]--;
}
long long ans = 1;
int vis[n] = {0};
for(int i=0; i<n; i++){
if(!vis[i]){
int cnt = 0;
int curr = i;
while(!vis[curr]){
cnt++;
vis[curr]=1;
curr = a[curr];
}
if(vis[curr]==2){
ans*=po(m-1, cnt);
ans%=mod;
}
else{
int cyc_len = 1;
for(int i=a[curr]; i!=curr; i=a[i]) cyc_len++;
long long tmp = po(m-1, cyc_len)+((cyc_len&1)?-1:1)*(m-1);
tmp%=mod;
ans*=tmp;
ans%=mod;
ans*=po(m-1, cnt-cyc_len);
ans%=mod;
}
curr = i;
while(vis[curr]!=2){
vis[curr]=2;
curr=a[curr];
}
}
}
ans%=mod;
if(ans<0) ans+=mod;
cout<<ans<<'\n';
}
return 0;
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long int
#define mod 1000000007
ll mpow(ll a, ll b) {
ll res = 1;
while(b) {
if(b&1) res *= a, res %= mod;
a *= a;
a %= mod;
b >>= 1;
}
return res;
}
int main() {
int t;
cin >> t;
while(t--) {
int n, m; cin >> n >> m;
int a[n];
for(int i = 0; i < n; i++) cin >> a[i], a[i]--;
int vis[n];
memset(vis, 0, sizeof(vis));
ll ans = 1;
ll v[n + 1];
v[1] = 0, v[2] = m - 1;
for(int i = 3; i <= n; i++)
v[i] = ((v[i - 1]*(m - 2))%mod + (v[i - 2]*(m - 1))%mod)%mod;
for(int i = 0; i < n; i++) {
if(vis[i]) continue;
int lst = -1, now = i, sz = 0;
vector<int> pro;
while(!vis[now]) {
vis[now] = 1;
pro.push_back(now);
sz++;
int nxt = a[now];
if(vis[nxt] == -1) {
ans *= mpow(m - 1, sz);
ans %= mod;
break;
}
if(vis[nxt]) {
int cycle = -1;
for(int i = 0; i < sz; i++)
if(nxt == pro[i])
cycle = sz - i;
ans *= (m*v[cycle])%mod;
ans %= mod;
ans *= mpow(m - 1, sz - cycle);
ans %= mod;
break;
}
lst = now;
now = nxt;
}
for(int j = 0; j < sz; j++)
vis[pro[j]] = -1;
}
cout << ans << "\n";
}
return 0;
}
Editorialist's Solution
#include<bits/stdc++.h>
using namespace std;
const long long mod = 1000000007;
long long dp[100005][2];
void init(int n, int m){
dp[1][0] = 0;
dp[1][1] = m;
for(int i = 2; i<=n; i++){
dp[i][1] = dp[i-1][0];
dp[i][0] = ((dp[i-1][1]*(m-1))%mod + (dp[i-1][0]*(m-2))%mod)%mod;
}
}
long long po(long long x, long long n){
long long ans=1;
while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
return ans;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int T=1;
int sum_len = 0;
cin >> T;
while(T--){
int n,m;
cin>>n>>m;
int a[n];
for(int i=0; i<n; i++){
cin>>a[i];
a[i]--;
}
long long ans = 1;
int vis[n] = {0};
init(n, m);
for(int i=0; i<n; i++){
if(!vis[i]){
int cnt = 0;
int curr = i;
while(!vis[curr]){
cnt++;
vis[curr]=1;
curr = a[curr];
}
if(vis[curr]==2){
ans*= po(m-1, cnt);
ans%=mod;
}
else{
int cyc_len = 1;
for(int i=a[curr]; i!=curr; i=a[i]) cyc_len++;
long long tmp = dp[cyc_len][0];
tmp%=mod;
ans*=tmp;
ans%=mod;
ans*=po(m-1, cnt-cyc_len);
ans%=mod;
}
curr = i;
while(vis[curr]!=2){
vis[curr]=2;
curr=a[curr];
}
}
}
ans%=mod;
if(ans<0) ans+=mod;
cout<<ans<<'\n';
}
return 0;
}