题目

Description
【JZOJ 省选模拟】6640. Lowbit_题组

Input
一行一共 5 个整数 p, a, b, L, R,意义如题。

Output
【JZOJ 省选模拟】6640. Lowbit_题组_02

Sample Input
Sample Input1:
2 1 2 1 7

Sample Input2:
2 1 3 1 10

Sample Input3:
5 4 7 10 100

Sample Output
Sample Output1:
19
解释:
f[1] = 2/1, f[2] = 2/1, f[3] = 3/1, f[4] = 2/1, f[5] = 7/2, f[6] = 3/1, f[7] = 7/2加起来 = 19。

Sample Output2:
110916062

Sample Output3:
977096547

Data Constraint
【JZOJ 省选模拟】6640. Lowbit_题组_03

思路

考虑p=2只求一个数怎么做

发现 x 会走到的点只有 log个。
设 dp[i][0/1] 表示 x 的前 i 位 = 0 了,0/1 表示是否有进位,这个数的 f 值。
初值dp[inf][1]=1/(1-a/b)
再套一个数位dp+整体转移即可

考虑如何推广到p>=2

可以发现两者的dp是一样的
我们只需要预处理出
p0[x]:最低位 = x,最后到 0 的概率。
p1[x]:最低位 = x,最后进位的概率。
f0[x]:最低位 = x,到 0 或进位前走的期望步数。

代码

#include<bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define mod 998244353
#define ll long long
using namespace std;
const int N=1e5+77;
ll s,A,B,L,R,ans,S,f[N],g[N],p1[N],p2[N];
int p,i,j,k,l,tot;
map<ll,ll> hs;
map<ll,ll>::iterator I;
ll power(ll x,ll t)
{
	ll b=1;
	while(t)
	{
		if(t&1) b=b*x%mod;
		x=x*x%mod; t>>=1;
	}
	return b;
}
ll work(ll t)
{
	ll sum=0;
	int i;
	if(t<p) return g[t];
	I=hs.find(t);
	if(I!=hs.end()) return hs[t];
	sum=work(t/p);
	fo(i,1,p-1) sum=(sum+(work((t+(p-i))/p)*p2[i]+work((t-i)/p)*p1[i]+((t-i)/p+1)%mod*f[i])%mod)%mod;
	hs[t]=sum;
	return sum;
}
void work1(int t,ll k,ll b)
{
	if(t==p) return;
	ll S=power(1-(1-s)*k%mod,(mod-2));
	work1(t+1,s*S%mod,(1+b*(1-s)%mod)*S%mod);
	f[t]=(k*f[t+1]+b)%mod;
}
void work2(int t,ll k,ll b)
{
	if(t==p) return;
	ll S=power(1-(1-s)*k%mod,(mod-2));
	work2(t+1,s*S%mod,b*(1-s)%mod*S%mod);
	p1[t]=(k*p1[t+1]+b)%mod;
}
void work3(int t,ll k,ll b)
{
	if(t==p) return;
	ll S=power(1-(1-s)*k%mod,(mod-2));
	work3(t+1,s*S%mod,b*(1-s)%mod*S%mod);
	p2[t]=(k*p2[t+1]+b)%mod;
}
int main()
{
//	freopen("lowbit.in","r",stdin); freopen("lowbit.out","w",stdout);
	scanf("%d%lld%lld%lld%lld",&p,&A,&B,&L,&R);
	s=A*power(B,(mod-2))%mod;
	if(p==2) f[1]=1,p1[1]=1-s,p2[1]=s,g[1]=power(1-s,(mod-2));
	else
	{
		p1[0]=p2[p]=1,work1(1,s,1),work2(1,s,1-s),work3(1,s,0);
		g[1]=f[1]*power(1-p2[1],(mod-2))%mod;
		fo(i,2,p-1) g[i]=(p2[i]*g[1]+f[i])%mod;
	}
	fo(i,1,p-1) g[i]=(g[i]+g[i-1])%mod;
	printf("%lld\n",((work(R)-work(L-1))%mod+mod)%mod);
}