CNTAR - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3

Setter: Abhinav Sharma
Tester: Tejas Pandey
Editorialist: Kanhaiya Mohan

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Combinatorics, Graph Theory

PROBLEM:

You are given an array A of length N such that 1 \leq A_i \leq N and an integer M. Count the number of arrays B of length N, such that

  • B_i \neq B_{A_i}
  • 1 \leq B_i \leq M

Since the answer can be large, print it modulo 10^9 + 7.

QUICK EXPLANATION:

  • Build a graph with N edges as (i, A_i). The graph will have some connected components.
  • Each connected component has exactly one cycle.
  • Calculate the answer for the cycle. Each non-cycle node can be filled in (M-1) ways. The total answer for a component is the product of the answer for cycle as well as that of non-cycle nodes.
  • Final answer is the product of the answers of all connected components.

EXPLANATION:

Observation

We are given that B_i \neq B_{A_i}. To connect the i^{th} position with the {A_i}^{th} position, we can build a graph with N edges. Each edge is generated using the vertices i and A_i.

Considering 1 \leq A_i \leq N, the graph would consist of various connected components. Each of the connected components will have exactly one cycle. In other words, the graph would be a tree with one extra edge which forms a cycle.

Proof

Any connected component of size K will have exactly K edges. For more than one cycle to exist, there has to be atleast K+1 edges. Also, for no cycles, there can’t be more than K-1 edges. Thus, there is exactly one cycle present.

For each connected component, let us look at this cycle part separately.

Subproblem

Problem: Find number of arrays B of length N(\geq 2) such that:

  • B_1 \neq B_N
  • B_i \neq B_{i+1} (1 \leq i < N)
  • 1 \leq B_i \leq M

Solution: We need to find the number of arrays such that no two consecutive elements have the same value (including first and last element) and each element has value in range [1,M]. We can use dynamic programming to calculate this.

Let dp[i][0] denote the number of ways to fill first i elements such that B_j \neq B_{j+1} (1 \leq j < i) and B_1 \neq B_i. Similarly, let dp[i][1] denote the number of ways to fill first i elements such that such that B_j \neq B_{j+1} (1 \leq j < i) and B_1 = B_i.
The answer to our problem is the value dp[N][0].

  • Base Case: dp[1][0] = 0, there is no way in which we can fill the first element such that it is not equal to the first element. Similarly, dp[1][1] = M.
  • Calculating dp[i][1]: Since there is only one way of filling the i^{th} element, dp[i][1] is nothing but dp[i-1][0].
  • Calculating dp[i][0]: If the (i-1)^{th} element is equal to the first element, the i^{th} element can be filled in (M-1) ways (all numbers except the value of the first element). Else, there are (M-2) ways to fill that (all numbers except the value of the first and the (i-1)^{th} element). Thus, dp[i][0] = (dp[i-1][1] \cdot (M-1) + dp[i][0] \cdot (M-2)) \% mod.
Conclusion

Consider a connected component with X nodes. Let this component contain a cycle of length Y. For the cycle part, the answer is dp[Y][0]. For all other nodes, which are not part of the cycle, there are (M-1) ways to fill each one of them. Thus, the answer for this component is (dp[Y][0] \cdot (Y-X)^{M-1}) \%mod.

The final answer would be the product of the answers of all such components.

TIME COMPLEXITY:

The time complexity is O(N) per test case.

SOLUTION:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;


const long long mod = 1000000007;

long long po(long long x, long long n){ 
    long long ans=1;
    while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
    return ans;
}


int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int T=1;

    int sum_len = 0;

    cin >> T;
    while(T--){
        
        int n,m;
        cin>>n>>m;

        int a[n];

        for(int i=0; i<n; i++){
            cin>>a[i];
            a[i]--;
        }

        long long ans = 1;

        int vis[n] = {0};

        for(int i=0; i<n; i++){
            if(!vis[i]){
                int cnt = 0;
                int curr = i;

                while(!vis[curr]){
                    cnt++;
                    vis[curr]=1;
                    curr = a[curr];
                }

                if(vis[curr]==2){
                    ans*=po(m-1, cnt);
                    ans%=mod;
                }
                else{
                    int cyc_len = 1;
                    for(int i=a[curr]; i!=curr; i=a[i]) cyc_len++;

                    long long tmp = po(m-1, cyc_len)+((cyc_len&1)?-1:1)*(m-1);
                    tmp%=mod;
                    ans*=tmp;
                    ans%=mod;

                    ans*=po(m-1, cnt-cyc_len);
                    ans%=mod;
                }

                curr = i;
                while(vis[curr]!=2){
                    vis[curr]=2;
                    curr=a[curr];
                }
            }
        }

        ans%=mod;
        if(ans<0) ans+=mod;

        cout<<ans<<'\n';
    
    }

    return 0;
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long int
#define mod 1000000007

