CIRCOLOR - Editorial

PROBLEM LINK:

Practice
Div-2 Contest
Div-1 Contest

Author: Ildar Gainullin
Tester: Nikolay Budin
Editorialist: Alei Reyes

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Dynamic Programming, Sqrt-decomposition

PROBLEM:

You are given n sets of integers A_1, A_2, \ldots, A_n. The sum of sizes of all sets is m.

Find the number of ways to choose numbers a_1, a_2, \ldots, a_n, such that a_i \in A_i, a_i \neq a_{i + 1} and a_1 \neq a_n.

QUICK EXPLANATION:

  • The problem for a line can be solved with dynammic programming in O(m).
  • If n \leq \sqrt{m}, then count the number of ways by continously replacing pairs of adjacent sets with their intersection.
  • If n \gt \sqrt{m}, then there is a pair of adjacent sets with intersection of size at most \sqrt{m}. Break the circle in the position of that pair of sets, and solve it as a line.

EXPLANATION:

First let’s solve the problem when the sets are not in a circle, but in a line (so the first and last sets are not adjacent). Let f_{i,x} be the number of ways of choosing elements from the first i sets, given that we are forced to choose x from the i-th set. We can calculate f_{i,x} using the values from f_{i-1} as follows:

f_{i,x}=\sum_{y} f_{i-1,y} - f_{i-1,x}

From all the ways of choosing elements from the previous i-1 sets, we are removing the number of ways that chooses x as the element from set i-1 (because we can’t have to adjacent elements with the same value). This dynamic programming runs in O(m), it is implemented in the function calc in the tester solution.

Let L be the answer when the sets are in a line, i.e L= \sum_x f_{N,x}. Let’s figure out how to extend it for the case when the sets are in a circle. L is an incorrect answer for a circle because it also counts some sequence of integers x_1, x_2, ..., x_n (x_i \in A_i), where x_1=x_n (the first element equal to the last element).

Let I be the intersection of the first and last sets. To fix L we can just subtract all possible sequences of the form y, x_2, x_3,...,x_{n-1},y, where y \in I i.e sequences with the same first and last element. We can count all such sequences by making A_1=A_N=\{y\} and running the previous dp.

The problem is that such algorithm runs in O(|I| \cdot N), and |I| can be very big. So the next natural question is to determine under which conditions is possible to find a bound for |I|.

It turns out that if the number of sets is greater than \sqrt{m}, then there exists a pair of adjacent sets with size at most \sqrt{m}, and we can solve such case in O(m \cdot \sqrt{m}).

The remaining case is when number of sets is at most \sqrt{m}. Let C(A_1,A_2,...,A_N) be the number of ways of choosing one integer from each set without taking the same element from adjacent sets (A_1 is adjacent with A_N). Similarly let L(A_1,...,A_N) be the number of ways when the sets are in a line (A_1 is not adjacent with A_N). Note that C stands for circle and L for line. The following recurrence holds:

C(A_1,...,A_N)=L(A_1,...,A_N) - C(A1 \cap A_2, A_3,...,A_{N})

Each step of the recurrence merges two adjacent sets to force the algorithm to choose an element from the intersection. In each iteration the number of sets decreases by one, therefore the algorithm runs N times, and since N \lt \sqrt{m}, the overall running time is again O(m \cdot \sqrt{m}) .

SOLUTIONS:

Tester's Solution
const int MOD = 998'244'353;

void add(int& a, int b) {
  a += b;
  if (a >= MOD) {
    a -= MOD;
  }
}

int sum(int a, int b) {
  add(a, b);
  return a;
}

int calc(vector<vector<int>> const& arr) {
  vector<pii> vars, next;
  for (int num : arr[0]) {
    vars.push_back({num, 1});
  }
  for (int i = 1; i < szof(arr); ++i) {
    int tot = 0;
    for (auto p : vars) {
      add(tot, p.ss);
    }
    next.clear();
    int pos = 0;
    for (int num : arr[i]) {
      while (pos < szof(vars) && vars[pos].ff < num) {
        ++pos;
      }
      int cur = tot;
      if (pos < szof(vars) && vars[pos].ff == num) {
        add(cur, MOD - vars[pos].ss);
      }
      next.push_back({num, cur});
    }
    swap(next, vars);
  }

  int ret = 0;
  for (auto p : vars) {
    add(ret, p.ss);
  }

  return ret;
}

