题目链接 戳我
最短路DAG上计数
经过点的最短路计数和经过边的最短路计数实在不能一概而论。。。。qwqwqwqwqwq
在做这道题之前,让我先捞上一张来自xyz32768大佬的图
对于一条边(x,y)——求S到x的最短路径个数是很好求的一件事,对于一条边两个端点u,v,sum[v]+=sum[u]。那么顺序应当怎么确定呢?拓扑排序啊!但是怎么保证正确性呢?没有环啊!
求上面那个好求,但是求y到T的最短路径个数就不太容易了。。。因为我们必须保证这些最短路都经过S->x,x->y,而不是从S直接到了y。还按照上面那个方法的话是有后效性的,无法DP。
但是!但是!我们没有办法正着DP,难道就不能反过来了吗!sum[u]+=sum[v]不就可以了吗?
至此,大功告成。
代码如下:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#define MAXN 100010
#define mod 1000000007
using namespace std;
int n,m,t,tot,ll,rr;
int cnt[MAXN],cnt1[MAXN],cnt2[MAXN],done[MAXN];
int head[MAXN],ans[MAXN],ins[MAXN],dis[MAXN],pq[MAXN];
struct Edge{int nxt,to,dis;}edge[MAXN<<1];
struct Line{int u,v,w;}a[MAXN];
int q[MAXN];
inline void add(int from,int to,int dis)
{
edge[++t].nxt=head[from],edge[t].to=to;
edge[t].dis=dis,head[from]=t;
}
inline void spfa(int x)
{
memset(dis,0x3f,sizeof(dis));
memset(ins,0,sizeof(ins));
ll=1,rr=0;
q[++rr]=x;dis[x]=0;done[x]=1;
while(ll<=rr)
{
int u=q[ll++];done[u]=0;
for(int i=head[u];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(dis[u]+edge[i].dis<dis[v])
{
dis[v]=dis[u]+edge[i].dis;
if(!done[v])
q[++rr],done[v]=1;
}
}
}
for(int i=1;i<=m;i++)
if(dis[a[i].u]+a[i].w==dis[a[i].v])
ins[i]=1;
}
inline void topo_sort(int x)
{
memset(cnt,0,sizeof(cnt));
memset(cnt1,0,sizeof(cnt1));
memset(cnt2,0,sizeof(cnt2));
ll=1,rr=0,tot=0;
cnt1[x]=1;
q[++rr]=x;
for(int i=1;i<=m;i++) if(ins[i]) cnt[a[i].v]++;
while(ll<=rr)
{
int u=q[ll++];
pq[++tot]=u;
for(int i=head[u];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(!ins[i]) continue;
cnt[v]--;
if(!cnt[v]) q[++rr]=v;
cnt1[v]=(cnt1[v]+cnt1[u])%mod;
}
}
for(int i=tot;i;i--)
{
int u=pq[i]; cnt2[u]++;
for(int j=head[u];j;j=edge[j].nxt)
{
if(!ins[j]) continue;
int v=edge[j].to;
cnt2[u]=(cnt2[u]+cnt2[v])%mod;
}
}
}
inline void print()
{
for(int i=1;i<=n;i++) printf("ins[%d]=%d\n",i,ins[i]);
for(int i=1;i<=n;i++) printf("dis[%d]=%d\n",i,dis[i]);
for(int i=1;i<=n;i++) printf("cnt1[%d]=%d\n",i,cnt1[i]);
for(int i=1;i<=n;i++) printf("cnt2[%d]=%d\n",i,cnt2[i]);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("ce.in","r",stdin);
#endif
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
a[i].u=u,a[i].v=v,a[i].w=w;
add(u,v,w);
}
for(int i=1;i<=m;i++)
{
printf("i=%d\n",i);
spfa(i);
topo_sort(i);
print();
for(int i=1;i<=m;i++)
if(ins[i])
ans[i]=(ans[i]+1ll*cnt1[a[i].u]*cnt2[a[i].v]%mod)%mod;
}
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}
呜呜呜好吧我承认这个代码在BZOJ上没有办法过。。。会T一个点。。。。。。但是把STL队列手写一下。。应该就可以了。。。。。。
不开O2惨无人道啊qwqwq