PREFIXES - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: beevo
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

3447

PREREQUISITES:

Tries, binary lifting, subtree updates and point queries on a tree.

PROBLEM:

You have N strings, all initially with value 0. Process Q queries/updates:

  • Given (i, k, x), add x to the value of all strings whose length-k prefix equals S_i[1:k].
  • Given (i, k, T), add a new string to the list that equals S_i[1:k] + T with value 0.
  • Given i, print the value of the i-th string.

EXPLANATION:

Let’s tackle a slightly easier version of the problem first, without the second query type - i.e, without adding new strings.

Given that we’re dealing with strings, and in particular prefixes of strings, the obvious data structure that comes to mind is a trie.
So, let’s put all the S_i into a trie.

Now, for the queries:

  • Type 1: given (i, k, x), add x to all strings whose length-k prefix equals S_i[1:k].
    Notice that, if we find the vertex u in the trie corresponding to the position of the k-th element of S_i, all valid strings are exactly those that lie in the subtree of u.
    So, we’d like to add x to the values of all vertices in this subtree.
    Finding u quickly is not very hard; you can even save that information when you create the trie and just look it up when needed.
  • Type 3: given i, print the value of the i-th string.
    Note that the previous query type had us perform several subtree updates; so all we need to do is to quickly get the current value of the endpoint of this string.

After building the trie, our problem has really turned into a problem on a tree: we’d like to add values to certain subtrees, and get the value of certain vertices.
This is a classical problem.

How do I solve it?

Perform an Euler tour of the tree, which “flattens” subtrees into subarrays.

Now, subtree updates and vertex queries instead correspond to subarray updates and point queries on an array - which is a well-known problem that can be solved in several ways.
For example, you can use a segment tree with lazy propagation; or even just a normal segment tree/BIT built on the prefix sums of the array instead.


Now, let’s try to incorporate the second query type as well.
Once again, let’s construct a trie of all the S_i.
The type 2 query (i, k, T) then corresponds to the following:

  • Find the position of the k-th character of S_i in the trie.
  • Then, insert T into the trie, starting from this position.

The second part is easy once we’ve done the first, so let’s focus on the first instead.
Suppose the length of S_i is L.
Then, the k-th character of S_i from the front is the (L-k+1)-th character of S_i from the back.
Moving backwards in the trie is easy; it corresponds to just moving to the parent.
So, if know the vertex u which is the endpoint of string S_i, we’re really just looking for the (L-k+1)-th ancestor of u, which can be found quickly using binary lifting.

However, notice that this process cannot really be done online, since it would change the structure of the trie (and hence the Euler tour will change).

Instead, we can process all insertions offline.

That is,

  • Read all the queries initially, but don’t do anything with them just yet.
  • First, process all type 2 queries only, to build a single large trie.
  • Build the Euler tour of this trie; and a segment tree on the tour.
  • Next, go over the queries once again, but this time:
    • Type 1 queries correspond to subtree additions, as before.
    • Type 3 queries correspond to point queries, as before.
    • Type 2 queries however, correspond to point-set updates.
      Notice that since we’re using the single megatrie, some values might’ve been added to vertices even though they technically don’t exist yet.
      To counter this, you can reset the values of such elements to 0, which can still be done using the segment tree.

TIME COMPLEXITY

\mathcal{O}(S\cdot (26 + \log S)) per testcase, where S is the sum of lengths of all input strings.

CODE:

Author's code (C++)
#include <bits/stdc++.h>

#define el '\n'

typedef long long ll;
typedef long double ld;

#define Beevo ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);

using namespace std;

const int N = 2e5 + 5, ALPHA = 26, LOG = 20;

ll oldVal[N];
int id, timer, sz[N], trie[N][ALPHA], leaf[N], up[N][LOG], in[N], out[N];

struct Query {
    int t, i, k, x;
    string s;
};

struct Node {
   ll sum = 0;
};

struct SegTree {
   ll lazy[N * 4];
   Node tree[N * 4];
   Node neutral = Node();

   Node merge(Node u, Node v) {
       return {u.sum + v.sum};
   }

