문제 분석
난이도
플래티넘5
분류
그래프 탐색, DP, DFS
들어가기 전에
이번 2022 SHAKE!에 나왔던 문제. 점화식을 떠올리는 것도 꽤나 어려운 문제라고 생각한다.
문제
문제 풀이
풀이
DFS를 이용해 각 노드의 자식 노드의 개수를 구해준다. 그 이유는 자식의 노드의 개수가 곧 해당 노드와 해당 노드의 부모 노드가 연결된 간선의 호출 횟수이기 때문이다.
그렇게 구해준 자식의 노드수를 통해서 한 점에 대해 그 점으로부터 모든 점까지 거리의 합을 구해준다. 그리고 그 점에 대해 구해주면서 다른 노드들에는 그 노드로부터 자식 노드들까지의 거리를 함께 구해줄 수 있다.
for usetime,now in order:
dist[par[now]]+=tree[par[now]][now]*usetime
dist[par[now]]+=dist[now]
그 방법은 위와 같다. 앞서 구해준 자식 노드의 개수를 통해 자기 자신과 부모 노드사이의 간선의 가중치를 곱해서 더해주고 자식 노드가 갖는 누적 거리의 합을 더해준다.
그리고 이 값들을 이용해 O(1)에 다른 노드에서 모든 노드까지의 거리를 구할 수 있다.
그 원리는 다음과 같다. 다른 모든 노드들까지의 거리를 구해놓은 노드 a에 대해 그 노드 a로부터 구한 거리는 a->b + a->c + a->d .... 일 것이다. 그리고 다른 노드 b는 자신의 자식 노드들까지의 거리의 합을 가지고 있을 것이다.
a에서 구한 거리 - b에서 구한 거리를 빼준다면 그 값이 b노드와 그 자식 노드들을 제외한 모든 노드의 거리가 될 것이다.
이제 b노드를 중심으로 a와 연결시켜주면 된다.
이때 점화식은 다음과 같다.
dist[now]+=dist[bef]-dist[now]-tree[bef][now]*use[now]+tree[bef][now]*(use[bef]-use[now])
그렇게 다른 모든 정점들까지 거리의 합이 최소가 되는 노드를 구할 수 있다.
그 노드들을 연결 시켜준 뒤 E의 모든 노드와 W의 모든 노드 사이의 거리를 구해주면 된다.
소스코드
from sys import stdin,setrecursionlimit
setrecursionlimit(120000)
from collections import defaultdict as dd
from collections import deque
input=stdin.readline
def solve(n):
def getdist(now):
use[now]=1
for i in tree[now]:
if vi[i]==0:
vi[i]=1
getdist(i)
par[i]=now
use[now]+=use[i]
def getminloc():
vi=[0]*(n+1)
vi[1]=1
queue=deque([])
for i in tree[1]:
queue.append((i,1))
while queue:
now,bef=queue.popleft()
vi[now]=1
dist[now]+=dist[bef]-dist[now]-tree[bef][now]*use[now]+tree[bef][now]*(use[bef]-use[now])
use[now]=n
for i in tree[now]:
if vi[i]==0:
queue.append((i,now))
tree=[dd(int) for i in range(n+1)]
dist=[0]*(n+1)
vi=[0]*(n+1)
vi[1]=1
use=[0]*(n+1)
par=[i for i in range(n+1)]
for i in range(n-1):
a,b,c=map(int,input().split())
tree[a][b]=tree[b][a]=c
a=getdist(1)
order=[]
for i in range(1,n):
order.append((use[i+1],i+1))
order.sort()
for usetime,now in order:
dist[par[now]]+=tree[par[now]][now]*usetime
dist[par[now]]+=dist[now]
getminloc()
mn=float('inf')
mnidx=''
for i in range(n):
if dist[i+1]<mn:
mn=dist[i+1]
mnidx=i+1
return mn,mnidx
n=int(input())
mn1,mnidx1=solve(n)
m=int(input())
mn2,mnidx2=solve(m)
print(mnidx1,mnidx2)
print(mn1*m+mn2*n+m*n)
후기
'Koala - 10기 > 코딩테스트 준비 스터디' 카테고리의 다른 글
[백준/Python] #14495 피보나치 비스무리한 수열 (0) | 2023.03.19 |
---|---|
[백준/Python] 9184번 신나는 함수 실행 (0) | 2023.03.19 |
[백준/C++] 15624번 피보나치 수 7 (0) | 2023.03.19 |
[백준/Python] 9465 스티커 (0) | 2023.03.19 |
[백준/Python]15624번: 피보나치 수 7 (0) | 2023.03.19 |