ALIKE_THEM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: shubham_grg
Testers: iceknight1093, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Functional graphs

PROBLEM:

You have an array A with N elements from 0 to M, and a permutation P of length N.
For each i such that A_i = 0, you must replace it with some integer from 1 to M.

An array is beautiful if the following holds:

  • For each i from 1 to N in order, replace A_i with A_{P_i}.
  • If this eventually results in all the elements of A becoming equal, A is said to be beautiful.

Count the number of ways of replacing zeros such that the resulting array is beautiful.

EXPLANATION:

Let A be the initial array, and B be the array obtained by performing the given operation on A once.

Then, for each 1 \leq i \leq N, there’s a unique position \text{pos}_i such that B_i = A_{\text{pos}_i}
That is, the value ending up at position i has to come from some unique position, which depends only on what P is.

This is useful information to have, so let’s see how we can compute it.

In particular, we can see that:

  • If P_i \geq i, then \text{pos}_i = P_i, since position i just directly receives the value at position P_i
  • If P_i \lt i, then \text{pos}_i = \text{pos}_{P_i}, because P_i receives its value from somewhere, then i receives this value from position P_i; so their sources are the same.

In this way, computing \text{pos}_i can be done in \mathcal{O}(N) for all i.

Now, let’s attempt to use this information.
Consider a directed graph on N vertices, containing the N edges i \to \text{pos}_i for each i.

Note that each vertex has exactly one outedge, so this is a functional graph.
In particular, we know what a functional graph looks like: several disjoint cycles, with trees hanging off of some vertices of the cycles (each tree is directed towards its corresponding cycle).

Analyzing this information, we can see the following:

  • If a vertex is on a cycle, it will remain on that cycle forever. In particular, each step simply shuffles the values on a cycle within the vertices on it.
    • In particular, values on a cycle never disappear. So, if there are at least two distinct non-zero values on cycles, the answer is immediately 0.
    • If there is a non-zero element on a cycle, all zeros on cycles must be set to this value; so they are uniquely determined.
    • If all the elements on cycles are zeros, then there are M choices for what to choose for them.
  • If a vertex is not on a cycle, it will eventually receive a value from a cycle; so its initial value doesn’t matter at all.
    • If there are x indices such that A_i = 0 that are not on a cycle, they thus contribute M^x to the answer.

Together, this gives us a simple formula for the answer:

  • If there are two distinct non-zero elements on cycles, the answer is 0.
  • Otherwise, let there be x indices not lying on cycles such that A_i = 0 for these indices.
  • If any cycle contains a non-zero element, the answer is M^x.
  • Otherwise, the answer is M^{x+1}.

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;

const int MOD=1e9+7;

#define ll long long int






int solve(int n, int m, vector<int>p, vector<int>a)
{
	bool fetch[n+1]{};
	int curr=0, c=0, x=0;
	for(int i=1; i<=n; i++) if(p[i]>=i) fetch[p[i]]=true, x++;
	
	bool zero=false;
	for(int i=1; i<=n; i++)
	{
	    if(fetch[i] && a[i]==0) zero=true;
		if(fetch[i] && a[i])
		{
			if(curr && (curr^a[i])) return 0;
			curr=a[i];
		}
		if(!fetch[i] && a[i]) c++;
	}

	int exp=n-x+1-c;
	if(curr) exp--;
	

	ll ans=1;
	while(exp--) ans=(ans*m)%MOD;
	
	if(curr>m && zero)
	{
        return 0;   
	}
	return ans;
}