   void propagate(int x, int lX, int rX) {
       tree[x].sum += lazy[x] * (rX - lX + 1);

       if (lX != rX) {
           lazy[x * 2] += lazy[x];
           lazy[x * 2 + 1] += lazy[x];
       }

       lazy[x] = 0;
   }

   void update(int x, int lX, int rX, int l, int r, int val) {
       propagate(x, lX, rX);

       if (lX > r || rX < l)
           return;

       if (lX >= l && rX <= r) {
           tree[x].sum += 1LL * (rX - lX + 1) * val;

           if (lX != rX) {
               lazy[x * 2] += val;
               lazy[x * 2 + 1] += val;
           }

           return;
       }

       int m = (lX + rX) >> 1;

       update(x * 2, lX, m, l, r, val);
       update(x * 2 + 1, m + 1, rX, l, r, val);

       tree[x] = merge(tree[x * 2], tree[x * 2 + 1]);
   }

   Node query(int x, int lX, int rX, int l, int r) {
       if (lX > r || rX < l)
           return neutral;

       propagate(x, lX, rX);

       if (lX >= l && rX <= r)
           return tree[x];

       int m = (lX + rX) >> 1;

       Node u = query(x * 2, lX, m, l, r);
       Node v = query(x * 2 + 1, m + 1, rX, l, r);

       return merge(u, v);
   }
} st;

int insert(int cur, string &s) {
    int ch;

    for (auto &i: s) {
        ch = i - 'a';

        if (!trie[cur][ch])
            trie[cur][ch] = ++id;

        up[trie[cur][ch]][0] = cur, cur = trie[cur][ch];

        for (int k = 1; k < LOG; k++)
            up[cur][k] = up[up[cur][k - 1]][k - 1];
    }

    return cur;
}

int kth(int cur, int k) {
    for (int i = LOG - 1; i >= 0; i--) {
        if (k & (1 << i))
            cur = up[cur][i];
    }

    return cur;
}

void dfs(int u) {
    in[u] = timer++;

    for (int i = 0; i < ALPHA; i++) {
        if (trie[u][i])
            dfs(trie[u][i]);
    }

    out[u] = timer - 1;
}

void testCase() {
    int n;
    cin >> n;

    string s;
    for (int i = 0; i < n; i++) {
        cin >> s;

        sz[i] = s.size();
        leaf[i] = insert(0, s);
    }

    int q;
    cin >> q;

    vector<Query> v;
    int t, i, k, x, u, cnt = 0;
    for (int j = 0; j < q; j++) {
        s.clear();

        cin >> t >> i;

        i--;

        if (t == 1) {
            cin >> k >> x;

            k--;
        }
        else if (t == 2) {
            cin >> k >> s;

            k--;

            sz[n + cnt] = k + 1 + s.size();
            leaf[n + cnt] = insert(kth(leaf[i], sz[i] - k - 1), s);

            cnt++;
        }

        v.push_back({t, i, k, x, s});
    }

    dfs(0);

    cnt = 0;
    for (auto &j: v) {
        t = j.t, i = j.i, k = j.k, x = j.x, s = j.s;

        if (t == 1) {
            u = kth(leaf[i], sz[i] - k - 1);

            st.update(1, 0, N - 1, in[u], out[u], x);
        }
        else if (t == 2) {
            oldVal[n + cnt] = st.query(1, 0, N - 1, in[leaf[n + cnt]], in[leaf[n + cnt]]).sum;

            cnt++;
        }
        else
            cout << st.query(1, 0, N - 1, in[leaf[i]], in[leaf[i]]).sum - oldVal[i] << el;
    }
}

