UWCOI20G - Optimal Memory Address - Editorial

G. Optimal Memory Address


Author: Jishnu Roychoudhury (astoria)
Testers: Jatin Yadav (jtnydv25), Taranpreet Singh (taran_1407)
Editorialist: Jishnu Roychoudhury (astoria)




Graph Theory


You are given an array of numbers in prime-factorised form. Define the distance between two numbers, dist(X,Y) = size(X) + size(Y) - 2*V, where size(X) is the sum of the exponents in the prime factorised form of X, and V is the length longest common shared prefix of S(X) and S(Y). S(X) is a sequence constructed from the prime factorisation of X (refer to problem statement).

Find the number C that minimises \sum_{i=1}^{N} dist(C,A[i]), and output that sum.


We can consider a specific tree where the distance between two nodes in the tree is the same as the distance between the two numbers. Then we can quickly find the distance sum for all nodes using an observation, and compress the tree further to get full score.


There are multiple ways of solving this problem on the implementation level, but all of them use similar ideas. Before we consider subtask 1, let’s make some basic observations about the problem.

Consider an unweighted infinite tree rooted on 1 containing each of the natural numbers as a node. Let the parent of node x be x/f(x), where f(x) is the largest prime factor in the prime factorisation of x. Now, we notice that the distance between two numbers in the problem statement is the same as the distance between two nodes in this tree. One way of coming up with the idea of the tree is noticing that the given distance formula is similar to the formula for the distance between two nodes in a tree.

For clarity, call nodes which contain one of the input elements “special nodes” and the node which minimises the sum of distances the “best node”.

Subtask 1

Obviously, we can’t compute an entire infinite tree. Instead, notice that only nodes on the path from the root to one of the input elements can be the best node. So we only need to consider these nodes, which number at most 10^6 when Q=1. One way of constructing the tree is to use a trie. It remains to quickly calculate the sum of distances from an arbitrary node. There are multiple ways to do this, but we will just describe the way the editorialist did it. First we compute the sum of distances from node 1. This is just the sum of all the Q (exponent) inputs. Now we want to find the sum of distances from node Y given that we know the sum of distances from the parent node X. Note that the change is equal to the number of nodes we are going closer to minus the number of nodes we are going further from. The number of nodes we are going closer to is the number of special nodes in the subtree of Y, which can be computed for all nodes with a DFS. The number of nodes we are going further from is the rest of the nodes. In this way we can compute the sum of distances from each node in the tree in O(1). Then just find the minimum.

Please note that it is necessary to remember that two memory addresses in the input may be equal, or you will get Wrong Answer.

Subtask 2

The previous solution is too slow, because now Q is not necessarily 1, meaning that the number of nodes under the previous solution is now at most 10^{15}. To solve this subtask, we further observe that we only need the nodes of the virtual tree of the Subtask 1 tree. When we replace a path of length x with one edge, that one edge should have weight x (our tree is now weighted).

There are at most 10^6 of these nodes. We can construct this virtual tree using a trie (though there are other ways). Since our tree is weighted, we have to slightly modify our formula for distance sums from Subtask 1 (multiply the change by x where x is the weight of the edge between X and Y).

A tricky test case to consider is one where all of the inputs have the same prime factor and different exponent. This may cause O(N^2) blowup on tree creation. One way of solving this is to add special nodes to the tree in descending sorted order of exponent. See code for details.


Setter's Solution
#include <bits/stdc++.h>
using namespace std;

#define int long long

int n;
int dist_sum;

