Codeforces 437D: The Child and Zoo —— 并查集

题目链接

题目大意:

给出一张 $n$ 个点, $m$ 条边的图每个节点有一个权值 $a_i$ 。定义 $f(p,\ q)$ 为节点 $p$ 到节点 $p$ 简单路径上经过的所有节点的最小权值。求所有 $f(p, q)$ ( $p \not = q$ ) 的平均值。

输入格式:

第一行两个整数 $n,\ m$ ( $2 \leq n \leq 10^5,\ 0 \leq m \leq 10^5$ )。

接下来一行 $n$ 个整数,第 $i$ 个整数代表第 $i$ 个节点的权值 $a_i$ ( $0 \leq a_i \leq 10^5$ )。

接下来 $m$ 行,每行两个整数 $x,\ y$ ( $1 \leq x,\ y \leq n,\ x \not = y$ ),代表 $x$ 到 $y$ 有一条无向边。

输出格式:

输出一行一个整数,代表 $\frac{\sum_{p, q, p \not = q}^{}{f(p,\ q)}}{n(n-1)}$ 。

精度误差在 $10^{-4}$ 内。

测试样例:

Input
4 3
10 20 30 40
1 3
2 3
4 3
Output
16.666667
Input
3 3
10 20 30
1 2
2 3
3 1
Output
13.333333
Input
7 8
40 20 10 30 20 50 40
1 2
2 3
3 4
4 5
5 6
6 7
1 4
5 7
Output
18.571429

分析:

由于这道题 $n$ 最大有 $10^5$ ,所以是不能两两枚举节点的。所以这道题我们可以考虑用并查集

首先我们要将点权转化为边权。因为 $f(p,\ q)$ ,是路径上最小的点权,所以对于每一条边,只有点权更小的端点才可能影响到 $f(p,\ q)$ 的值,所以我们可以把每条边的边权设置为两个端点中权值更小的点的点权。

然后我们把边按边权从大到小排序。为什么从大到小排序呢?从大到小可以保证在选择一条新的边加入并查集的时候,这条边的权值一定是之前所有边里最小的。这样我们只需要统计两个联通块中一共有多少条路径经过这道边就好。怎么统计呢?乘法原理,联通块 A 的节点数 乘上 联通块 B 的节点数。最后再合并这两个联通块。

统计分子的 $ans$ 最后要 $\times 2$ ,因为 $f(p,\ q) = f(q,\ p)$ 。

还要注意不要爆 int 了。

代码:

#include <bits/stdc++.h>
using namespace std;

template<class _T> inline void read(_T &_x) {
  int _t, _flag = 0;
  while ((_t = getchar()) != '-' && (_t < '0' || _t > '9'));
  if (_t == '-')  _t = getchar(), _flag = 1;
  _x = _t - '0';
  while ((_t = getchar()) >= '0' && _t <= '9')
    _x = _x * 10 + _t - '0';
  if (_flag)  _x = -_x;
}

typedef long long ll;
typedef pair<int, int> pii;

const int MAX_N = 1e5 + 100;
struct E {
  int u, v, w;
  bool operator < (E A) const {
    return w > A.w;
  }
} e[MAX_N];
int father[MAX_N], n, m, a[MAX_N], cnt[MAX_N];
double ans = 0;

void init() {
  for (int i = 1; i <= n; ++i)
    father[i] = i;
}

int _find(int x) {
  if (father[x] == x) return x;
  else return father[x] = _find(father[x]);
}


void solve() {
  sort(e + 1, e + m + 1);
  for (int i = 1; i <= m; ++i) {
    int f1 = _find(e[i].u), f2 = _find(e[i].v);
    if (_find(e[i].u) != _find(e[i].v)) {
      ans += (double) e[i].w * cnt[f1] * cnt[f2];
      cnt[f1] += cnt[f2];
      father[f2] = f1;
    }
  }
  printf("%.6f", ans * 2 / (n * 1.0 * (n - 1)));
}

int main() {
  read(n), read(m);
  init();
  for (int i = 1; i <= n; ++i) {
    read(a[i]);
    cnt[i] = 1;
  }
  for (int i = 1; i <= m; ++i) {
    read(e[i].u), read(e[i].v);
    e[i].w = min(a[e[i].u], a[e[i].v]);
  }
  solve();
}