知识蒸馏: Hinton 2015年在论文《Distilling the knowledge in a neural network》中首次提出,并应用在分类任务上,大模型称为 teacher(教师模型),小模型称为 Student(学生模型),来自 Teacher 模型输出的监督信息称之为 knowledge (知识),而 student 学习迁移来自 teacher 的监督信息的过程称之为Distillation(蒸馏)。下面是知识蒸馏的入门小示例(代码可直接跑起来):

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchinfo import summary


class TeacherModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=64,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
        )
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.conv2 = nn.Conv2d(
            in_channels=64,
            out_channels=256,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
        )
        self.fc1 = nn.Linear(256 * 7 * 7, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x


class StudentModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
        )
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
        )
        self.fc1 = nn.Linear(16 * 7 * 7, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x


def check_accuracy(loader, model, device):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            preds = model(x)
            _, predictions = preds.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

    model.train()
    return (num_correct / num_samples).item()


def train_model(model, epochs):
    # print(summary(model))  # teacher params: 273,802; student params: 9,098
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        model.train()
        losses = []

        pbar = tqdm(train_loader, leave=False, desc=f"Epoch {epoch + 1}")
        for data, labels in pbar:
            data, labels = data.to(device), labels.to(device)
            # forward
            preds = model(data)
            loss = criterion(preds, labels)
            losses.append(loss.item())
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_loss = sum(losses) / len(losses)
        acc = check_accuracy(test_loader, model, device)
        print(f"Loss:{avg_loss:.2f}\tAccuracy:{acc:.2f}")

    return model


def train_distillation(teacher, student, epochs, temp=7, alpha=0.3):
    student_loss_fn = nn.CrossEntropyLoss()
    divergence_loss_fn = nn.KLDivLoss(reduction="batchmean")
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)

    teacher.eval()
    student.train()
    for epoch in range(epochs):
        losses = []
        pbar = tqdm(train_loader, leave=False, desc=f"Epoch {epoch + 1}")
        for data, labels in pbar:
            data, labels = data.to(device), labels.to(device)
            # forward
            with torch.no_grad():
                teacher_preds = teacher_model(data)
            student_preds = student(data)

            student_loss = student_loss_fn(student_preds, labels)
            ditillation_loss = divergence_loss_fn(F.softmax(student_preds / temp, dim=1),
                                                  F.softmax(teacher_preds / temp, dim=1))
            loss = alpha * student_loss + (1 - alpha) * ditillation_loss
            losses.append(loss.item())

            # backward
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

        avg_loss = sum(losses) / len(losses)
        acc = check_accuracy(test_loader, student, device)
        print(f"Loss:{avg_loss:.2f}\tAccuracy:{acc:.2f}")


if __name__ == '__main__':
    torch.manual_seed(0)  # 设置随机种子,便于复现
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True  # 使用cuDNN加速卷积运算

    # Load mnist train dataset
    train_dataset = torchvision.datasets.MNIST(
        root="data/",
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )
    # Load mnist test dataset
    test_dataset = torchvision.datasets.MNIST(
        root="data/",
        train=False,
        transform=transforms.ToTensor(),
        download=True
    )
    # Create train and test dataloaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

    epochs = 6
    print("---------- teacher ----------")
    teacher_model = train_model(TeacherModel().to(device), epochs)

    print("---------- student ----------")
    train_model(StudentModel().to(device), epochs)

    print("---------- student and teacher ----------")
    train_distillation(teacher_model, StudentModel().to(device), epochs)