mo1lusca의 블로그
[백준] 13510 트리와 쿼리 1 - C++ 본문
https://www.acmicpc.net/problem/13510
트리상의 임의의 두 노드 u, v를 잇는 경로 중 비용이 가장 큰 간선의 비용을 출력하면 된다.
HLD (Heavy Light Decomposition) 를 사용하면 된다!
이 글에서는 이론적인 내용보단 구현적인 내용을 다룬다.
부분 코드 설명
#include <bits/stdc++.h>
using namespace std;
const int MAX = 100'005;
int sz[MAX];
int dep[MAX];
int par[MAX];
int top[MAX];
int in[MAX], out[MAX];
vector<int> g[MAX];
vector<pair<pair<int, int>, int>> edge;
int seg[MAX * 4 + 16];
선언부이다.. 와우...
뭐가 많아보이지만 사실 간단하다!
- sz[i]는 노드 i를 루트로 삼는 서브트리의 크기이다.
- dep[i]는 노드 i의 깊이이다.
- par[i]는 노드 i의 부모 노드이다.
- top[i]는 노드 i가 속한 체인의 최상단 노드이다.
- in[], out[]은 ETT처럼 방문시간을 기록한 것이다.
- g는 문제에서 주는 원본 트리이다. 무방향이다.
- edge는 간선 정보를 저장한다. (문제 조건이다.) {{u,v},w}의 구조로 노드 두개, 비용을 저장한다,
- seg는 세그트리다!
void dfs1(int v = 1, int p = 0) {
par[v] = p;
sz[v] = 1;
for (auto& i : g[v]) {
if (p == i) {
continue;
}
dep[i] = dep[v] + 1;
par[i] = v;
dfs1(i, v);
sz[v] += sz[i];
if (g[v][0] == p || sz[i] > sz[g[v][0]]) {
swap(i, g[v][0]);
}
}
}
첫번째 DFS 함수이다.
par, sz, dep 배열을 채우고, 서브트리가 가장 큰 자식을 맨 앞으로 보낸다.
이로 인해 dfs1 함수가 끝나면, 노드 v에서 뻗어나가는 heavy edge는 ( v, g[v][0] ) 이 된다.
int cnt = 0;
void dfs2(int v = 1, int p = 0) {
in[v] = ++cnt;
for (auto i : g[v]) {
if (p == i) {
continue;
}
top[i] = (i == g[v][0] ? top[v] : i); //<1>
dfs2(i, v);
}
out[v] = cnt;
}
두번째 DFS 함수이다. in과 out, top을 채운다.
top은 아까 말했듯 해당 노드가 속한 체인의 최상단에 있는 정점을 나타낸다.
<1> : 노드 v의 자식 노드 i의 top을 채운다.
- i가 g[v][0]이라면, 즉 i가 heavy child라면 v의 top을 물려받는다. (같은 heavy chain에 속하기 때문에)
- i가 g[v][0]이 아니라면, i는 v가 속한 heavy chain에 포함되지 않기 때문에, 새로운 체인을 시작해준다 (최상단이 자기 자신)
void update(int idx, int left, int right, int cidx, int cval) {
if (cidx < left || right < cidx) {
return;
}
if (left == right) {
seg[idx] = cval;
return;
}
int mid = (left + right) / 2;
update(idx * 2, left, mid, cidx, cval);
update(idx * 2 + 1, mid + 1, right, cidx, cval);
seg[idx] = max(seg[idx * 2], seg[idx * 2 + 1]);
return;
}
int s_query(int idx, int left, int right, int qleft, int qright) {
if (qright < left || right < qleft) {
return 0;
}
if (qleft <= left && right <= qright) {
return seg[idx];
}
int mid = (left + right) / 2;
return max(s_query(idx * 2, left, mid, qleft, qright), s_query(idx * 2 + 1, mid + 1, right, qleft, qright));
}
점업뎃 구간쿼리 맥스세그를 구현하자.
쿼리를 수행하는 함수명이 query가 아니라 s_query인 이유는... 쿼리를 수행하려면 한 단계가 더 필요하기 때문이다.
int query(int u, int v) {
int res = 0;
while (top[u] != top[v]) { //<1>
if (dep[top[u]] < dep[top[v]]) { //<2>
swap(u, v);
}
int st = top[u]; //<3>
res = max(res, s_query(1, 1, n, in[st], in[u])); //<4>
u = par[st]; //<5>
}
if (dep[u] > dep[v]) { //<6-0>
swap(u, v);
}
res = max(res, s_query(1, 1, n, in[u] + 1, in[v])); //<6>
return res;
}
HLD에 기반한 경로쿼리 함수이다.
<1> : 노드 u와 v가 같은 heavy chain에 속할 때 까지 반복한다.
<2> : top[u]의 depth가 항상 top[v]의 depth보다 같거나 크도록 만든다. depth가 큰 쪽부터 쿼리를 처리하기 때문이다.
- u가 속한 체인이 지금 처리할 체인이라는 뜻이다.
<3> : st는 u가 속한 체인의 최상단 노드를 뜻한다.
<4> : st ~ u 구간에 쿼리를 날린다. 이로서 u가 속했던 heavy chain에서의 최댓값을 찾을 수 있다.
-- in[st] ~ in[u] 구간 내의 모든 노드가 하나의 heavy chain에 속한다는 것을 어떻게 알 수 있을까?
in은 그저 dfs2가 방문한 순서이니 heavy chain에 속하지 않은 노드도 포함되는 것 아닐까?
하지만 우리는 dfs1에서 heavy child를 맨 앞( g[v][0] ) 에 오게 해두었다.
이것이 dfs2에서 g[v]를 순회하며 탐색할 때, 항상 heavy child인 노드를 먼저 탐색하니
결과적으로 heavy chain을 연속적으로 탐색할 수 있게 해주는것이다!
따라서 같은 heavy chain에 속한 노드는 in 내에서 연속적인 구간을 갖는다.
<5> : u가 속한 체인의 최상단 노드의 부모를 u에 저장한다. <4>에서 쿼리를 처리했으니 상위 체인으로 커서를 옮기는 느낌이다.
<6> : 이 코드는 <1>의 while문이 종료되었을 때에 실행된다. 그 말은 즉 현재 u와 v가 같은 heavy chain 내에 있고, 이외의 구간에 대한 처리는 모두 <4>를 통해 처리했다는 의미이다.
한마디로, 현재 u와 v가 같은 체인에 있으니 마지막으로 u~v 구간에 대해 쿼리를 날려주면 된다.
- in[u] ~ in[v] 구간이 아니라 in[u]+1 ~ in[v] 구간인데에는 이유가 있다.
문제상 노드가 아니라 간선에 대해 쿼리를 수행하는 것이기 때문에 HLD를 쓰기 위해서 간선을 고유한 노드에 매핑할 필요가 있다.
그리고 임의의 노드에서 부모로 향하는 간선은 유일하기 때문에 노드와 해당 노드의 부모로 향하는 간선을 서로 매핑하게 된다.
이때 <6-0>에 의해 u~v중 depth가 제일 낮은 (루트와 제일 가까운) 노드인 u는 u의 부모로 향하는 간선을 나타내게 되고, 이는 u~v의 구간에서 벗어난다.
따라서 in[u]+1을 해줌으로 u~v 구간을 나타나게 한다.
아무튼 이런 절차를 거쳐 얻은 res를 return해준다!
main함수는 특별한게 딱히 없으므로 자세한 설명은 생략한다.
전체 코드
#include <bits/stdc++.h>
using namespace std;
const int MAX = 100'005;
int sz[MAX]; //i를 루트로 하는 서브트리 크기
int dep[MAX]; //i의 깊이
int par[MAX]; //i의 부모 노드
int top[MAX]; //i가 속한 체인의 최상단 노드
int in[MAX], out[MAX]; //DFS ordering
vector<int> g[MAX]; // graph
vector<pair<pair<int, int>, int>> edge;
int seg[MAX * 4 + 16];
int n;
void dfs1(int v = 1, int p = 0) {
par[v] = p;
sz[v] = 1;
for (auto& i : g[v]) {
if (p == i) {
continue;
}
dep[i] = dep[v] + 1;
par[i] = v;
dfs1(i, v);
sz[v] += sz[i];
if (g[v][0] == p || sz[i] > sz[g[v][0]]) {
swap(i, g[v][0]);
}
}
}
int cnt = 0;
void dfs2(int v = 1, int p = 0) {
in[v] = ++cnt;
for (auto i : g[v]) {
if (p == i) {
continue;
}
top[i] = (i == g[v][0] ? top[v] : i);
dfs2(i, v);
}
out[v] = cnt;
}
void update(int idx, int left, int right, int cidx, int cval) {
if (cidx < left || right < cidx) {
return;
}
if (left == right) {
seg[idx] = cval;
return;
}
int mid = (left + right) / 2;
update(idx * 2, left, mid, cidx, cval);
update(idx * 2 + 1, mid + 1, right, cidx, cval);
seg[idx] = max(seg[idx * 2], seg[idx * 2 + 1]);
return;
}
int s_query(int idx, int left, int right, int qleft, int qright) {
if (qright < left || right < qleft) {
return 0;
}
if (qleft <= left && right <= qright) {
return seg[idx];
}
int mid = (left + right) / 2;
return max(s_query(idx * 2, left, mid, qleft, qright), s_query(idx * 2 + 1, mid + 1, right, qleft, qright));
}
int query(int u, int v) {
int res = 0;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u, v);
}
int st = top[u];
res = max(res, s_query(1, 1, n, in[st], in[u]));
u = par[st];
}
if (dep[u] > dep[v]) {
swap(u, v);
}
res = max(res, s_query(1, 1, n, in[u] + 1, in[v]));
return res;
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int m;
edge.push_back({ {-1,-1},-1 });
cin >> n;
for (int i = 1; i <= n - 1; i++) {
int u, v, w;
cin >> u >> v >> w;
edge.push_back({ {u,v},w });
g[u].push_back(v);
g[v].push_back(u);
}
dep[1] = 0;
par[1] = 0;
top[1] = 1;
dfs1();
dfs2();
for (int i = 1; i <= n - 1; i++) {
int u = edge[i].first.first;
int v = edge[i].first.second;
int w = edge[i].second;
if (dep[u] < dep[v]) {
swap(u, v);
}
update(1, 1, n, in[u], w);
}
cin >> m;
for (int i = 1; i <= m; i++) {
int command, a, b;
cin >> command >> a >> b;
if (command == 1) {
int u = edge[a].first.first;
int v = edge[a].first.second;
if (dep[u] < dep[v]) {
swap(u, v);
}
update(1, 1, n, in[u], b);
}
else {
cout << query(a, b) << "\n";
}
}
return 0;
}
세상에 마상에 이런 아름다운 알고리즘이 있다니
'PS' 카테고리의 다른 글
| [백준] 13263 나무 자르기 - C++ (0) | 2025.12.19 |
|---|---|
| [Codeforces] Round 1065 (Div. 3) 업솔빙 (0) | 2025.11.27 |
| [백준] 18227 성대나라의 물탱크 - C++ (0) | 2025.10.23 |
| [백준] 14268 회사 문화 2 - C++ (0) | 2025.10.22 |
| [백준] 1517 버블 소트 - C++ (3) | 2025.08.18 |