连接后的新树的直径要么是原树的直径。。。要么是连接的两个点各自的最长链加1.。。。先预处理每个点的最长链,然后存储一颗树的最长链,另外一颗树遍历每个节点,二分+前缀和求出答案。。。
#include <iostream>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <climits>
#include <cstdlib>
#include <cmath>
#include <time.h>
#define maxn 100005
#define maxm 500005
#define eps 1e-7
#define mod 1000000007
#define INF 0x3f3f3f3f
#define PI (acos(-1.0))
#define lowbit(x) (x&(-x))
#define mp make_pair
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid
#define rson o<<1 | 1, mid+1, R
#define pii pair<int, int>
#pragma comment(linker, "/STACK:16777216")
typedef long long LL;
typedef unsigned long long ULL;
//typedef int LL;
using namespace std;
LL qpow(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base;base=base*base;b/=2;}return res;}
LL powmod(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base%mod;base=base*base%mod;b/=2;}return res;}
// head
struct Edge
{
int v, next;
Edge(int v = 0, int next = 0) : v(v), next(next) {}
}E[maxm];
queue<int> q;
int HA[maxn];
int HB[maxn];
int dis[maxn];
int res1[maxn];
int res2[maxn];
int son[maxn];
LL sum[maxn];
int n, m, cntE;
void addedges(int u, int v, int H[])
{
E[cntE] = Edge(v, H[u]);
H[u] = cntE++;
}
void init()
{
cntE = 0;
memset(HA, -1, sizeof HA);
memset(HB, -1, sizeof HB);
memset(res1, 0, sizeof res1);
memset(res2, 0, sizeof res2);
}
int bfs(int u, int H[], int flag)
{
memset(dis, -1, sizeof dis);
q.push(u), dis[u] = 0;
int res = 0, t = -1;
while(!q.empty()) {
u = q.front();
q.pop();
if(!flag) res1[u] = max(res1[u], dis[u]);
else res2[u] = max(res2[u], dis[u]);
if(dis[u] > t) t = dis[u], res = u;
for(int e = H[u]; ~e; e = E[e].next) {
int v = E[e].v;
if(dis[v] == -1) {
dis[v] = dis[u] + 1;
q.push(v);
}
}
}
return res;
}
void work()
{
int u, v;
for(int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
addedges(u, v, HA);
addedges(v, u, HA);
}
for(int i = 1; i < m; i++) {
scanf("%d%d", &u, &v);
addedges(u, v, HB);
addedges(v, u, HB);
}
LL tt = 0;
u = bfs(1, HA, 0);
u = bfs(u, HA, 0);
tt = max(tt, (LL)dis[u]);
bfs(u, HA, 0);
u = bfs(1, HB, 1);
u = bfs(u, HB, 1);
tt = max(tt, (LL)dis[u]);
bfs(u, HB, 1);
sort(res1+1, res1+n+1);
sort(res2+1, res2+m+1);
sum[m+1] = 0;
for(int i = m; i >= 1; i--) sum[i] = sum[i+1] + res2[i];
LL ans = 0;
for(int i = 1; i <= n; i++) {
int t = lower_bound(res2+1, res2+m+1, tt - res1[i] - 1) - res2;
ans += sum[t] + (LL)(m - t + 1) * (1 + res1[i]);
ans += tt * (t - 1);
}
printf("%lld\n", ans);
}
int main()
{
while(scanf("%d%d", &n, &m) != EOF) {
init();
work();
}
return 0;
}