连接后的新树的直径要么是原树的直径。。。要么是连接的两个点各自的最长链加1.。。。先预处理每个点的最长链,然后存储一颗树的最长链,另外一颗树遍历每个节点,二分+前缀和求出答案。。。


#include <iostream>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <climits>
#include <cstdlib>
#include <cmath>
#include <time.h>
#define maxn 100005
#define maxm 500005
#define eps 1e-7
#define mod 1000000007
#define INF 0x3f3f3f3f
#define PI (acos(-1.0))
#define lowbit(x) (x&(-x))
#define mp make_pair
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid 
#define rson o<<1 | 1, mid+1, R
#define pii pair<int, int>
#pragma comment(linker, "/STACK:16777216")
typedef long long LL;
typedef unsigned long long ULL;
//typedef int LL; 
using namespace std;
LL qpow(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base;base=base*base;b/=2;}return res;}
LL powmod(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base%mod;base=base*base%mod;b/=2;}return res;}
// head

struct Edge
{
	int v, next;
	Edge(int v = 0, int next = 0) : v(v), next(next) {}
}E[maxm];

queue<int> q;
int HA[maxn];
int HB[maxn];
int dis[maxn];
int res1[maxn];
int res2[maxn];
int son[maxn];
LL sum[maxn];
int n, m, cntE;

void addedges(int u, int v, int H[])
{
	E[cntE] = Edge(v, H[u]);
	H[u] = cntE++;
}

void init()
{
	cntE = 0;
	memset(HA, -1, sizeof HA);
	memset(HB, -1, sizeof HB);
	memset(res1, 0, sizeof res1);
	memset(res2, 0, sizeof res2);
}

int bfs(int u, int H[], int flag)
{
	memset(dis, -1, sizeof dis);
	q.push(u), dis[u] = 0;
	int res = 0, t = -1;
	while(!q.empty()) {
		u = q.front();
		q.pop();
		if(!flag) res1[u] = max(res1[u], dis[u]);
		else res2[u] = max(res2[u], dis[u]);
		if(dis[u] > t) t = dis[u], res = u;
		for(int e = H[u]; ~e; e = E[e].next) {
			int v = E[e].v;
			if(dis[v] == -1) {
				dis[v] = dis[u] + 1;
				q.push(v);
			}
		}
	}
	return res;
}

void work()
{
	int u, v;
	for(int i = 1; i < n; i++) {
		scanf("%d%d", &u, &v);
		addedges(u, v, HA);
		addedges(v, u, HA);
	}
	for(int i = 1; i < m; i++) {
		scanf("%d%d", &u, &v);
		addedges(u, v, HB);
		addedges(v, u, HB);
	}

	LL tt = 0;

	u = bfs(1, HA, 0);
	u = bfs(u, HA, 0);
	tt = max(tt, (LL)dis[u]);
	bfs(u, HA, 0);

	u = bfs(1, HB, 1);
	u = bfs(u, HB, 1);
	tt = max(tt, (LL)dis[u]);
	bfs(u, HB, 1);

	sort(res1+1, res1+n+1);
	sort(res2+1, res2+m+1);
	sum[m+1] = 0;
	for(int i = m; i >= 1; i--) sum[i] = sum[i+1] + res2[i];

	LL ans = 0;
	for(int i = 1; i <= n; i++) {
		int t = lower_bound(res2+1, res2+m+1, tt - res1[i] - 1) - res2;
		ans += sum[t] + (LL)(m - t + 1) * (1 + res1[i]);
		ans += tt * (t - 1);
	}
	printf("%lld\n", ans);
}

int main()
{
	while(scanf("%d%d", &n, &m) != EOF) {
		init();
		work();
	}

	return 0;
}