4 minute read

1. 전국시대

  • 전국시대엔 N개의 국가가 존재한다. 각 국가는 1부터 N까지의 번호를 가지고 있다.

  • 또한, 모든 국가는 각자 자신의 국가의 힘을 상징하는 병력을 가지고 있다.
  • 이때 M개의 기록이 주어진다. 각각의 기록은 다음과 같다.

    • 동맹 - 두 나라가 서로 동맹을 맺는다. 두 나라의 병력이 하나로 합쳐진다.
    • 전쟁 - 두 나라가 서로 전쟁을 벌인다. 병력이 더 많은 나라가 승리하며 패배한 나라는 속국이 된다.
      • 이때 남은 병력은 승리한 나라의 병력에서 패배한 나라의 병력을 뺀 수치가 된다.
      • 두 나라의 병력이 같을 경우 두 나라 모두 멸망한다.
  • 모든 나라는 정직하기 때문에 내 동맹의 동맹도 나의 동맹이고, 내 동맹이 적과 전쟁을 시작하면 같이 참전한다.
  • 속국인 경우도 동맹의 경우와 마찬가지이다.

  • 따라서, 전쟁에서 진 국가와 동맹인 다른 국가 또한 전쟁에서 이긴 국가의 속국이 된다.

  • 모든 기록이 끝났을 때 남아있는 국가의 수를 출력하고, 그 국가들의 남은 병력의 수를 오름차순으로 출력하는 프로그램을 작성하시오.

  • 단, 여러 국가가 서로 동맹이거나 속국 관계인 경우는 한 국가로 취급한다.

입력

  • 첫 번째 줄에 국가의 수를 나타내는 N과 기록의 수 M이 주어진다. (1 ≤ N, M ≤ 100,000)

  • 두 번째 줄 부터 N개의 줄에 걸쳐 i번째 국가의 병력 Ai (1 ≤ i ≤ N)가 자연수로 주어진다. (1 ≤ Ai ≤ 10,000)

  • 다음 M개의 줄에는 기록이 3개의 정수 O, P, Q로 주어진다. O가 1인 경우 P, Q가 서로 동맹을 맺었음을 의미하고, O가 2인 경우 P, Q가 서로 전쟁을 벌였음을 의미한다.

  • 동맹끼리 다시 동맹을 맺거나 전쟁하는 입력은 주어지지 않는다.

출력

  • 첫째 줄에 남아있는 국가의 수를 출력한다.

  • 다음 줄에 각 국가의 남은 병력의 수를 띄어쓰기를 간격으로 오름차순으로 출력한다.

풀이

# your code goes here
 
import sys
 
# 첫 번째 줄에 국가의 수를 나타내는 N과 기록의 수 M이 주어진다. (1 ≤ N, M ≤ 100,000)
 
N, M = list(map(int,sys.stdin.readline().split()))
 
# 두 번째 줄 부터 N개의 줄에 걸쳐 i번째 국가의 병력 Ai (1 ≤ i ≤ N)가 자연수로 주어진다. (1 ≤ Ai ≤ 10,000)
 
 
power = {} # i번째 국가의 병력은 Ai이다.
alias = {} # 국가들의 동맹관계를 의미한다. 같은 배열 안의 국가는 동맹으로 취급한다.
 
for i in range(0,N):
	Ai = int(sys.stdin.readline())
	power[i+1] = Ai
	alias[i+1] = [i+1] # 초기에는 동맹이 자기자신뿐이다.
 
# 다음 M개의 줄에는 기록이 3개의 정수 O, P, Q로 주어진다.
# O가 1인 경우 P, Q가 서로 동맹을 맺었음을 의미하고, O가 2인 경우 P, Q가 서로 전쟁을 벌였음을 의미한다.
 
