我们通常采用递归的方式实现树形dp。
对于每个节点,先递归在它的每个子节点上进行dp,在回溯时,从子节点向根节点进行状态转移。
顺序一般为从叶子结点到根节点递推。
以下是写的一些树形dp题目:
一. P1352 没有上司的舞会
以子树的根作为dp状态的第一维。容易发现,每个员工是否参加至于他的上司是否参加有关。
不妨设
1. f[x,0] 表示从以 x 为根的子树中选员工参会,而且不选 x 的最大价值。
那么我们得到:x的子节点 s 可以选,也可以不选,那么 f[x,0] 即为子节点选或不选的最大值。
f[x,0]=∑max( f[s,0] , f[s,1] )
2. f[x,1] 表示选 x 的最大价值。
此时发现:f[x,1]只能等于f[s,0]+自己的快乐指数。
f[x,1]=∑f[s,0] + h[x]
于是基本解决了这道题。
#include<bits/stdc++.h>
using namespace std;
vector<int>son[10010];
int f[10010][2],v[10010],h[10010],n;
void dp(int x)
{
f[x][0]=0;
f[x][1]=h[x];
for(int i=0;i<son[x].size();i++)
{
int y=son[x][i];
dp(y);
f[x][0]+=max(f[y][0],f[y][1]);
f[x][1]+=f[y][0];
}
}
//注意:需要先找到总上司。
int main()
{
cin>>n;
for(int i=1;i<=n;i++) cin>>h[i];
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
v[x]=1;
son[y].push_back(x);
}
int u;
for(int i=1;i<=n;i++) if(!v[i]) u=i;
dp(u);
cout<<max(f[u][0],f[u][1]);
return 0;
}
二.背包类树形dp(树上背包)
参考树形 DP - OI Wiki (oi-wiki.org)
三.换根dp。
换根dp在模拟赛的时候到处跑。想出来就是萌萌题。
[POI2008]STA-Station 作为例子。
我们需要进行两次dfs。
dfs1,处理出每个点的深度以及子树中点的个数。
考虑每次换根的时候子树的节点数变化。
观察。
设根节点从 x 变为 y ,那么在 y 的子树内的节点深度都将+1,而其他节点深度都将-1。
先处理出当1为根节点时的总深度 f[1] 作为初始状态。那么对于每次换根操作,f[y]=f[x]+n-2*s[y]。
这样我们的换根操作就解决了,即dfs2。
最后统计答案即可。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1000010;
int h[N<<1],ne[N<<1],idx,e[N<<1];
ll n;
ll sz[N],dep[N],f[N];
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0' && ch<='9')
x=x*10+ch-'0',ch=getchar();
return x*f;
}
void add(int a,int b)
{
e[++idx]=b;
ne[idx]=h[a];
h[a]=idx;
}
void dfs1(int u,int fa)
{
sz[u] = 1;
dep[u] = dep[fa] + 1;
for(int i=h[u];i;i=ne[i])
{
int j=e[i];
if(j!=fa)
{
dfs1(j,u);
sz[u]+=sz[j];
}
}
}
void dfs2(int u, int fa) {
for (int i=h[u];i; i =ne[i]) {
int j = e[i];
if (j != fa) {
f[j] = f[u] - sz[j] * 2 + n;
dfs2(j, u);
}
}
}
signed main()
{
// memset(h,-1,sizeof h);
// memset(ne,-1,sizeof ne);
n=read();
for(int i=1;i<n;i++)
{
int x,y;
// cin>>x>>y;
x=read(),y=read();
add(x,y);
add(y,x);
}
dfs1(1,1);
for (int i = 1; i <= n; i++) f[1] += dep[i];
dfs2(1, 1);
long long int ans = -1;
int id;
for (int i = 1; i <= n; i++) {
if (f[i] > ans) {
ans = f[i];
id = i;
}
}
cout<<id;
return 0;
}