mo1lusca의 블로그
[백준] 18227 성대나라의 물탱크 - C++ 본문
https://www.acmicpc.net/problem/18227
루트(수도)와 임의의 노드 사이의 경로에 있는 노드들에,
루트부터 +1, +2, +3...이 더해진다.
이렇게 경로라고 생각하면? 뭔가 HLD + Lazy prop을 써야할 것 같지만
ETT만으로도 풀 수 있다!
핵심 아이디어는 노드 당 물이 부어진 횟수와,
물을 한번 부을때마다 얼마나 부어질지를 따로 생각하는 것이다.

이러한 트리가 있다고 해보자.
6번 노드에 물을 부으려면 1 - 4 - 6 순서로 1, 2, 3 만큼의 물을 부어야 한다.
항상 루트를 첫번째로, 목표 노드를 마지막으로 물을 붓게 되고..
이 경로에 포함되는 노드에 부어지는 물의 양은 루트로부터의 거리, 즉 depth임을 알 수 있다.
노드 별로 항상 같은 양만큼 물이 부어지니, 노드 당 물이 부어진 횟수만 알면 임의의 노드에 담긴 물의 총량을 알 수 있다.
위의 사진의 트리에서 4번 노드에 물이 부어지는 경우는
4번 노드에 직접 물이 부어지는 경우, 6번이나 7번 노드에 물이 부어지는 경우가 있다.
이 점에서 우리는 어떤 노드에 물이 부어지는 횟수는 해당 노드의 자식노드의 방문횟수 + 해당노드의 방문횟수 임을 알 수 있다.
한마디로, 물이 부어지는 횟수는 해당 노드를 루트로 갖는 서브트리의 모든 노드의 방문횟수이다.
방문횟수를 구하는 합세그를 만들고, ETT로 서브트리 쿼리를 돌리면 될 것 같다.
부분 코드 설명
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAX = 200'005;
vector<ll> in(MAX);
vector<ll> out(MAX);
vector<ll> depth(MAX);
ll seg[MAX * 4 + 16];
vector<int> g[MAX];
vector<bool> visited(MAX, 0);
선언부이다.
ETT에 사용할 in과 out, 노드의 깊이를 나타낼 depth를 만들어준다.
구간합을 위한 세그트리를 만들고, 원본 트리를 나타낼 g(graph), dfs 중 방문상태를 저장할 visited를 만든다.
int cnt = 0;
void dfs(int cur, int dep) {
in[cur] = ++cnt;
depth[cur] = dep;
visited[cur] = 1;
for (auto& next : g[cur]) {
if (visited[next] != 1) {
dfs(next, dep + 1);
}
}
out[cur] = cnt;
}
ETT를 위해 dfs를 수행하는 함수이다.
2025.10.22 - [백준] - [백준] 14268 회사 문화 2 - C++
이 글의 구현과 비슷하지만, depth를 구하는 부분이 추가되었고,
문제에서 그래프를 양방향으로 주기 때문에 visited를 체크하는 부분이 추가되었다.
물을 부은 횟수를 세기 위한 간단한 합세그를 만든다.
void update(ll idx, ll left, ll right, ll cidx, ll cval) {
if (cidx < left || right < cidx) {
return;
}
if (left == right) {
seg[idx] += cval;
return;
}
ll mid = (left + right) / 2;
update(idx * 2, left, mid, cidx, cval);
update(idx * 2 + 1, mid + 1, right, cidx, cval);
seg[idx] = seg[idx * 2] + seg[idx * 2 + 1];
return;
}
ll query(ll idx, ll left, ll right, ll qleft, ll qright) {
if (qright < left || right < qleft) {
return 0;
}
if (qleft <= left && right <= qright) {
return seg[idx];
}
ll mid = (left + right) / 2;
return query(idx * 2, left, mid, qleft, qright) + query(idx * 2 + 1, mid + 1, right, qleft, qright);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n, m, r;
cin >> n >> r;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(r, 1); //<1>
cin >> m;
for (int i = 0; i < m; i++) {
int command, a;
cin >> command >> a;
if (command == 1) {
update(1, 1, n, in[a], 1); //<2>
}
else {
cout << query(1, 1, n, in[a], out[a]) * depth[a] << "\n"; //<3>
}
}
return 0;
}
main이다. 딱히 특별한게 엄청 있진 않지만..
<1> : ETT를 돌린다! 입력으로 받은 수도를 루트로 돌려주어야 하기 때문에 인자를 잘 넣어주도록 하자.
<2> : 위에서 말했듯 경로에 포함되는 노드의 방문횟수는 구간합으로 구해지기에 경로의 끝 노드에만 +1 해준다.
<3> : 이것도 위에서 말했듯, 임의의 노드에 담겨있는 물은 (방문횟수)*(depth) 이기 때문에,
서브트리의 구간합으로 방문횟수를 구해주고, dfs를 돌리며 만들어 둔 depth배열에서 값을 꺼내와 곱해주면 된다.
전체 코드
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAX = 200'005;
vector<ll> in(MAX);
vector<ll> out(MAX);
vector<ll> depth(MAX);
ll seg[MAX * 4 + 16];
vector<int> g[MAX];
vector<bool> visited(MAX, 0);
int cnt = 0;
void dfs(int cur, int dep) {
in[cur] = ++cnt;
depth[cur] = dep;
visited[cur] = 1;
for (auto& next : g[cur]) {
if (visited[next] != 1) {
dfs(next, dep + 1);
}
}
out[cur] = cnt;
}
void update(ll idx, ll left, ll right, ll cidx, ll cval) {
if (cidx < left || right < cidx) {
return;
}
if (left == right) {
seg[idx] += cval;
return;
}
ll mid = (left + right) / 2;
update(idx * 2, left, mid, cidx, cval);
update(idx * 2 + 1, mid + 1, right, cidx, cval);
seg[idx] = seg[idx * 2] + seg[idx * 2 + 1];
return;
}
ll query(ll idx, ll left, ll right, ll qleft, ll qright) {
if (qright < left || right < qleft) {
return 0;
}
if (qleft <= left && right <= qright) {
return seg[idx];
}
ll mid = (left + right) / 2;
return query(idx * 2, left, mid, qleft, qright) + query(idx * 2 + 1, mid + 1, right, qleft, qright);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n, m, r;
cin >> n >> r;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(r, 1);
cin >> m;
for (int i = 0; i < m; i++) {
int command, a;
cin >> command >> a;
if (command == 1) {
update(1, 1, n, in[a], 1);
}
else {
cout << query(1, 1, n, in[a], out[a]) * depth[a] << "\n";
}
}
return 0;
}
구간합 아이디어를 떠올리는게 약간 어려웠다...
'PS' 카테고리의 다른 글
| [Codeforces] Round 1065 (Div. 3) 업솔빙 (0) | 2025.11.27 |
|---|---|
| [백준] 13510 트리와 쿼리 1 - C++ (0) | 2025.10.24 |
| [백준] 14268 회사 문화 2 - C++ (0) | 2025.10.22 |
| [백준] 1517 버블 소트 - C++ (3) | 2025.08.18 |
| [백준] 12738 가장 긴 증가하는 부분 수열 3 - C++ (3) | 2025.08.03 |