ll mpow(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b&1) res *= a, res %= mod;
        a *= a;
        a %= mod;
        b >>= 1;
    }
    return res;
}

int main() {
	int t;
	cin >> t;
	while(t--) {
		int n, m; cin >> n >> m;
	    int a[n];
	    for(int i = 0; i < n; i++) cin >> a[i], a[i]--;
	    int vis[n];
	    memset(vis, 0, sizeof(vis));
	    ll ans = 1;
	    ll v[n + 1];
	    v[1] = 0, v[2] = m - 1;
	    for(int i = 3; i <= n; i++)
	        v[i] = ((v[i - 1]*(m - 2))%mod + (v[i - 2]*(m - 1))%mod)%mod;
	    for(int i = 0; i < n; i++) {
	        if(vis[i]) continue;
	        int lst = -1, now = i, sz = 0;
	        vector<int> pro;
	        while(!vis[now]) {
	            vis[now] = 1;
	            pro.push_back(now);
	            sz++;
	            int nxt = a[now];
	            if(vis[nxt] == -1) {
	                ans *= mpow(m - 1, sz);
	                ans %= mod;
	                break;
	            }
	            if(vis[nxt]) {
	                int cycle = -1;
	                for(int i = 0; i < sz; i++)
	                    if(nxt == pro[i])
	                        cycle = sz - i;
	                ans *= (m*v[cycle])%mod;
	                ans %= mod;
	                ans *= mpow(m - 1, sz - cycle);
	                ans %= mod;
	                break;
	            }
	            lst = now;
	            now = nxt;
	        }
	        for(int j = 0; j < sz; j++)
	            vis[pro[j]] = -1;
	    }
	    cout << ans << "\n";
	}
	return 0;
}
Editorialist's Solution
#include<bits/stdc++.h>
using namespace std;


const long long mod = 1000000007;
long long dp[100005][2];

void init(int n, int m){
    dp[1][0] = 0;
    dp[1][1] = m;
    for(int i = 2; i<=n; i++){
        dp[i][1] = dp[i-1][0];
        dp[i][0] = ((dp[i-1][1]*(m-1))%mod + (dp[i-1][0]*(m-2))%mod)%mod;
    }
}

long long po(long long x, long long n){ 
    long long ans=1;
    while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
    return ans;
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int T=1;

    int sum_len = 0;

    cin >> T;
    while(T--){
        
        int n,m;
        cin>>n>>m;

        int a[n];

        for(int i=0; i<n; i++){
            cin>>a[i];
            a[i]--;
        }

        long long ans = 1;

        int vis[n] = {0};
        init(n, m);
        
        for(int i=0; i<n; i++){
            if(!vis[i]){
                int cnt = 0;
                int curr = i;

                while(!vis[curr]){
                    cnt++;
                    vis[curr]=1;
                    curr = a[curr];
                }

                if(vis[curr]==2){
                    ans*= po(m-1, cnt);
                    ans%=mod;
                }
                else{
                    int cyc_len = 1;
                    for(int i=a[curr]; i!=curr; i=a[i]) cyc_len++;

                    long long tmp = dp[cyc_len][0];
                    tmp%=mod;
                    ans*=tmp;
                    ans%=mod;

                    ans*=po(m-1, cnt-cyc_len);
                    ans%=mod;
                }

                curr = i;
                while(vis[curr]!=2){
                    vis[curr]=2;
                    curr=a[curr];
                }
            }
        }

        ans%=mod;
        if(ans<0) ans+=mod;

        cout<<ans<<'\n';
    
    }

    return 0;
}
10 Likes

Are the edges directed, or undirected?
If undirected, we can have a situation of duplicate edges (A[x] = y, A[y] = x). So, we may not have exactly N edges.

Yes, edges are undirected, the notion of connected components is defined for undirected graphs.

True, there will be at most N edges. But with fewer edges, the problem is trivialized for a
connected component as the contribution from that connected component would just be M \times (M-1)^{K-1}

4 Likes

Is there any Inclusion-Exclusion solution available for this problem?

Yes, you can use Inclusion-Exclusion to find the number of possible ways of assigning values to the nodes of a cycle. You can refer this

3 Likes

How does the setter’s solution work exactly?
It didn’t use any dp…
Also what does this mean-

long long tmp = po(m-1, cyc_len)+((cyc_len&1)?-1:1)*(m-1);

His solution is based on inclusion-exclusion approach.You can refer this.

oh okay. Thanks

Implemented same solution as setter did in python but still getting tle. Solution Link

math.pow(x,y,z) will efficiently calculate (x**y)%z

Tried replacing operators with pow function, but still getting the same error

It’s quite similar to this problem but nevertheless a nice problem!

1 Like

Such a Nice problem, Felt good after solving it!