struct trie{
    struct node{
        int isSpecial;
        int subtree_size;
        map<int,pair<int,node*> > ch;
        node(){isSpecial = 0;}
        node* add(int p, int q){
            auto nx = ch.find(p);
            if (nx == ch.end()){
                node* nn = new node();
                ch[p] = make_pair(q,nn);
                return nn;
                int edg_len = nx->second.first;
                node* nxt = nx->second.second;
                if (edg_len == q) return nxt;
                else if (q > edg_len){
                    return nxt->add(p,q-edg_len);
                    node* nn = new node();
                    ch[p] = make_pair(q,nn);
                    nn->ch[p] = make_pair(edg_len-q,nxt);
                    return nn;
    node* root;
    trie(){root = new node();}
    void add (vector<pair<int,int> > nv){
        node* cn = root;
        for (auto v : nv){
            cn = cn->add(v.first,-v.second);
    int dfs(node* crr){
        int sm = 0;
        for (auto vv : crr->ch){
            sm += dfs(vv.second.second);
        crr->subtree_size = sm;
        return sm;
    void dfs_util(){
    int solve(){
        node* cn = root;
        int ans = dist_sum;
        while (true){
            if (cn->ch.empty()) return ans;
            bool is_better=0;
            for (auto vv : cn->ch){
                node* look = vv.second.second;
                int len = vv.second.first;
                if (look->subtree_size > (n/2)){
                    ans -= (look->subtree_size*len);
                    int ky = n-(look->subtree_size);
                    ans += (ky*len);
                    cn = look; is_better = 1; break;
            if (!is_better) return ans;

int32_t main(){
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    dist_sum = 0;
    cin >> n;
    trie T;
    vector<pair<int,int> > stu[n];
    for (int i=0; i<n; i++){
        int k; cin >> k; vector<pair<int,int> > cr;
        int p,q;
        for (int j=0; j<k; j++){
            cin >> p >> q;
            dist_sum += q;
    for (int i=0; i<n; i++) T.add(stu[i]);
    cout << T.solve();
Tester's Solution (jtnydv25)
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define sd(x) scanf("%d", &(x))
#define pii pair<int, int>
#define F first
#define S second
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
#define ld long double

template<class T,class U>
ostream& operator<<(ostream& os,const pair<T,U>& p){
    os<<"("<<p.first<<", "<<p.second<<")";
    return os;

template<class T>
ostream& operator <<(ostream& os,const vector<T>& v){
    for(int i = 0;i < (int)v.size(); i++){
        if(i)os<<", ";
    return os;

#ifdef LOCAL
#define cerr cout

#define TRACE

#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
    cerr << name << " : " << arg1 << std::endl;
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
    const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
#define trace(...)

struct trie{
    struct node{
        map<int, pair<int, node*>> mp;
        int num;
        ll depth;
        node(ll d = 0){
            num = 0;
            depth = d;
        node* add(int p, int q){
            auto it = mp.find(p);
            if(it == mp.end()){
                node* new_node = new node(depth + q);
                mp[p] = {q, new_node};
                return new_node;
            } else{
                node* nxt = it->second.second;
                int t = it->second.first;
                if(t == q){
                    return nxt;
                } else if(q < t){
                    node* new_node = new node(depth + q);
                    mp[p] = {q, new_node};
                    new_node->mp[p] = {t - q, nxt};
                    new_node->num = nxt->num;
                    return new_node;
                } else{
                    assert(0); // input is given in such an order that this never happens
                    return nxt->add(p, q - t);
    node * root;
    ll sum_depth;
    trie(){root = new node(); sum_depth = 0;}
    void add(vector<pii> vec){
        node* now = root;
        for(auto pq : vec){
            int p = pq.F, q = -pq.S;
            now = now->add(p, q);
        sum_depth += now->depth;
    ll findCentroid(){
        node* now = root;
        int total = now->num;
        ll currAns = sum_depth;
            if(now->mp.empty()) return currAns;
            bool found = false;
            for(auto it : now->mp){
                node* nd = it.S.S;
                if(2 * nd->num > total){
                    found = true;
                    currAns -= (2 * nd->num - total) * (ll) (it.S.F);
                    now = nd;
            if(!found) return currAns;

int main(){
    int numT = 1; /* sd(numT); */ while(numT--){
        int n= 100000;
        trie T;
        vector<vector<pii>> vec(n);
        for(int i = 0; i < n; i++){
            int k= 1;
            for(int j = 0; j < k; j++){
                int p = 2, q = i;
                sd(p); sd(q);
                vec[i].push_back({p, -q});
        for(auto it : vec){
        printf("%lld\n", T.findCentroid());
Tester's Solution (taran_1407)
import java.net.Inet4Address;
import java.util.*;
import java.io.*;
import java.text.*;
public class Main{
    //Into the Hardware Mode
    void pre() throws Exception{}
    void solve(int TC)throws Exception {
        int n = ni();
        int[][][] p = new int[n][][];
        long ans = 0;
        Node root = new Node();
        for(int i = 0; i< n; i++){
            int x = ni();long sum = 0;
            p[i] = new int[x][];
            for(int ii = 0; ii< x; ii++) {
                p[i][ii] = new int[]{ni(), ni()};
                sum += p[i][ii][1];
            ans += sum;
        Arrays.sort(p, (int[][] i1, int[][] i2) -> {
            for(int i = 0; i< Math.min(i1.length, i2.length); i++){
                if(i1[i][0] != i2[i][0])return Integer.compare(i1[i][0], i2[i][0]);
                if(i1[i][1] != i2[i][1])return Integer.compare(i2[i][1], i1[i][1]);
            return Integer.compare(i1.length, i2.length);
        for(int i = 0; i< n; i++)insert(root, p[i]);
        boolean flag = true;
            flag = false;
            for(Edge e:root.nxt.values()){
                if(e.to.deg*2 > n){
                    ans -= (e.to.deg*2-n)*(long)e.q;
                    root = e.to;
                    flag = true;
    void insert(Node root, int[][] a){
        Node tmp = root;
        for(int i = 0; i< a.length; ){
            Edge e = tmp.nxt.getOrDefault(a[i][0], null);
            if(e == null){
                tmp.nxt.put(a[i][0], new Edge(new Node(), a[i][1]));
                tmp = tmp.nxt.get(a[i][0]).to;
            }else if(e.q <= a[i][1]){
                a[i][1] -= e.q;
                tmp = e.to;
                if(a[i][1] == 0)i++;
                Node mid = new Node();
                mid.deg = e.to.deg;
                mid.nxt.put(a[i][0], new Edge(e.to, e.q-a[i][1]));
                tmp.nxt.put(a[i][0], new Edge(mid, a[i][1]));
                tmp = mid;
    class Node{
        int p;
        HashMap<Integer, Edge> nxt;
        int deg;
        public Node(){
            nxt = new HashMap<>();
            deg = 0;
    HashMap<Integer, Edge>[] nxt;
    class Edge{
        int q;
        Node to;
        public Edge(Node T, int Q){to = T;q = Q;}
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    void exit(boolean b){if(!b)System.exit(0);}
    long IINF = (long)1e15;
    final int INF = (int)1e9+2, MX = (int)2e6+5;
    DecimalFormat df = new DecimalFormat("0.00000000000");
    double PI = 3.141592653589793238462643383279502884197169399, eps = 1e-8;
    static boolean multipleTC = false, memory = false, fileIO = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        long ct = System.currentTimeMillis();
        if (fileIO) {
            in = new FastReader("");
            out = new PrintWriter("");
        } else {
            in = new FastReader();
            out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC) ? ni() : 1;
        for (int t = 1; t <= T; t++) solve(t);
        System.err.println(System.currentTimeMillis() - ct);
    public static void main(String[] args) throws Exception{
        if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
        else new Main().run();
    int find(int[] set, int u){return set[u] = (set[u] == u?u:find(set, set[u]));}
    int digit(long s){int ans = 0;while(s>0){s/=10;ans++;}return ans;}
    long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
    int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}
    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
            return st.nextToken();
        String nextLine() throws Exception{
            String str;
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            return str;
1 Like