题目
Description
Input
Output
一行一个整数表示答案。
Sample Input
2 2
2 3 1
Sample Output
26
Data Constraint
思路
代码
#include<bits/stdc++.h>
#define si size()
#define re resize
#define ll long long
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i < _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
using namespace std;
const int mod = 998244353;
ll power(ll x,ll t)
{
ll b=1;
while(t)
{
if(t&1) b=b*x%mod;
x=x*x%mod; t>>=1;
}
return b;
}
namespace ntt
{
const int nm = 1 << 18;
ll w[nm],a[nm],b[nm];
int r[nm];
void build()
{
for(int i = 1; i < nm; i *= 2)
{
w[i] = 1; ll v = power(3,(mod - 1) / 2 / i);
for(int j=1; j<i; j++) w[i + j] = w[i + j - 1] * v % mod;
}
}
void dft(ll *a,int n,int f)
{
for(int i=0; i<=n; i++)
{
r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
if(i < r[i]) swap(a[i],a[r[i]]);
}
ll v;
for(int i = 1; i < n; i<<=1) for(int j = 0; j < n; j+=i<<1) for(int k=0; k<i; k++)
{
v = a[i + j + k] * w[i + k],a[i + j + k] = (a[j + k] - v) % mod,a[j + k] = (a[j + k] + v) % mod;
}
if(f == -1)
{
reverse(a + 1,a + n);
v = power(n,mod - 2);
for(int i=0; i<=n; i++) a[i] = (a[i] + mod) * v % mod;
}
}
vector<ll> operator * (vector<ll> p,vector<ll> q) {
int n0 = p.si + q.si - 1,n = 1;
while(n < n0) n *= 2;
ff(i,0,n) a[i] = b[i] = 0;
ff(i,0,p.si) a[i] = p[i];
ff(i,0,q.si) b[i] = q[i];
dft(a,n,1); dft(b,n,1);
ff(i,0,n) a[i] = a[i] * b[i] % mod;
dft(a,n,-1);
p.re(n0);
ff(i,0,n0) p[i] = a[i];
return p;
}
void dft(vector<ll> &p,int f) {
int n = p.si;
ff(i,0,n) a[i] = p[i];
dft(a,n,f);
ff(i,0,n) p[i] = a[i];
}
}
using ntt :: operator *;
using ntt :: dft;
vector<ll> operator - (vector<ll> a,vector<ll> b) {
a.re(max(a.si,b.si));
ff(i,0,a.si) a[i] = (a[i] - b[i]) % mod;
return a;
}
vector<ll> qni(vector<ll> a) {
int n0 = a.si;
vector<ll> b; b.re(1); b[0] = power(a[0],mod - 2);
for(int n = 2; n < n0 * 2; n *= 2) {
vector<ll> c = a; c.re(n); c.re(2 * n);
b.re(2 * n);
dft(c,1); dft(b,1);
ff(i,0,2 * n) b[i] = (2 * b[i] - c[i] * b[i] % mod * b[i]) % mod;
dft(b,-1); b.re(n);
}
b.re(n0); return b;
}
vector<ll> qmod(vector<ll> a,vector<ll> b) {
int n = a.si - 1,m = b.si - 1;
if(n < m) return a;
vector<ll> a0 = a,b0 = b;
reverse(a.begin(),a.end());
reverse(b.begin(),b.end());
b.re(a.si);
vector<ll> c = a * qni(b);
c.re(n - m + 1);
reverse(c.begin(),c.end());
vector<ll> d = a0 - b0 * c;
d.re(m);
return d;
}
const int N = 1e5 + 5;
int n,m; ll f[N];
ll fac[N],nf[N],h[N];
void build(int n) {
fac[0] = 1; fo(i,1,n) fac[i] = fac[i - 1] * i % mod;
nf[n] = power(fac[n],mod - 2); fd(i,n,1) nf[i - 1] = nf[i] * i % mod;
}
void build_h() {
h[0] = 1;
fo(i,1,m) h[i] = h[i - 1] * (n - i + 1) % mod;
fo(i,0,m) h[i] = h[i] * nf[i] % mod;
}
ll C(int n,int m) {
if(n < m) return 0;
return fac[n] * nf[n - m] % mod * nf[m] % mod;
}
ll g[N];
vector<ll> t[N * 4];
#define i0 i + i
#define i1 i + i + 1
void dg(int i,int x,int y) {
if(x == y) {
t[i].re(2);
t[i][0] = -x; t[i][1] = 1;
return;
}
int m = x + y >> 1;
dg(i0,x,m); dg(i1,m + 1,y);
t[i] = t[i0] * t[i1];
}
ll p[N];
void fz(vector<ll> a,int i,int x,int y) {
a = qmod(a,t[i]);
if(y - x + 1 < 128) {
fo(j,x,y) {
ll s = 0,v = 1;
ff(k,0,a.si) {
s = (s + v * a[k]) % mod;
v = v * j % mod;
}
p[j] = s;
}
return;
}
if(x == y) {
p[x] = a[0];
return;
}
int m = x + y >> 1;
fz(a,i0,x,m); fz(a,i1,m + 1,y);
}
int main() {
ntt :: build();
freopen("number.in","r",stdin);freopen("number.out","w",stdout);
scanf("%d %d",&n,&m);
fo(i,0,m) scanf("%lld",&f[i]);
build(m);
build_h();
ll ans = f[0] * power(2,n) % mod;
fo(j,1,m) g[j] = fac[j] * power(2,n - j) % mod * h[j] % mod;
vector<ll> a,b; a.clear(); b.clear();
a.re(m + 1); b.re(m + 1);
dg(1,1,m);
vector<ll> c; c.re(m + 1);
fo(i,1,m) c[i] = f[i];
fz(c,1,1,m);
fo(i,1,m) {
a[i] = nf[i] * p[i] % mod;
}
fo(i,0,m) b[i] = nf[i] * (i % 2 ? -1 : 1);
a = a * b;
fo(i,1,m) ans = (ans + a[i] * g[i]) % mod;
ans = (ans % mod + mod) % mod;
printf("%lld\n",ans);
}