Monster Hunter 树形DP

题目大意:

给你一棵n个节点的树,每一个节点有一个怪兽,这个怪兽有一个生命值。

消灭一个怪兽:

  • 前提:这个节点的直接父亲节点怪兽已经被消灭
  • 代价:这个节点的权值+它所有的子节点怪兽的生命值之和

特别的是,你可以使用魔咒,每一次魔咒使用,你可以直接消灭一个怪兽,而不会有任何的花费,也不需要任何的前提。

m = 0,1,2,3...,n。问使用m次魔咒,消灭所有怪兽的最小的代价是多少?

题解:

这个题目一开始我们队想用贪心,后来发现贪心是错误的,然后就尝试使用dp来解决。

\(dp[u][j][0/1]\) :表示u这个节点,使用了 j 次魔法,0表示u节点没有使用魔法,1表示有使用魔法。

这个转移我写了很久,一直写不出来,后来发现是需要一个辅助数组,每次需要辅助数组的题目我好像都没有写出来,感觉好可惜。

#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
const int maxn = 2e3+10;
typedef long long ll;
int cnt,head[maxn],to[maxn<<1],nxt[maxn<<1];
ll hp[maxn],dp[maxn][maxn][2];
void init(int n){
    cnt = 0;
    for(int i=0;i<=n+10;i++) {
        head[i] = 0;
        for(int j=0;j<=n+10;j++){
            dp[i][j][0] = dp[i][j][1] = 0;
        }
    }
}
void add(int u,int v){
    ++cnt,to[cnt] = v,nxt[cnt] = head[u],head[u] = cnt;
}
int siz[maxn];
ll tmp[maxn][2];
void dfs(int u){
    siz[u] = 1;
    for(int i=head[u];i;i=nxt[i]){
        int v = to[i];
        dfs(v);
        memset(tmp,inf64,sizeof(tmp));
        for(int j=0;j<=siz[u];j++){
            for(int k=0;k<=siz[v];k++){
                tmp[j+k][0] = min(tmp[j+k][0],dp[u][j][0]+dp[v][k][0]+hp[v]);
                if(k>0) tmp[j+k][0] = min(tmp[j+k][0],dp[u][j][0]+dp[v][k][1]);
                if(j>0) tmp[j+k][1] = min(tmp[j+k][1],dp[u][j][1]+dp[v][k][0]);
                if(j>0&&k>0) tmp[j+k][1] = min(tmp[j+k][1],dp[u][j][1]+dp[v][k][1]);
//                printf("u = %d v = %d j = %d k = %d \n",u,v,j,k);
//                printf("tmp[%d][%d]=%lld tmp[%d][%d]=%lld\n",j+k,0,tmp[j+k][0],j+k,1,tmp[j+k][1]);
            }
        }
        for(int j=0;j<=siz[u]+siz[v];j++) {
            dp[u][j][0] = tmp[j][0],dp[u][j][1] = tmp[j][1];
//            printf("dp[%d][%d][%d]=%lld dp[%d][%d][%d]=%lld\n",u,j,0,dp[u][j][0],u,j,1,dp[u][j][1]);
        }
        siz[u]+=siz[v];
    }
    for(int i=0;i<=siz[u];i++) {
        dp[u][i][0]+=hp[u];
//        printf("u = %d dp[%d][%d][%d] = %lld dp[%d][%d][%d]=%lld\n",u,u,i,0,dp[u][i][0],u,i,1,dp[u][i][1]);
    }
}
int main(){
    int T;
    scanf("%d",&T);
    while(T--){
        int n;
        scanf("%d",&n);init(n);
        for(int i=2,x;i<=n;i++){
            scanf("%d",&x);
            add(x,i);
        }
        for(int i=1;i<=n;i++) scanf("%lld",&hp[i]);
        dfs(1);
        for(int i=0;i<=n;i++){
            printf("%lld",min(dp[1][i][0],dp[1][i][1]));
            if(i==n) printf("\n");
            else printf(" ");
        }
    }
}