1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| #include <bits/stdc++.h> using namespace std; typedef long long LL; #define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i) #define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i) #ifdef zerol #define dbg(args...) do { cout << "\033[32;1m" << #args<< " -> "; err(args); } while (0) #else #define dbg(...) #endif void err() { cout << "\033[39;0m" << endl; } template<typename T, typename... Args> void err(T a, Args... args) { cout << a << ' '; err(args...); } // ----------------------------------------------------------------------------- const int maxn = 1E5 + 10; const int INF = 1E9; bool f[maxn]; int n, K, sz[maxn], SUM, D; struct E { int to, d; }; vector<E> G[maxn];
int dfs(int u, int fa) { int m1 = f[u] ? 0 : -INF, m2 = -INF; sz[u] = f[u]; for (E& e: G[u]) { int v = e.to; if (v == fa) continue; m2 = max(m2, dfs(v, u) + e.d); if (m2 > m1) swap(m1, m2); sz[u] += sz[v]; if (0 < sz[v] && sz[v] < K) SUM += e.d; } D = max(D, m1 + m2); if (f[u]) D = max(D, m1); return m1; }
int main() { int u, v, d; cin >> n >> K; FOR (i, 0, K) { scanf("%d", &u); f[u] = 1; } FOR (_, 1, n) { scanf("%d%d%d", &u, &v, &d); G[u].push_back({v, d}); G[v].push_back({u, d}); } dfs(1, -1); cout << SUM * 2 - D << endl; }
|