UWCOI20C - Mercury Poisoning - Editorial

C. Mercury Poisoning


Author: Jishnu Roychoudhury (astoria)
Tester: Taranpreet Singh (taran_1407)
Editorialist: Jishnu Roychoudhury (astoria)




Union Find Disjoint Set


There is a grid, with each cell having a height value. There are Q queries, each with a starting cell and power value. Mercury begins at the starting cell and can propagate to any adjacent cell with height strictly less than the power value. For each of the Q queries, output how many cells the mercury eventually reaches.


Process the queries offline in increasing order of power value, using UFDS data structure to count the number of reachable cells from the starting cell at that power value.


Subtask 1

For this subtask, we note that all heights are the same. This means that the mercury will either reach the whole grid or reach no squares at all. We can use an if statement to check if the query power value is greater than the height of the whole grid. If it is, output H*W. Otherwise, output 0.

Subtask 2

In this subtask, the grid is a line and Q \leq 1000. This means that we can solve each query in linear time. For each query, note that a contiguous segment of the array will be flooded. So we can linear search beginning from the starting cell twice: once to the left and once to the right. Once we reach a cell with height that is greater than or equal to the power value, we know that the mercury cannot go any further and we stop. Then output r-l+1, where r is the furthest index we reached to the right and l is the furthest index we reached to on the left.

Subtask 3

In this subtask, the starting cell is the same for all queries. Observe that if we sort the queries in ascending order of power value, all the cells that are reached in one query will be reached for all future queries as well. Therefore, we do not need to repeat cells which we have already visited. We can use a min-priority queue of adjacent cells sorted by height, and visit cells until we have no more adjacent cells which have lower height than the current query. Once we have done this, we can answer the query and move on to the next query.

Subtask 4

The full solution is somewhat similar. Sort the queries again in ascending order of power value. Now we can iterate through the cells of the grid in sorted order of height, adding connections between them in a Union Find Disjoint Set data structure. Store the size of the connected component at the root. When there are no more cells with height less than the current query, we answer the current query - the answer is the size of the connected component that contains the starting cell - and we move on to the next query.

Note that if the height of the starting cell is less than the power queried, you must output 0.


Setter's Solution
#include <bits/stdc++.h>
using namespace std;

pair<int,int> parent[1005][1005];
int sz[1005][1005];
int h,w,q;
int grid[1005][1005];
bool vis[1005][1005];

