MAXIMUM_SUM - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: Tejas Pandey
Editorialist: Nishank Suresh




Sieve of Eratosthenes for fast prime factorization


You have an array A. In one move, you can replace two adjacent numbers by their \gcd and \text{lcm}. What is the maximum possible sum of A after performing some moves?


In tasks that have operations to be performed, it’s always nice to make some observations about how things change, and if they change at all.

Say we apply an operation on two integers x and y, with x \leq y.

  • If x is a factor of y, then \gcd(x, y) = x and \text{lcm}(x, y) = y, so there’s no change at all
  • Otherwise, we know for sure that \text{lcm}(x, y) \neq y, which means it must be at least 2y. In particular, \gcd(x, y) + \text{lcm}(x, y) \geq 1 + 2y \gt x + y, so it is always optimal to perform this move since it improves the sum.
  • Also note that the operation allows us to make the final array sorted, since the lcm is always placed after the gcd.

This tells us what the final array will look like in the optimal case: it must be sorted in non-decreasing order, and A_i \mid A_{i+1} for each 1 \leq i \lt N.

In fact, under these conditions, the final array that maximizes the sum is unique. The proof of this uniqueness will also allow us to construct the array, and hence compute its sum.


Let’s look at a single prime p. Let pw_i denote the highest power of p that divides A_i.
Note that an operation on A at index i changes pw as follows:

  • If pw_i \leq pw_{i+1}, do nothing
  • Otherwise, swap pw_i with pw_{i+1}

This, of course, allows us to sort the pw array since we have adjacent swaps. Note that the condition A_i \mid A_{i+1} also necessitates that pw be sorted.

\gcd and \text{lcm} operate independently on each prime, and so the above discussion applies to each one of them: their respective pw arrays must be sorted in the end.

Knowing the pw arrays also tells us the elements, since we effectively know the prime factorizations of each element.

That gives us a working solution:

  • For each prime p, compute its pw array and sort it.
  • Compute the final answer from all the pw arrays for all the primes.

However, this by itself is too slow: there are \approx 8\cdot 10^4 primes less than 10^6, and computing the pw array for each of these in \mathcal{O}(N) is much too slow.

Instead, note that most elements of most pw arrays will just be 0. In particular, only \mathcal{O}(N\log 10^6) positions will be non-zero across all arrays, corresponding to the prime factorizations of the input elements.

So, prime factoring the input is enough to get all the information we need, since 0's in the pw array don’t affect anything anyway. Keeping this compressed information about non-zero powers allows us to solve the problem quickly.

Thus, the final solution is as follows:

  • Prime factorize each A_i quickly.
    • This can be done in \mathcal{O}(\log 10^6) with a sieve of Eratosthenes that precomputes, for each integer, its smallest prime factor.
  • Use these prime factorizations to build compressed pw arrays, that don’t include the zeros.
  • Sort the obtained arrays and give them to positions in reverse order, starting from N
  • Finally, compute the sum of the resulting array.


\mathcal{O}(N\log 10^6) per test case.


Setter's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll  TMD=1000000007;
const ll  INF=2147483647;
int T,n;
ll  ans;
int a[N];
ll  b[N];
vector<int> v[N];

ll pw(ll x,ll p)
	if(!p) return 1;
	ll y=pw(x,p>>1);
	if(p&1) y=(y*x)%TMD;
	return y; 

int main()
		for(int i=1;i<=n;i++) scanf("%d",&a[i]);
		for(int i=2;i<N;i++)  v[i].clear();
		for(int i=1;i<=n;i++)
			int t=a[i];
			for(int j=2;j*j<=t;j++)
				if(t%j) continue;
				int cnt=0;
				while(!(t%j)) cnt++,t/=j;
			if(t!=1) v[t].push_back(1);
		for(int i=2;i<N;i++) sort(v[i].begin(),v[i].end(),greater<int>());
		for(int i=1;i<=n;i++) b[i]=1;
		for(int i=2;i<N;i++)
			for(int j=0;j<v[i].size();j++) b[j+1]=(b[j+1]*pw(i,v[i][j]))%TMD;
		for(int i=1;i<=n;i++) ans=(ans+b[i])%TMD;
	return 0;
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define maxn 100007
#define mod 1000000007

vector<long long int> primes;
bool isp[maxn];

long long int mpow(long long int a, long long int b) {
    long long int res = 1;
    while(b) {
        if(b&1) res *= a, res %= mod;
        a *= a;
        a %= mod;
        b >>= 1;
    return res;

void seive() {
    for(int i = 2; i < maxn; i++) {
        if(isp[i]) continue;
        for(int j = i*2; j < maxn; j += i)
            isp[j] = 1;

int main() {
	int t;
	cin >> t;
	while(t--) {
	    int n;
	    cin >> n;
	    long long int a[n], b[n];
	    for(int i = 0; i < n; i++) cin >> a[i], b[i] = 1;
	    map<int, vector<int>> pows;
	    for(int i = 0; i < n; i++) {
	        int now = 0, pp = primes[now];
	        while(pp*pp <= a[i]) {
	            int cnt = 0;
	            while(a[i]%pp == 0) a[i] /= pp, cnt++;
	            if(cnt) pows[pp].push_back(cnt);
	            pp = primes[now];
	        if(a[i] > 1) pows[a[i]].push_back(1);
	    for(auto it: pows) {
	        sort(it.second.rbegin(), it.second.rend());
	        for(int i = 0; i < it.second.size(); i++)
	            b[n - 1 - i] *= mpow(it.first, it.second[i]), b[n - 1 - i] %= mod;
	    long long int ans = 0;
	    for(int i = 0; i < n; i++) ans += b[i], ans %= mod;
	    cout << ans << "\n";
	return 0;
Editorialist's code (Python)
maxn = int(10 ** 6 + 5)
mod = int(10**9 + 7)

prmfac = [0]*maxn
id = [0]*maxn
primes = []
curid = 0
for i in range(2, maxn):
	if prmfac[i] > 0: continue
	id[i] = curid
	curid += 1
	for j in range(i, maxn, i):
		prmfac[j] = i

for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	pw = [[] for _ in range(curid)]
	for x in a:
		while x > 1:
			p = prmfac[x]
			ct = 0
			while x%p == 0:
				x //= p
				ct += 1
	ans = [1]*n
	for i in range(curid):
		sz = len(pw[i])
		for j in range(len(pw[i])):
			ans[n-1-j] *= pow(primes[i], pw[i][sz-1-j], mod)
			ans[n-1-j] %= mod