vector<int> intersection(vector<int> const& a, vector<int> const& b) {
  vector<int> ret(szof(a));
  ret.erase(set_intersection(a.begin(), a.end(), b.begin(), b.end(), ret.begin()), ret.end());
  return ret;
}

void solve() {
  int n;
  cin >> n;
  vector<vector<int>> arr;
  for (int i = 0; i < n; ++i) {
    int k;
    cin >> k;
    arr.push_back({});
    for (int j = 0; j < k; ++j) {
      int num;
      cin >> num;
      --num;
      arr[i].push_back(num);
    }
    sort(arr[i].begin(), arr[i].end());
  }

  const int Q = 450;
  if (n <= Q) {
    int ans = 0;
    bool sign = false;
    while (szof(arr) > 1) {
      int cur = calc(arr);
      if (sign) {
        cur = MOD - cur;
      }
      add(ans, cur);

      arr[0] = intersection(arr[0], arr.back());
      arr.pop_back();

      sign ^= 1;
    }

    cout << ans << "\n";
  } else {
    int best = INF;
    int mem = -1;
    for (int i = 0; i < n; ++i) {
      int prev = (i + n - 1) % n;
      auto tmp = intersection(arr[i], arr[prev]);
      if (szof(tmp) < best) {
        best = szof(tmp);
        mem = i;
      }
    }

    rotate(arr.begin(), arr.begin() + mem, arr.end());

    int ans = calc(arr);
    auto intr = intersection(arr.front(), arr.back());
    for (int val : intr) {
      arr.front().clear();
      arr.front().push_back(val);
      arr.back().clear();
      arr.back().push_back(val);
      add(ans, MOD - calc(arr));
    }

    cout << ans << "\n";
  }
}
Setter's Solution
int main() {
  int max_c = 0;
  int n;
  cin >> n;
  vector <vector <int> > a(n);
  for (int i = 0; i < n; i++) {
    int x;
    cin >> x;
    while (x--) {
      int t;
      cin >> t;
      t--;
      max_c = max(max_c, t);
      a[i].push_back(t);
    }
  }
  int x = 0;
  for (int i = 0; i < n; i++) {
    if (a[i].size() < a[x].size()) x = i;
  }
  rotate(a.begin(), a.begin() + x, a.end());
  vector <int> dp(max_c + 1);
  vector <int> used(max_c + 1, -1);
  const int mod = 998244353;
  auto solve = [&] (vector <vector <int> > a, int start, int first) {
    if (a.size() == 1) return 0;
    for (int i = 0; i <= max_c; i++) {
      used[i] = -2, dp[i] = 0;
    }
    if (first != -1)
      dp[first] = 1;
    int tot = 1;
    for (int i = start; i < (int) a.size(); i++) {
      int new_tot = 0;
      for (int x : a[i]) {
        int val = tot - dp[x];
        if (val < 0) val += mod;
        dp[x] = val;
        new_tot += val;
        if (new_tot >= mod) new_tot -= mod;
        used[x] = i;
      }
      if (i) {
        for (int x : a[i - 1]) {
          if (used[x] != i) {
            dp[x] = 0;
          }
        }
      }
      tot = new_tot;
    }
    int sum = 0;
    for (int i = 0; i <= max_c; i++) {
      if (i != first) {
        sum += dp[i];
        if (sum >= mod) sum -= mod;
      }
    }
    return sum;
  };
  if ((int) a[0].size() <= 1000) {
    int sum = 0;
    for (int x : a[0]) {
      sum += solve(a, 1, x);
      if (sum >= mod) sum -= mod;
    }
    cout << sum << '\n';
  } else {
    auto intersect = [&] (vector <int> a, vector <int> b) {
      vector <char> used(max_c + 1);
      for (int x : a) used[x] = true;
      vector <int> go;
      for (int x : b) {
        if (used[x]) {
          go.push_back(x);
        }
      }
      return go;
    };
    int sum = 0;
    int its = (int) a.size();
    for (int i = 0; i < its; i++) {
      if (i % 2 == 0) {
        sum += solve(a, 0, -1);
        if (sum >= mod) sum -= mod;
      } else {
        sum -= solve(a, 0, -1);
        if (sum < 0) sum += mod;
      }
      if (i == its - 1) break;
      a[0] = intersect(a[0], a[(int) a.size() - 1]);
      a.pop_back();
    }
    cout << sum << '\n';
  }
}

VIDEO EDITORIAL:

2 Likes