signed main() {
    Beevo

    int t = 1;
//    cin >> t;

    while (t--)
        testCase();
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <typename T>
struct fenwick {
    int n;
    vector<T> node;

    fenwick(int _n) : n(_n) {
        node.resize(n);
    }

    void add(int x, T v) {
        while (x < n) {
            node[x] += v;
            x |= (x + 1);
        }
    }

    T get(int x) {  // [0, x]
        T v = 0;
        while (x >= 0) {
            v += node[x];
            x = (x & (x + 1)) - 1;
        }
        return v;
    }

    T get(int x, int y) {  // [x, y]
        return (get(y) - (x ? get(x - 1) : 0));
    }

    int lower_bound(T v) {
        int x = 0;
        int h = 1;
        while (n >= (h << 1)) {
            h <<= 1;
        }
        for (int k = h; k > 0; k >>= 1) {
            if (x + k <= n && node[x + k - 1] < v) {
                v -= node[x + k - 1];
                x += k;
            }
        }
        return x;
    }
};

int main() {
    input_checker in;
    const int N = 1e5;
    int n = in.readInt(1, 1e5);
    in.readEoln();
    vector<string> s(n);
    for (int i = 0; i < n; i++) {
        s[i] = in.readString(0, N, in.lower);
        in.readEoln();
    }
    vector<vector<int>> trie(1, vector<int>(26, -1));
    vector<vector<int>> pv(1, vector<int>(20, -1));
    vector<int> tail(n), len(n);
    for (int i = 0; i < n; i++) {
        int pos = 0;
        for (char c : s[i]) {
            if (trie[pos][c - 'a'] == -1) {
                trie[pos][c - 'a'] = (int) trie.size();
                trie.emplace_back(vector<int>(26, -1));
                pv.emplace_back(vector<int>(20, -1));
                int p = pos;
                for (int j = 0; j < 20; j++) {
                    pv.back()[j] = p;
                    if (p != -1) {
                        p = pv[p][j];
                    }
                }
            }
            pos = trie[pos][c - 'a'];
        }
        tail[i] = pos;
        len[i] = (int) s[i].size();
    }
    int q = in.readInt(1, 1e5);
    in.readEoln();
    vector<vector<int>> que(q);
    for (int i = 0; i < q; i++) {
        int op = in.readInt(1, 3);
        in.readSpace();
        int x = in.readInt(1, n);
        x--;
        if (op == 1) {
            in.readSpace();
            int y = in.readInt(1, len[x]);
            in.readSpace();
            int z = in.readInt(1, 1e5);
            que[i] = {1, x, y, z};
        } else if (op == 2) {
            in.readSpace();
            int y = in.readInt(1, len[x]);
            in.readSpace();
            string z = in.readString(1, N, in.lower);
            s.emplace_back(z);
            int pos = tail[x];
            int goup = len[x] - y;
            for (int j = 0; j < 20; j++) {
                if (goup & (1 << j)) {
                    pos = pv[pos][j];
                }
            }
            for (char c : z) {
                if (trie[pos][c - 'a'] == -1) {
                    trie[pos][c - 'a'] = (int) trie.size();
                    trie.emplace_back(vector<int>(26, -1));
                    pv.emplace_back(vector<int>(20, -1));
                    int p = pos;
                    for (int j = 0; j < 20; j++) {
                        pv.back()[j] = p;
                        if (p != -1) {
                            p = pv[p][j];
                        }
                    }
                }
                pos = trie[pos][c - 'a'];
            }
            tail.emplace_back(pos);
            len.emplace_back((int) z.size() + y);
            que[i] = {2, n};
            n++;
        } else {
            que[i] = {3, x};
        }
        in.readEoln();
    }
    int sn = 0;
    for (int i = 0; i < n; i++) {
        sn += (int) s[i].size();
    }
    assert(sn <= N);
    vector<int> order, beg(N), end(N);
    function<void(int)> Dfs = [&](int v) {
        beg[v] = (int) order.size();
        order.emplace_back(v);
        for (int i = 0; i < 26; i++) {
            if (trie[v][i] == -1) {
                continue;
            }
            Dfs(trie[v][i]);
        }
        end[v] = (int) order.size();
    };
    Dfs(0);
    fenwick<long long> f(N + 10);
    vector<long long> t(N);
    for (int i = 0; i < q; i++) {
        int x = que[i][1];
        if (que[i][0] == 1) {
            int y = que[i][2];
            int z = que[i][3];
            int goup = len[x] - y;
            int pos = tail[x];
            for (int j = 0; j < 20; j++) {
                if (goup & (1 << j)) {
                    pos = pv[pos][j];
                }
            }
            f.add(beg[pos], z);
            f.add(end[pos], -z);
        } else if (que[i][0] == 2) {
            t[x] = f.get(beg[tail[x]]);
        } else {
            cout << f.get(beg[tail[x]]) - t[x] << '\n';
        }
    }
    in.readEof();
    return 0;
}