题目

Description

【JZOJ 省选模拟】6676. 查拉图斯特拉如是说_题组
【JZOJ 省选模拟】6676. 查拉图斯特拉如是说_题组_02
Input

【JZOJ 省选模拟】6676. 查拉图斯特拉如是说_题组_03
Output
一行一个整数表示答案。

Sample Input
2 2
2 3 1

Sample Output
26

Data Constraint
【JZOJ 省选模拟】6676. 查拉图斯特拉如是说_题组_04

思路

【JZOJ 省选模拟】6676. 查拉图斯特拉如是说_题组_05

代码

#include<bits/stdc++.h>
#define si size()
#define re resize
#define ll long long
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i <  _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
using namespace std;
const int mod = 998244353;
ll power(ll x,ll t)
{
	ll b=1;
	while(t)
	{
		if(t&1) b=b*x%mod;
		x=x*x%mod; t>>=1;
	}
	return b;
}

namespace ntt 
{
	const int nm = 1 << 18;
	ll w[nm],a[nm],b[nm];
	int r[nm];
	void build() 
	{
		for(int i = 1; i < nm; i *= 2) 
		{
			w[i] = 1; ll v = power(3,(mod - 1) / 2 / i);
			for(int j=1; j<i; j++) w[i + j] = w[i + j - 1] * v % mod;
		}
	}
	void dft(ll *a,int n,int f) 
	{
		for(int i=0; i<=n; i++) 
		{
			r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
			if(i < r[i]) swap(a[i],a[r[i]]);
		}
		ll v;
		for(int i = 1; i < n; i<<=1) for(int j = 0; j < n; j+=i<<1) for(int k=0; k<i; k++) 
		{
			v = a[i + j + k] * w[i + k],a[i + j + k] = (a[j + k] - v) % mod,a[j + k] = (a[j + k] + v) % mod;
		}
		if(f == -1) 
		{
			reverse(a + 1,a + n);
			v = power(n,mod - 2);
			for(int i=0; i<=n; i++) a[i] = (a[i] + mod) * v % mod;
		}
	}
	vector<ll> operator * (vector<ll> p,vector<ll> q) {
		int n0 = p.si + q.si - 1,n = 1;
		while(n < n0) n *= 2;
		ff(i,0,n) a[i] = b[i] = 0;
		ff(i,0,p.si) a[i] = p[i];
		ff(i,0,q.si) b[i] = q[i];
		dft(a,n,1); dft(b,n,1);
		ff(i,0,n) a[i] = a[i] * b[i] % mod;
		dft(a,n,-1);
		p.re(n0);
		ff(i,0,n0) p[i] = a[i];
		return p;
	}
	void dft(vector<ll> &p,int f) {
		int n = p.si;
		ff(i,0,n) a[i] = p[i];
		dft(a,n,f);
		ff(i,0,n) p[i] = a[i];
	}
}

using ntt :: operator *;
using ntt :: dft;

vector<ll> operator - (vector<ll> a,vector<ll> b) {
	a.re(max(a.si,b.si));
	ff(i,0,a.si) a[i] = (a[i] - b[i]) % mod;
	return a;
}

vector<ll> qni(vector<ll> a) {
	int n0 = a.si;
	vector<ll> b; b.re(1); b[0] = power(a[0],mod - 2);
	for(int n = 2; n < n0 * 2; n *= 2) {
		vector<ll> c = a; c.re(n); c.re(2 * n);
		b.re(2 * n);
		dft(c,1); dft(b,1);
		ff(i,0,2 * n) b[i] = (2 * b[i] - c[i] * b[i] % mod * b[i]) % mod;
		dft(b,-1); b.re(n);
	}
	b.re(n0); return b;
}

vector<ll> qmod(vector<ll> a,vector<ll> b) {
	int n = a.si - 1,m = b.si - 1;
	if(n < m) return a;
	vector<ll> a0 = a,b0 = b;
	reverse(a.begin(),a.end());
	reverse(b.begin(),b.end());
	b.re(a.si);
	vector<ll> c = a * qni(b);
	c.re(n - m + 1);
	reverse(c.begin(),c.end());
	vector<ll> d = a0 - b0 * c;
	d.re(m);
	return d;
}

const int N = 1e5 + 5;

int n,m; ll f[N];
ll fac[N],nf[N],h[N];

void build(int n) {
	fac[0] = 1; fo(i,1,n) fac[i] = fac[i - 1] * i % mod;
	nf[n] = power(fac[n],mod - 2); fd(i,n,1) nf[i - 1] = nf[i] * i % mod;
}

void build_h() {
	h[0] = 1;
	fo(i,1,m) h[i] = h[i - 1] * (n - i + 1) % mod;
	fo(i,0,m) h[i] = h[i] * nf[i] % mod;
}

ll C(int n,int m) {
	if(n < m) return 0;
	return fac[n] * nf[n - m] % mod * nf[m] % mod;
}

ll g[N];

vector<ll> t[N * 4];

#define i0 i + i
#define i1 i + i + 1
void dg(int i,int x,int y) {
	if(x == y) {
		t[i].re(2);
		t[i][0] = -x; t[i][1] = 1;
		return;
	}
	int m = x + y >> 1;
	dg(i0,x,m); dg(i1,m + 1,y);
	t[i] = t[i0] * t[i1];
}

ll p[N];

void fz(vector<ll> a,int i,int x,int y) {
	a = qmod(a,t[i]);
	if(y - x + 1 < 128)  {
		fo(j,x,y) {
			ll s = 0,v = 1;
			ff(k,0,a.si)  {
				s = (s + v * a[k]) % mod;
				v = v * j % mod;
			}
			p[j] = s;
		}
		return;
	}
	if(x == y) {
		p[x] = a[0];
		return;
	}
	int m = x + y >> 1;
	fz(a,i0,x,m); fz(a,i1,m + 1,y);
}

int main() {
	ntt :: build();
	freopen("number.in","r",stdin);freopen("number.out","w",stdout);
	scanf("%d %d",&n,&m);
	fo(i,0,m) scanf("%lld",&f[i]);
	build(m);
	build_h();
	ll ans = f[0] * power(2,n) % mod;
	fo(j,1,m) g[j] = fac[j] * power(2,n - j) % mod * h[j] % mod;
	vector<ll> a,b; a.clear(); b.clear();
	a.re(m + 1); b.re(m + 1);
	dg(1,1,m);
	vector<ll> c; c.re(m + 1);
	fo(i,1,m) c[i] = f[i];
	fz(c,1,1,m);
	fo(i,1,m) {
		a[i] = nf[i] * p[i] % mod;
	}
	fo(i,0,m) b[i] = nf[i] * (i % 2 ? -1 : 1);
	a = a * b;
	fo(i,1,m) ans = (ans + a[i] * g[i]) % mod;
	ans = (ans % mod + mod) % mod;
	printf("%lld\n",ans);
}