(JAVA)간격의 합과 차를 빠르게 찾을 수 있는 세그먼트 트리

문제는 쉽습니다. != 실력이 향상되었습니다.

다만 시간을 힘들게 버티고 있어서 생각대로 넘기지 못하고 있을 뿐입니다.

실제로 Interval sum이나 difference와 같은 문제를 만났을 때 주로 Prefix Sum으로 문제를 해결하려고 합니다.

그런데 PrefixSum은 특정 구간의 합을 계산할 때 분명히 O(1)의 장점이 있지만,

단점은 중간 값을 수정하는 데 O(N) 시간이 걸린다는 것입니다.

그래서 그걸 해결하면 시간이 초과되는 경우가 있습니다.

여기서 배워야 할 것은 세그먼트 트리입니다.

세그먼트 트리를 업데이트하면 시간을 O(NlogN)로 줄일 수 있습니다.

모르겠습니다 ? 한 번 외워

import java.util.*;
import java.io.*;

class SegmentTree {
	long () tree;
	int treeSize;
	
	SegmentTree(int n) {
		int h = (int)Math.ceil(Math.log(n)/Math.log(2));
		this.treeSize = (int)Math.pow(2, h+1);
		tree = new long(treeSize);
	}
	
	long init(long () arr, int node, int start, int end) {
		
		// 리프노드 인거
		if(start == end) {
			return tree(node) = arr(start);
		} 
		
		return tree(node) = init(arr, node*2, start, (start+end)/2)
				+ init(arr,node*2+1, (start+end)/2+1, end);
	}
	
	void update(int node, int start, int end, int idx, long diff) {
		if(idx < start || end < idx) {
			return;
		}
		
		tree(node) += diff;
		
		if(start != end) {
			update(node*2, start, (start+end)/2, idx, diff);
			update(node*2+1, (start+end)/2+1, end, idx, diff);
		}
	}
	
	long sum(int node, int start, int end, int left, int right) {
		
		if(left > end || right < start) {
			return 0;
		}
		
		if(left <= start && end <= right) {
			return tree(node);
		}
		
		return sum(node*2, start, (start+end)/2, left, right) +
				sum(node*2+1, (start+end)/2+1, end, left, right);
	}

}


public class 세그먼트트리 {

	public static void main(String() args) throws IOException{
		// TODO Auto-generated method stub
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		String () tmp = br.readLine().strip().split(" ");
		int n = Integer.parseInt(tmp(0));
		int times = Integer.parseInt(tmp(1))+Integer.parseInt(tmp(2));
		
		long () arr = new long(n+1);
		for(int i=1;i<=n;i++) {
			arr(i) = Integer.parseInt(br.readLine().strip());
		}
		
		SegmentTree tree = new SegmentTree(n);
		tree.init(arr, 1, 1, n);
		
		for(int i=0;i<times;i++) {
			tmp = br.readLine().strip().split(" ");
			int type = Integer.parseInt(tmp(0).strip());
			if(type == 1) {
				int update_node = Integer.parseInt(tmp(1));
				int update_num = Integer.parseInt(tmp(2));
				long diff = update_num - arr(update_node);
				tree.update(1, 1, n, update_node, diff);
				
			} else {
				int left = Integer.parseInt(tmp(1));
				int right = Integer.parseInt(tmp(2));
				long sum_value = tree.sum(1, 1, n, left, right);
				System.out.println(sum_value);
			}
		}
	}

}