I tried to solve this problem using adjacency list and set, and I use map iterator for finding index of maximum value. here is my implementation that takes 0.52 sec to give AC.
This can be done using 2 heaps too if
one isn’t familiar with sets.
Maintain 2 max heaps 1 of all the
Pi’s Other of the Pi’s of V and it’s
neighbours
Check if the top element
(max element) of 1st queue is greater
than that of 2nd, if yes stop If not,
check if they’re equal, if yes, pop
both of them out of the heap If 1st
is less than 2nd, pop the max element
of 2nd heap
Do this recursively until
you exit from the first case. Add the deleted elements back in the 1st heap.
Seeing so many people happy with their solutions passing in 0.16-0.19 seconds…I am feeling weird coz my brute force passed in 0.09 seconds xD. That too, unoptimized :p.
PS- Is it really brute force? Can we, put certain restrictions on number of operations using the fact that “It is a tree” and observations on degree of nodes? IDK, thats for you to decide~
As an interesting aside: because the network is a tree, there can only be at most three alternative capitals, one of which is of course the next most populous planet NxtPop:
case 1: NxtPop is directly connected to the capital; then you need at most two more planets, drawn from three groups: the other capital neighbours, the other NxtPop neighbours, and all other planets
case 2: NxtPop and the capital have a common neighbour; then you need one more planet that is not adjacent to that common neighbour
case 3: NxtPop and the capital have no common neighbours; then you do not need any other planet.
It’s giving TLE for later cases.
Where can I improve?
def weight(dmap , l , ind , d):
l1 = l[:]
l1.remove(ind)
for i in dmap[ind]:
l1.remove(i)
m = 0
ans = -1
for i in l1:
m1 = d[i]
if(m<m1):
m = m1
ans = i
return ans
t = int(input())
while(t>0):
n = int(input())
indexs = []
for i in range(n):
indexs.append(i+1)
d =list(map(int , input().strip().split()))
d1 = {}
for i in range(1 , n+1):
d1[i] = d[i-1]
# print(d)
lmap = {}
for i in range(n-1):
l = list(map(int , input().strip().split()))
if l[0] in lmap:
if(l[1] not in lmap[l[0]]):
lmap[l[0]].append(l[1])
else:
lmap[l[0]] = [l[1]]
if(l[1] in lmap):
if(l[0] not in lmap[l[1]]):
lmap[l[1]].append(l[0])
else:
lmap[l[1]] = [l[0]]
for i in range(1 , n+1):
m = weight(lmap , indexs , i , d1 )
print(m , end = " ")
t-=1