for _ in range(0,M):
	O, P, Q = list(map(int,sys.stdin.readline().split()))
 
	# case O = 1. 동맹
	if O == 1:
		tmp1 = power[P]
		tmp2 = power[Q]
		# 모든 나라는 정직하기 때문에 내 동맹의 동맹도 나의 동맹이다.
		# 두 나라가 서로 동맹을 맺는다. 두 나라의 병력이 하나로 합쳐진다.
		for a in [x for x in list(set(alias[P]+alias[Q]))]:
			alias[a] = list(set(alias[P]+alias[Q]))
			power[a] = tmp1 + tmp2
 

	# case O = 2. 전쟁
	elif O == 2:
		# 두 나라가 서로 전쟁을 벌인다. 
		# 두 나라의 병력이 같을 경우 두 나라 모두 멸망한다.
		if power[P] == power[Q]:
			# 동맹 리스트에서 제거
			# 파워 리스트에서 제거
			for a in [x for x in alias[P] + alias[Q]]:
				del alias[a]
				del power[a]
 
 
		else:
			# 병력이 더 많은 나라가 승리하며 패배한 나라는 속국이 된다. 
			winner = P if power[P] > power[Q] else Q
			loser = Q if power[P] > power[Q] else P
			# 이때 남은 병력은 승리한 나라의 병력에서 패배한 나라의 병력을 뺀 수치가 된다. 
			tmp = power[winner]
			power[winner] = tmp - power[loser]
 
			# 패배한 동맹을 승리한 동맹에 편입한다.
			tmp = alias[winner].copy()
			tmp2 = alias[loser].copy()
			for a in [x for x in alias[loser]]:
				alias[a] = sorted(tmp + alias[a])
				power[a] = power[winner]
 
 
			# 승리한 동맹의 모든 병력을 동일하게 맞춘다.
			for a in [x for x in alias[winner]]:
				alias[a] = sorted(tmp2 + alias[a])
				power[a] = power[winner]
 
 
 
alias_no_dup = {str(val): key for key,val in alias.items()}
alias_result = {val: key for key,val in alias_no_dup.items()}
 
power_result = []
 
# 단, 여러 국가가 서로 동맹이거나 속국 관계인 경우는 한 국가로 취급한다.
# 모든 기록이 끝났을 때 남아있는 국가의 수를 출력하고, 
print(len(alias_result))
 
# 그 국가들의 남은 병력의 수를 오름차순으로 출력하는 프로그램을 작성하시오.
for a in [x for x in alias_result.keys()]:
	power_result.append(power[a])
 
print(' '.join([str(x) for x in sorted(power_result)]))

채점 결과

  • 오답

개념정리 - 서로소 집합

  • Disjoint-set 또는 union-find라고도 한다.

  • 서로소 집합이란, 공통 원소가 없는 두 집합을 의미한다.

  • find 연산과 Union 연산을 통해 문제를 해결한다.

    • find : 어떤 원소가 주어졌을 때, 이 원소가 속한 집합을 반환한다.

    • Union : 두개의 집합을 하나의 집합으로 합친다.

  • 트리 자료구조의 관점에서 보면,

    • find연산은 트리의 루트노드를 찾고, 그 루트노드를 통해 특정 집합을 표현하는것이다.
    • Union연산은 두 원소에 대해 find연산을 수행하여 각각의 루트노드를 찾고, 한쪽의 루트노드를 다른쪽에 연결하여 하나의 트리로 만드는 연산이다.

개념 정리후 풀이

# your code goes here
 
import sys
 
# 첫 번째 줄에 국가의 수를 나타내는 N과 기록의 수 M이 주어진다. (1 ≤ N, M ≤ 100,000)
 
N, M = list(map(int,sys.stdin.readline().split()))
 
# 두 번째 줄 부터 N개의 줄에 걸쳐 i번째 국가의 병력 Ai (1 ≤ i ≤ N)가 자연수로 주어진다. (1 ≤ Ai ≤ 10,000)

root = [x for x in range(0,N+1)] # 동맹 상황 트리
power = [x for x in range(0,N+1)] # 각 국가의 병력




for i in range(1,N+1):
	Ai = int(sys.stdin.readline())
	root[i] = i
	power[i] = Ai
	

	
	
def find(x):
	if root[x] != x:
		root[x] = find(root[x])
	return root[x] 
	
	
def alias(a,b):
	a = find(a)
	b = find(b)
	
	root[b] = a
	power[a] += power[b]
	
	
def fight(a,b):
	a = find(a)
	b = find(b)
	
	if power[a] > power[b]:
		power[a] -= power[b]
		root[b] = a
	elif power[a] < power[b]:
		power[b] -= power[a]
		root[a] = b
	else:
		root[a] = 0
		root[b] = 0
		

	
# 다음 M개의 줄에는 기록이 3개의 정수 O, P, Q로 주어진다.
# O가 1인 경우 P, Q가 서로 동맹을 맺었음을 의미하고, O가 2인 경우 P, Q가 서로 전쟁을 벌였음을 의미한다.
 
for _ in range(0,M):
	O, P, Q = map(int,sys.stdin.readline().split())
	if O == 1:
		alias(P,Q)
	elif O == 2:
		fight(P,Q)
		

a = set()

for i in range(1,N+1):
	p = find(i)
	if p != 0:
		a.add(p)

print(len(a))
power_result = sorted([power[x] for x in a])
print(' '.join([str(x) for x in power_result]))

재채점 결과

  • 정답