https://www.acmicpc.net/problem/11658
알고리즘 유형
- 자료 구조
- 세그먼트 트리
- 누적 합
- 다차원 세그먼트 트리
문제
N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다. (x1, y1)부터 (x2, y2)까지 합이란 x1 ≤ x ≤ x2, y1 ≤ y ≤ y2를 만족하는 모든 (x, y)에 있는 수의 합이다.
예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.
1 | 2 | 3 | 4 |
2 | 3 | 4 | 5 |
3 | 4 | 5 | 6 |
4 | 5 | 6 | 7 |
여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.
표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.
입력
첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2가 주어진다. w = 0인 경우는 (x, y)를 c (1 ≤ c ≤ 1,000)로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ x1 ≤ x2 ≤ N, 1 ≤ y1 ≤ y2 ≤ N) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.
출력
w = 1인 입력마다 구한 합을 순서대로 한 줄에 하나씩 출력한다.
예제 입출력
풀이
특정 면적의 합을 연속적으로 구하므로 (O(N)^4) , 누적합을 이용해야한다.(O(N)^2)
단순 2차원 배열의 경우 1칸이 변경될 때마다 전체 누적합을 매번 갱신해야하므로 (O(N)^2) 펜윅트리를 통해 더 빠르게 구간합을 처리해야한다. (O(logN)^2)
즉, 단순 완전탐색(O(N)^4)이 불가능함을 깨닫고, 누적합(O(N)^2)을 사용할 것을 떠올린 다음, 지속적 갱신(O(N)^2)으로 인해 발생되는 연산을 펜윅트리(O(logN)^2)를 통해 구간합을 처리한다.
코드
import sys
input = sys.stdin.readline
def update(prefix, x, y, diff, n):
i = x
while i <= n:
j = y
while j <= n:
prefix[i][j] += diff
j += (j & -j)
i += (i & -i)
def query(prefix, x, y):
s = 0
i = x
while i > 0:
j = y
while j > 0:
s += prefix[i][j]
j -= (j & -j)
i -= (i & -i)
return s
def calc(prefix, x1, y1, x2, y2):
return (query(prefix, x2, y2) - query(prefix, x2, y1 - 1) - query(prefix, x1 - 1, y2) + query(prefix, x1 - 1, y1 - 1) )
n, m = map(int, input().split())
graphs = [list(map(int, input().split())) for _ in range(n)]
prefix = [[0]*(n+1) for _ in range(n+1)]
for i in range(1, n+1):
for j in range(1, n+1):
update(prefix, i, j, graphs[i-1][j-1], n)
for _ in range(m):
line = list(map(int, input().split()))
if line[0] == 0:
_, x, y, c = line
old_value = graphs[x-1][y-1]
diff = c - old_value
graphs[x-1][y-1] = c
update(prefix, x, y, diff, n)
else:
_, x1, y1, x2, y2 = line
print(calc(prefix, x1, y1, x2, y2))
'Koala - 17기 > 코딩테스트 심화 스터디' 카테고리의 다른 글
[백준/C++] 2352번: 반도체 설계 (0) | 2025.01.31 |
---|---|
[백준/Python] 21608번: 상어 초등학교 (0) | 2025.01.26 |
[백준/Python] #30648 트릭플라워 (0) | 2025.01.26 |
[백준/Python] 2003번: 수들의 합 2 (0) | 2025.01.26 |
[BOJ/Python3] 2842번 : 집배원 한상덕 (0) | 2025.01.26 |