pair<int,int> root(int u, int v){
    if (parent[u][v].first == -1) return make_pair(u,v);
    return parent[u][v] = root(parent[u][v].first, parent[u][v].second);

void connect(int x1, int y1, int x2, int y2){
    pair<int,int> T1 = root(x1,y1), T2 = root(x2,y2);
    if (T1 != T2){ parent[T1.first][T1.second] = T2; sz[T2.first][T2.second] += sz[T1.first][T1.second];}

void sol(){
    cin >> h >> w >> q;
    for (int i=1; i<=h; i++){
        for (int j=1; j<=w; j++){
            cin >> grid[i][j];
    pair<int,pair<int,int> > srt[h*w];
    for (int i=0; i<h; i++){
        for (int j=0; j<w; j++){
            srt[(i*w)+j] = make_pair(grid[i+1][j+1],make_pair(i+1,j+1));
    pair<pair<int,int>,pair<int,int> > queries[q];
    for (int i=0; i<q; i++){
        cin >> queries[i].second.first >> queries[i].second.second >> queries[i].first.first;
        queries[i].first.second = i;
    for (int i=0; i<1005; i++){
        for (int j=0; j<1005; j++){
            sz[i][j] = 1;
    for (int i=0; i<1005; i++){
        for (int j=0; j<1005; j++){
            parent[i][j] = make_pair(-1,-1);
    int ctr=0;
    int ans[q];
    for (int i=0; i<q; i++){
        while (ctr < (h*w)){
            if (srt[ctr].first < queries[i].first.first){
                int x=srt[ctr].second.first,y=srt[ctr].second.second;
                if (vis[x-1][y]) connect(x-1,y,x,y);
                if (vis[x+1][y]) connect(x+1,y,x,y);
                if (vis[x][y-1]) connect(x,y-1,x,y);
                if (vis[x][y+1]) connect(x,y+1,x,y);
                vis[x][y] = 1;
            else break;
        int u=queries[i].second.first,v=queries[i].second.second;
        if (!vis[u][v]){ ans[queries[i].first.second] = 0; continue;}
        pair<int,int> rt = root(u,v);
        ans[queries[i].first.second] = sz[rt.first][rt.second];
    for (int i=0; i<q; i++) cout << ans[i] << '\n';

int main(){
    cin.tie(NULL); cout.tie(NULL);
    int t; cin >> t; while(t--) sol();
Tester's Solution
import java.math.BigInteger;
import java.util.*;
import java.io.*;
import java.text.*;
public class Main{
    //Into the Hardware Mode
    int[][] D = new int[][]{
        {-1, 0},{1, 0}, {0, -1}, {0, 1}
    void pre() throws Exception{}
    void solve(int TC)throws Exception {
        int H = ni(), W = ni(), Q = ni();
        int[][] P = new int[H][W];
        for(int i = 0; i< H; i++)
            for(int j = 0; j< W; j++)
                P[i][j] = ni();
        int[] set = new int[H*W], size = new int[H*W];
        Integer[] ord = new Integer[H*W];
        for(int i = 0; i< H*W; i++){
            set[i] = i;
            ord[i] = i;
            size[i] = 1;
        Arrays.sort(ord, (Integer i1, Integer i2) -> Integer.compare(P[i1/W][i1%W], P[i2/W][i2%W]));
        int p = 0;
        int[][] qu = new int[Q][];
        for(int i = 0; i< Q; i++)qu[i] = new int[]{i, ni()-1, ni()-1, ni()};
        Arrays.sort(qu, (int[] i1, int[] i2) -> Integer.compare(i1[3], i2[3]));
        int[] ans = new int[Q];
        for(int i = 0; i< Q; i++){
            while(p< ord.length && P[ord[p]/W][ord[p]%W] < qu[i][3]){
                int r = ord[p]/W, c = ord[p]%W;
                for(int[] d:D){
                    int rr = r+d[0], cc = c+d[1];
                    if(rr < 0 || cc < 0 || rr >= H || cc >= W || P[rr][cc] >= qu[i][3] || find(set, rr*W+cc) == find(set, r*W+c))continue;
                    size[find(set, r*W+c)] += size[find(set, rr*W+cc)];
                    set[find(set, rr*W+cc)] = find(set, r*W+c);
            if(P[qu[i][1]][qu[i][2]] < qu[i][3])ans[qu[i][0]] = size[find(set, qu[i][1]*W+qu[i][2])];
        for(int i:ans)pn(i);
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    void exit(boolean b){if(!b)System.exit(0);}
    long IINF = (long)1e15;
    final int INF = (int)1e9+2, MX = (int)2e6+5;
    DecimalFormat df = new DecimalFormat("0.00000000000");
    double PI = 3.141592653589793238462643383279502884197169399, eps = 1e-7;
    static boolean multipleTC = true, memory = true, fileIO = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
            in = new FastReader("");
            out = new PrintWriter("");
        }else {
            in = new FastReader();
            out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        for(int t = 1; t<= T; t++)solve(t);
    public static void main(String[] args) throws Exception{
        if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
        else new Main().run();
    int find(int[] set, int u){return set[u] = (set[u] == u?u:find(set, set[u]));}
    int digit(long s){int ans = 0;while(s>0){s/=10;ans++;}return ans;}
    long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
    int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}
    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
            return st.nextToken();
        String nextLine() throws Exception{
            String str = "";
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            return str;

I solved similar problem few days ago. https://tlx.toki.id/contests/troc-10-div-2/problems/E


Why does this TLE for the last case? https://www.codechef.com/viewsolution/29896869

Using map is making the solution slow. Replacing it with a vector works fast enough. Your Solution using vector.

Why cant we solve this problem using BFS?

1 Like

Linear time per query is too slow except for subtask 2.

I have solved similar problem (https://www.codechef.com/problems/SNAKEEAT) yesterday only.
I guessed this can be solved Offline, I couldn’t solve it because of lack of knowledge on finding reachable cells :woozy_face:

Can you elaborate?

That’s an unfortunate coincidence. To be clear, the problem idea was made a while back so it’s not as if we were “inspired” by the TOKI problem. However, it’s possible that this can happen because the problem is (somewhat intentionally) standard.

1 Like

Simple BFS from each starting cell runs in O(HWQ) (O(HW) per query). With the constraints, this is 4*10^{11} operations, which is too slow. Using UFDS reduces complexity to O(HWlog(HW) + QlogQ).

Ok but for subtask 2 how can you say that answer is r-l+1?

Our two linear searches simulate the spread of mercury on a line. r is the furthest right cell the mercury reaches, l is the furthest left cell the mercury reaches. The interval [l,r] is covered by mercury. The number of cells in this interval is r-l+1, which is the answer.

BFS/DFS can also solve this subtask, but it is overkill.

Interesting, I knew maps were slow, but surprised that the factor is over 5x.

I am to trying to understand the code as well as editorial, but not getting it how overall code works

(I assume knowledge of UFDS here. If you do not understand it, then read the Wikipedia link attached and/or other resources.)

root and connect are UFDS functions. connect literally “connects” two nodes. In terms of a graph, it is the same as adding an edge between two nodes. It is also known as “Union”. root returns the node that serves as the root of the connected component. parent[][] is a UFDS helper array.

srt[][] stores the grid sorted by height, as described in the editorial. queries[] stores the queries sorted by power value.

I initialise the size of each connected component to 1 (sz[][]). Then I answer some queries. I use a while loop that continues until the height value of the next cell is not lower than the power value of the current query. When I am considering a cell, I add it to the graph. I do this by connecting it to all its neighbours. However, you have to remember that you can only add an edge between two adjacent cells if both have height lower than the current query’s power value. So I check whether I can connect to the adjacent cell with if statements. There are 4 adjacent cells, and I do this for each of them.

After doing all of this, I answer the query by finding the value of sz[][] at the root. In my UFDS, I maintain it so that the value of sz[][] at the root is the number of cells in the connected component. I do this within the function connect. When I connect two cells (and equivalently, connect two connected components) I increase the size of the new root (which is the root of one of the old connected components that is being merged) by the size of the root of the other connected component. Therefore, the value of sz[][] at the new root is equal to the size of the total connected component.

It’s somewhat difficult to explain this from first principle. It is important to practice UFDS problems to get better at them and gain greater understanding!

1 Like

hey thanks for replying, now I understand it

hey can anybody help me to resolve this segmentation fault?, i m not getting it
this is my code
#include <bits/stdc++.h>
using namespace std;

pair<int, int> parent[1000][1000];
int sz[1000][1000];

pair<int, int> findParent(int x, int y)
if (parent[x][y].first == -1)
return make_pair(x, y);
return parent[x][y] = findParent(parent[x][y].first, parent[x][y].second);

void unionSet(int x1, int y1, int x2, int y2, int h, int w)
if (x1 < 0 || x1 > h || y1 < 0 || y1 > w)
pair<int, int> set1 = findParent(x1, y1), set2 = findParent(x2, y2);
if (set1 == set2)
parent[set1.first][set1.second] = set2;
sz[set2.first][set2.second] += sz[set1.first][set1.second];

void solve()
int h, w, q;
cin >> h >> w >> q;
int grid[h][w], k = 0;
pair<int, pair<int, int>> height[h * w];
for (int i = 0; i < h; i++)
for (int j = 0; j < w; j++)
cin >> grid[i][j];
height[k++] = make_pair(grid[i][j], make_pair(i, j));
pair<pair<int, int>, pair<int, int>> queries[q];
for (int i = 0; i < q; i++)
int r, c, p;
cin >> r >> c >> p;
r–, c–;
queries[i].first.first = p;
queries[i].first.second = i;
queries[i].second.first = r;
queries[i].second.second = c;
int visited[h][w];
for (int i = 0; i < h; i++)
for (int j = 0; j < w; j++)
parent[i][j] = make_pair(-1, -1);
sz[i][j] = 1;
visited[i][j] = 0;
sort(height, height + h * w);
sort(queries, queries + q);
int ctr = 0, ans[q];
for (int i = 0; i < q; i++)
while (ctr < (h * w))
if (height[ctr].first < queries[i].first.first)
int x = height[ctr].second.first, y = height[ctr].second.second;
if (visited[x + 1][y])
unionSet(x + 1, y, x, y, h, w);
if (visited[x - 1][y])
unionSet(x - 1, y, x, y, h, w);
if (visited[x][y - 1])
unionSet(x, y - 1, x, y, h, w);
if (visited[x][y + 1])
unionSet(x, y + 1, x, y, h, w);
visited[x][y] = 1;
int r = queries[i].second.first, c = queries[i].second.second, index = queries[i].first.second;
if (!visited[r][c])
ans[index] = 0;
pair<int, int> rt = findParent(r, c);
ans[index] = sz[rt.first][rt.second];
for (int i = 0; i < q; i++)
cout << ans[i] << “\n”;

int main()
int t;
cin >> t;
while (t–)
return 0;