Bit Sequence 数位dp

题目大意:

定义 \(f(x)\) 表示 \(x\) 二进制表示的 \(1\) 的数量。给你n个是0或者1的数,再给你一个 \(L\),问在区间 \([0,L]\) 之间有多少个数 \(x\) 满足 \(∀i∈[0,m−1],f(x+i) \,mod\,2=a_i\)

题解:

很明显是一个数位 \(dp\) ,但是如何定义这个状态呢?

  • 因为这个需要比较的数只有100个,也就是 \(i\) 的最大贡献是100,所以可以拆出前面 7 位来看
  • 因为前面7位可能会进位,所以需要定义一个 \(x\) 表示从第7位往后连续1的数量的奇偶性
  • 要计算这个数的奇偶性还需要计算7位往后的1的数量的奇偶性
  • 最后还需要一个 \(pos\) 位表示此时是第几位
  • 因为每一个 \(L\) 都是不同的,所以每一次都需要清空这个数组,这个时候如果加一个 \(limit\) 放在数组中,可以有很大的优化,所以还需要一个 \(limit\)

最后的状态就是 :\(dp[64][2][128][2][2]\)

\(dp[pos][sum][sta][even][limit]\)

我觉得这个难的地方可能在于发现用前7位来表示一个状态。

#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int maxn = 2e5+10;
ll dp[64][2][128][2][2],L;
bool vis[2000000];
int n,a[maxn],v[maxn];
int judge(int x){
    int ans = 0;
    while(x) {
        if(x&1) ans^=1;
        x>>=1;
    }
    return ans;
}
bool check(int sum,int sta,int even){
    int S = 128;
    for(int i=0;i<n;i++){
        int num,cur = sta+i;
        if(sta+i>=S) {
            cur -= S;
            num = even - sum + judge(cur) + 1;
        }
        else  num = even + judge(cur);
        num = (num%2+2)%2;
        if(num!=a[i]) return 0;
    }
    return 1;
}
bool ok[2][128][2];
void init(){
    for(int i=0;i<2;i++){
        for(int j=0;j<128;j++){
            for(int k=0;k<2;k++){
                ok[i][j][k] = check(i,j,k);
            }
        }
    }
}
ll dfs(int pos,int sum,int sta,int even,int limit,int now){
    if(pos==-1) {
        return ok[sum][sta][even];
    }
    if(dp[pos][sum][sta][even][limit]!=-1) return dp[pos][sum][sta][even][limit];
    int up = limit?v[pos]:1;
    ll ans = 0;
    for(int i=0;i<=up;i++){
        if(pos>=7){
            int x = 0;
            if(i) x = (sum + 1)%2;
            ans += dfs(pos-1,x,sta,(even+i)%2,i==up&&limit,now*2+i);
        }
        else{
            int x = 0;
            if(i) x = 1<<pos;
            ans+=dfs(pos-1,sum,sta+x,even,i==up&&limit,now*2+i);
        }
    }
    return dp[pos][sum][sta][even][limit] = ans;
}
ll solve(ll x){
    int pos = 0;
    while(x){
        v[pos++] = x&1;
        x>>=1;
    }
    return dfs(pos-1,0,0,0,1,0);
}
// 用来测试,debug
ll solution(){
    int ans = 0;
    for(int i=0;i<=L;i++){
        bool flag = true;
        for(int j=1;j<=n;j++){
            if(a[j]!=judge(i+j)){
                flag = false;
                break;
            }
        }
        ans+=flag;
        vis[i] = flag;
    }
    return ans;
}
int main(){
    int T;
    scanf("%d",&T);
    while(T--){
        memset(dp,-1,sizeof(dp));
        scanf("%d%lld",&n,&L);
        for(int i=0;i<n;i++) scanf("%d",&a[i]);
        init();
        ll ans = solve(L);
        printf("%lld\n",ans);
    }
}