Koala - 10기/코딩테스트 준비 스터디

[BAEKJOON/Python] 27730 견우와 직녀

beans3142 2023. 3. 19. 17:12
https://www.acmicpc.net/problem/27730
 

27730번: 견우와 직녀

견우는 정점의 개수가 $N$인 무향 가중치 트리 $E$에 살고 있고, 직녀는 정점의 개수가 $M$인 무향 가중치 트리 $W$에 살고 있다. 두 사람은 각자 다른 트리에 살고 있으므로 만날 수 없다... 슬픔에

www.acmicpc.net



문제 분석

난이도

플래티넘5

분류

그래프 탐색, DP, DFS

들어가기 전에

이번 2022 SHAKE!에 나왔던 문제. 점화식을 떠올리는 것도 꽤나 어려운 문제라고 생각한다.

문제

 

문제 풀이

 

풀이

DFS를 이용해 각 노드의 자식 노드의 개수를 구해준다. 그 이유는 자식의 노드의 개수가 곧 해당 노드와 해당 노드의 부모 노드가 연결된 간선의 호출 횟수이기 때문이다.

그렇게 구해준 자식의 노드수를 통해서 한 점에 대해 그 점으로부터 모든 점까지 거리의 합을 구해준다. 그리고 그 점에 대해 구해주면서 다른 노드들에는 그 노드로부터 자식 노드들까지의 거리를 함께 구해줄 수 있다.

for usetime,now in order:
        dist[par[now]]+=tree[par[now]][now]*usetime
        dist[par[now]]+=dist[now]

그 방법은 위와 같다. 앞서 구해준 자식 노드의 개수를 통해 자기 자신과 부모 노드사이의 간선의 가중치를 곱해서 더해주고 자식 노드가 갖는 누적 거리의 합을 더해준다.

그리고 이 값들을 이용해 O(1)에 다른 노드에서 모든 노드까지의 거리를 구할 수 있다.

그 원리는 다음과 같다. 다른 모든 노드들까지의 거리를 구해놓은 노드 a에 대해 그 노드 a로부터 구한 거리는 a->b + a->c + a->d .... 일 것이다. 그리고 다른 노드 b는 자신의 자식 노드들까지의 거리의 합을 가지고 있을 것이다.

a에서 구한 거리 - b에서 구한 거리를 빼준다면 그 값이 b노드와 그 자식 노드들을 제외한 모든 노드의 거리가 될 것이다.
이제 b노드를 중심으로 a와 연결시켜주면 된다.

이때 점화식은 다음과 같다.

dist[now]+=dist[bef]-dist[now]-tree[bef][now]*use[now]+tree[bef][now]*(use[bef]-use[now])

 

그렇게 다른 모든 정점들까지 거리의 합이 최소가 되는 노드를 구할 수 있다.

그 노드들을 연결 시켜준 뒤 E의 모든 노드와 W의 모든 노드 사이의 거리를 구해주면 된다.

소스코드

from sys import stdin,setrecursionlimit
setrecursionlimit(120000)
from collections import defaultdict as dd
from collections import deque
input=stdin.readline

def solve(n):
    def getdist(now):
        use[now]=1
        for i in tree[now]:
            if vi[i]==0:
                vi[i]=1
                getdist(i)
                par[i]=now
                use[now]+=use[i]

    def getminloc():
        vi=[0]*(n+1)
        vi[1]=1
        queue=deque([])
        for i in tree[1]:
            queue.append((i,1))
        while queue:
            now,bef=queue.popleft()
            vi[now]=1
            dist[now]+=dist[bef]-dist[now]-tree[bef][now]*use[now]+tree[bef][now]*(use[bef]-use[now])
            use[now]=n
            for i in tree[now]:
                if vi[i]==0:
                    queue.append((i,now))

    tree=[dd(int) for i in range(n+1)]
    dist=[0]*(n+1)
    vi=[0]*(n+1)
    vi[1]=1
    use=[0]*(n+1)
    par=[i for i in range(n+1)]
    for i in range(n-1):
        a,b,c=map(int,input().split())
        tree[a][b]=tree[b][a]=c

    a=getdist(1)
    order=[]
    for i in range(1,n):
        order.append((use[i+1],i+1))
    order.sort()

    for usetime,now in order:
        dist[par[now]]+=tree[par[now]][now]*usetime
        dist[par[now]]+=dist[now]
        
    getminloc()
    mn=float('inf')
    mnidx=''
    for i in range(n):
        if dist[i+1]<mn:
            mn=dist[i+1]
            mnidx=i+1
    return mn,mnidx

n=int(input())
mn1,mnidx1=solve(n)
m=int(input())
mn2,mnidx2=solve(m)
print(mnidx1,mnidx2)
print(mn1*m+mn2*n+m*n)

후기