int main()
{

	int t; cin>>t;
	
	assert(t<=1e5);
	int total_n=0;

	while(t--)
	{
		int n, m; cin>>n>>m;
		total_n+=n;
		assert(n>=1 && n<=2e5);
		assert(m>=1 && m<=1e9);
		
		vector<int>p(n+1), a(n+1);

		for(int i=1; i<=n; i++) cin>>p[i];
		for(int i=1; i<=n; i++) cin>>a[i];
		
		bool visi[n+1]{};
		for(int i=1; i<=n; i++)
		{
		    visi[p[i]]=true;
		}
		
		for(int i=1; i<=n; i++)
		{
		    assert(a[i]>=0 && a[i]<=1e9);
		    assert(p[i]);
		}

		cout<<solve(n, m, p, a)<<endl;	
	}
	assert(total_n<=2e5);

	
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct dsu {
    vector<int> p;
    vector<int> sz;
    int n;

    dsu(int _n) : n(_n) {
        p = vector<int>(n);
        iota(p.begin(), p.end(), 0);
        sz = vector<int>(n, 1);
    }

    inline int get(int x) {
        if (p[x] == x) {
            return x;
        } else {
            return p[x] = get(p[x]);
        }
    }

    inline bool unite(int x, int y) {
        x = get(x);
        y = get(y);
        if (x == y) {
            return false;
        }
        p[x] = y;
        sz[y] += sz[x];
        return true;
    }

    inline bool same(int x, int y) {
        return (get(x) == get(y));
    }

    inline int size(int x) {
        return sz[get(x)];
    }

    inline bool root(int x) {
        return (x == get(x));
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = (int) 1e9 + 7;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readSpace();
        int m = in.readInt(1, 1e9);
        in.readEoln();
        sn += n;
        vector<int> p = in.readInts(n, 1, n);
        in.readEoln();
        vector<int> a = in.readInts(n, 0, m);
        in.readEoln();
        for (int i = 0; i < n; i++) {
            p[i]--;
        }
        dsu uf(n);
        for (int i = 0; i < n; i++) {
            uf.unite(i, p[i]);
        }
        int free = (int) (count(a.begin(), a.end(), 0));
        int of = free;
        int z = -1;
        for (int i = 0; i < n; i++) {
            if (!uf.root(i)) {
                continue;
            }
            vector<int> t;
            t.emplace_back(i);
            while (true) {
                int x = p[t.back()];
                t.emplace_back(x);
                if (x == t[0]) {
                    break;
                }
            }
            for (int j = 0; j < (int) t.size() - 1; j++) {
                if (t[j] <= t[j + 1]) {
                    if (a[t[j + 1]] == 0) {
                        free--;
                    } else {
                        if (z == -1) {
                            z = a[t[j + 1]];
                        } else if (z != a[t[j + 1]]) {
                            free = -1;
                            break;
                        }
                    }
                }
            }
            if (free == -1) {
                break;
            }
        }
        if (free == -1) {
            cout << 0 << '\n';
        } else {
            if (z == -1 && free < of) {
                free++;
            }
            cout << power(m, free) << '\n';
        }
    }
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
	n, m = map(int, input().split())
	p = [0] + list(map(int, input().split()))
	a = [0] + list(map(int, input().split()))
	
	indeg, outedge = [0]*(n+1), [0]*(n+1)
	ans = 1

	for i in range(1, n+1):
		if p[i] >= i: outedge[i] = p[i]
		else: outedge[i] = outedge[p[i]]
		indeg[outedge[i]] += 1
	
	queue = []
	for i in range(1, n+1):
		if indeg[i] == 0: queue.append(i)
	
	for u in queue:
		if a[u] == 0: ans = (ans * m) % mod
		indeg[outedge[u]] -= 1
		if indeg[outedge[u]] == 0: queue.append(outedge[u])
	
	cyclevals = set()
	for u in range(1, n+1):
		if indeg[u] == 0: continue
		cyclevals.add(a[u])

	if len(cyclevals) > 2: ans = 0
	elif len(cyclevals) == 2:
		if 0 not in cyclevals: ans = 0
	else:
		if 0 in cyclevals: ans = (ans * m) % mod
	print(ans)
1 Like

Can you explain the “Unique position” part a little more, Thankyou.

It becomes pretty obvious if you look at the process.

If P_i \geq i, position i will receive the value that was initially at position P_i (since position P_i hasn’t changed yet); this is the unique position for position i.
If P_i \lt i, position i will receive the value that’s currently at position P_i, but that value here has already changed.
Since it’s changed, P_i has to have received its value from somewhere, and i will receive the original value of this same source.

1 Like

Thankyou, It’s clear now.

Can anyone give me a counterexample for my code

C = 10**9+7

def pow(a,b):return a**b%C

for i in range(int(input())):
    n,m = map(int,input().split())
    t = list(map(int,input().split()))
    array = list(range(n))
    for _ in range(n): array[_] = array[t[_]-1]
    same = [*{*array}]
    original = list(map(int,input().split()))
    stuff = [original[_] for _ in same]
    stuff = [_ for _ in stuff if _!=0]
    if stuff != [] and stuff != [stuff[0]] * len(stuff): print(0); continue
    print(pow(m%C,(n-len(same)) if stuff != [] else (n-len(same)+1)))

I can’t seem to find a counterexample. It is passing checkpoint1 and nothing else. (WA)

Can anyone please through some light on fuctional graph ?
In this sentence : "In particular, we know what a functional graph looks like: several disjoint cycles, with trees hanging off of some vertices of the cycles (each tree is directed towards its corresponding cycle) , I dont understand "with trees hanging off of some vertices of the cycles ".

Please explain.

Do you know of the cycle decomposition of permutations? If not, I recommend reading about those first; for example in this blog (it’s a bit long but you only need the first three sections).

If you do know of them (and understand why permutations decompose into disjoint cycles), the same idea applies here.

A ‘functional graph’ is called that way because it represents a function from \{1, 2, 3, \ldots, N\} to itself (the edges are literally just i \to f(i)). Permutations are a specific class of such function, so functional graph essentially generalize cycle decompositions.

Also, try taking some functions with, say, N = 10 and drawing out what their functional graphs look like. It’ll give you an idea as to their structure.
If you’re unclear about certain definitions or claims, working out small examples always helps!

1 Like