Help me in solving TEMPBAL problem

My issue

t=int(input())
for j in range(t):
n=int(input())
lst=list(map(int,input().split()))
count=0
while(any(val!=0 for val in lst)):

    for i in range(n-1):
        if lst[i]>lst[i+1]:
            lst[i]=lst[i]-1
            lst[i+1]=lst[i+1]+1
            count=count+1
        
        elif lst[i]<lst[i+1]:
            lst[i]=lst[i]+1
            lst[i+1]=lst[i+1]-1
            count=count+1
print(count)

My code

t=int(input())
for j in range(t):
    n=int(input())
    lst=list(map(int,input().split()))
    count=0
    while(any(val!=0 for val in lst)):
        
        for i in range(n-1):
            if lst[i]>lst[i+1]:
                lst[i]=lst[i]-1
                lst[i+1]=lst[i+1]+1
                count=count+1
            
            elif lst[i]<lst[i+1]:
                lst[i]=lst[i]+1
                lst[i+1]=lst[i+1]-1
                count=count+1
    print(count)
    

Problem Link: Temperature Balance Practice Coding Problem