Programming/백준

[골드 1] 백준 2213 - 트리의 독립집합 (파이썬)

pental 2025. 4. 24. 10:59

https://www.acmicpc.net/problem/2213

풀이

  • 정점 포함 vs 미포함 두 가지 상태로 나눠서 각각 최적값을 저장
  • 서브트리의 정보를 재귀적으로 계산하면서 합산
  • 각 노드가 포함되는 경우, 자식 노드는 반드시 제외
  • 각 노드가 포함되지 않는 경우, 자식 노드는 포함 여부 선택 가능
  • D[u] : 정점 u를 포함하는 경우 최대 독립집합 가중치 합
  • E[u] : 정점 u를 포함하지 않는 경우 최대 독립집합 가중치 합
  • D_sol[u] : 정점 u를 포함하는 경우 선택된 정점 목록
  • E_sol[u] : 정점 u를 포함하지 않는 경우 선택된 정점 목록

DFS는 다음과 같이 처리한다.

def dfs(u):
    visit[u] = True

    D[u] = A[u]  # 자기 자신 포함
    D_sol[u].append(u)

    E[u] = 0  # 자기 자신 미포함

    for v in adj[u]:
        if not visit[v]:
            dfs(v)  # 자식 노드 방문

            D[u] += E[v]              # 자식은 반드시 제외
            D_sol[u].extend(E_sol[v])

            # u가 제외된 경우, 자식은 포함 여부 자유
            if D[v] < E[v]:
                E[u] += E[v]
                E_sol[u].extend(E_sol[v])
            else:
                E[u] += D[v]
                E_sol[u].extend(D_sol[v])

0-Index 를 1-Index로 변환 후 정렬한다.

if D[0] < E[0]:
    print(E[0])          # 루트 포함 X일 때가 더 큰 경우
    E_sol[0].sort()
    print(*list(map(lambda x: x + 1, E_sol[0])))
else:
    print(D[0])
    D_sol[0].sort()
    print(*list(map(lambda x: x + 1, D_sol[0])))

코드

# 백준 2213 - 트리의 독립집합
# 분류 : 다이나믹 프로그래밍

N = int(input())
A = list(map(int, input().split()))
adj = [[] for _ in range(N)]

for _ in range(N - 1) :
    u, v = map(int, input().split())
    u -= 1
    v -= 1

    adj[u].append(v)
    adj[v].append(u)

D = [0] * N
E = [0] * N
D_sol = [[] for _ in range(N)]
E_sol = [[] for _ in range(N)]

visit = [False] * N

def dfs(u) :
    visit[u] = True

    D[u] = A[u]
    D_sol[u].append(u)
    E[u] = 0
    for v in adj[u] :
        if not visit[v] :
            dfs(v)
            D[u] += E[v]
            D_sol[u].extend(E_sol[v])

            if D[v] < E[v] :
                E[u] += E[v]
                E_sol[u].extend(E_sol[v])
            else :
                E[u] += D[v]
                E_sol[u].extend(D_sol[v])

dfs(0)
if D[0] < E[0] :
    print(E[0])
    E_sol[0].sort()
    print(*list(map(lambda x : x + 1, E_sol[0])))
else :
    print(D[0])
    D_sol[0].sort()
    print(*list(map(lambda x : x + 1, D_sol[0])))