Koala - 17기/코딩테스트 심화 스터디

[BOJ/Python3] 11658번 : 구간합 구하기 3

cje_ 2025. 1. 31. 15:11

https://www.acmicpc.net/problem/11658


알고리즘 유형

  • 자료 구조
  • 세그먼트 트리
  • 누적 합
  • 다차원 세그먼트 트리

문제

N×N개의 수가 N×N 크기의 표에 채워져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 표의 i행 j열은 (i, j)로 나타낸다. (x1, y1)부터 (x2, y2)까지 합이란 x1 ≤ x ≤ x2, y1 ≤ y ≤ y2를 만족하는 모든 (x, y)에 있는 수의 합이다.

예를 들어, N = 4이고, 표가 아래와 같이 채워져 있는 경우를 살펴보자.

1 2 3 4
2 3 4 5
3 4 5 6
4 5 6 7


여기서 (2, 2)부터 (3, 4)까지 합을 구하면 3+4+5+4+5+6 = 27이 된다. (2, 3)을 7로 바꾸고 (2, 2)부터 (3, 4)까지 합을 구하면 3+7+5+4+5+6=30 이 된다.

표에 채워져 있는 수와 변경하는 연산과 합을 구하는 연산이 주어졌을 때, 이를 처리하는 프로그램을 작성하시오.

입력

첫째 줄에 표의 크기 N과 수행해야 하는 연산의 수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네 개의 정수 w, x, y, c 또는 다섯 개의 정수 w, x1, y1, x2, y2가 주어진다. w = 0인 경우는 (x, y)를 c (1 ≤ c ≤ 1,000)로 바꾸는 연산이고, w = 1인 경우는 (x1, y1)부터 (x2, y2)의 합을 구해 출력하는 연산이다. (1 ≤ x1 ≤ x2 ≤ N, 1 ≤ y1 ≤ y2 ≤ N) 표에 채워져 있는 수는 1,000보다 작거나 같은 자연수이다.

출력

w = 1인 입력마다 구한 합을 순서대로 한 줄에 하나씩 출력한다.

예제 입출력


풀이

 특정 면적의 합을 연속적으로 구하므로 (O(N)^4) , 누적합을 이용해야한다.(O(N)^2) 

 단순 2차원 배열의 경우 1칸이 변경될 때마다 전체 누적합을 매번 갱신해야하므로 (O(N)^2) 펜윅트리를 통해 더 빠르게 구간합을 처리해야한다. (O(logN)^2)

 즉, 단순 완전탐색(O(N)^4)이 불가능함을 깨닫고, 누적합(O(N)^2)을 사용할 것을 떠올린 다음, 지속적 갱신(O(N)^2)으로 인해 발생되는 연산을 펜윅트리(O(logN)^2)를 통해 구간합을 처리한다. 

코드

import sys

input = sys.stdin.readline

def update(prefix, x, y, diff, n):
    i = x
    while i <= n:
        j = y
        while j <= n:
            prefix[i][j] += diff
            j += (j & -j)
        i += (i & -i)

def query(prefix, x, y):
    s = 0
    i = x
    while i > 0:
        j = y
        while j > 0:
            s += prefix[i][j]
            j -= (j & -j)
        i -= (i & -i)
    return s

def calc(prefix, x1, y1, x2, y2):
    return (query(prefix, x2, y2) - query(prefix, x2, y1 - 1) - query(prefix, x1 - 1, y2) + query(prefix, x1 - 1, y1 - 1) )

n, m = map(int, input().split())
graphs = [list(map(int, input().split())) for _ in range(n)]
prefix = [[0]*(n+1) for _ in range(n+1)]

for i in range(1, n+1):
    for j in range(1, n+1):
        update(prefix, i, j, graphs[i-1][j-1], n)
for _ in range(m):
    line = list(map(int, input().split()))
    if line[0] == 0:
        _, x, y, c = line
        old_value = graphs[x-1][y-1]
        diff = c - old_value
        graphs[x-1][y-1] = c
        update(prefix, x, y, diff, n)
    else:
        _, x1, y1, x2, y2 = line
        print(calc(prefix, x1, y1, x2, y2))