Description
给定一棵 nn 个节点的树,求它的最小覆盖集个数。
Solution
我们考虑同时维护节点的最小覆盖集和最小覆盖集个数。我们发现一个点被覆盖只有 种情况:父亲被选入集合,自己被选入集合和至少一个儿子被选入集合。我们写出 DPDP 状态:
⎧⎩⎨⎪⎪ f[u][i][0] 表示 u 的父亲被选入集合,自己和任意一个儿子不被选入集合 f[u][i][1] 表示 u 有至少一个儿子被选入集合 f[u][i][2] 表示 u 自己被选入集合{ f[u][i][0] 表示 u 的父亲被选入集合,自己和任意一个儿子不被选入集合 f[u][i][1] 表示 u 有至少一个儿子被选入集合 f[u][i][2] 表示 u 自己被选入集合
而 g[u][i]g[u][i] 表示 f[u][i]f[u][i] 情况下的方案数。
考虑转移 ff。
再转移 gg。
首先转移 。我们发现父亲节点状态为 00 时儿子节点状态必为 ,所以 g[u][0]=∏v∈son(u)g[v][1]g[u][0]=∏v∈son(u)g[v][1]
然后转移 g[2]g[2] 。我们发现 g[2]g[2] 的儿子可以随便选,所以只要用满足 min f[v][i]=f[v][k]min f[v][i]=f[v][k] 的 kk 更新即可。
最后转移 。我们发现 f[u][1]=min(f[u][1]+min(f[v][1],f[v][2]),f[u][0]+f[v][2])f[u][1]=min(f[u][1]+min(f[v][1],f[v][2]),f[u][0]+f[v][2])。于是使用这几种状态分别计算 f[u][1]f[u][1] 是不是等于它,如果是就直接计算方案数即可。细节较多,这里不列举。详见代码。
Code
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int maxn = 1000005;
const int maxm = 2000005;
const int P = 1000000007;
int n, f[maxn][3], g[maxn][3];
int m, ter[maxm], nxt[maxm], lnk[maxn];
// 0: up, 1: down, 2: self;
void addedge(int u, int v) {
ter[++m] = v;
nxt[m] = lnk[u];
lnk[u] = m;
}
void update(int &a, int b) {
a += b, (a >= P) && (a -= P);
}
void dfs(int u, int p) {
int t, mn;
f[u][1] = n + 1, f[u][2] = g[u][0] = g[u][1] = g[u][2] = 1;
for (int i = lnk[u]; i; i = nxt[i]) {
int v = ter[i];
if (v == p) continue;
dfs(v, u);
t = 0;
mn = min(f[v][0], min(f[v][1], f[v][2]));
if (f[v][0] == mn) update(t, g[v][0]);
if (f[v][1] == mn) update(t, g[v][1]);
if (f[v][2] == mn) update(t, g[v][2]);
f[u][2] += mn;
g[u][2] = 1ll * g[u][2] * t % P;
t = 0;
mn = min(f[u][1] + min(f[v][1], f[v][2]), f[u][0] + f[v][2]);
if (f[u][1] + f[v][1] == mn) update(t, g[v][1]);
if (f[u][1] + f[v][2] == mn) update(t, g[v][2]);
f[u][1] = mn;
g[u][1] = 1ll * t * g[u][1] % P;
if (f[u][0] + f[v][2] == mn) update(g[u][1], 1ll * g[u][0] * g[v][2] % P);
f[u][0] += f[v][1];
g[u][0] = 1ll * g[u][0] * g[v][1] % P;
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addedge(u, v);
addedge(v, u);
}
dfs(1, 0);
if (f[1][1] < f[1][2]) {
printf("%d\n%d\n", f[1][1], g[1][1]);
} else if (f[1][1] > f[1][2]) {
printf("%d\n%d\n", f[1][2], g[1][2]);
} else {
printf("%d\n%d\n", f[1][1], (g[1][1] + g[1][2]) % P);
}
return 0;
}