PyTorch 加载 pth 模型

在深度学习中,模型的训练通常需要花费大量的时间和计算资源。因此,为了节省时间和资源,我们可以将训练好的模型保存下来,以备后续使用。在 PyTorch 中,我们可以将模型保存为 .pth 文件,并在需要的时候加载它们。

本文将介绍如何使用 PyTorch 加载 .pth 模型,并提供相应的代码示例。

准备工作

首先,我们需要安装 PyTorch。可以使用以下命令安装最新版本的 PyTorch:

pip install torch

加载模型

要加载保存的 .pth 模型,我们需要创建一个与原始模型相同结构的模型类,并使用该类实例化一个模型对象。然后,我们可以使用 load_state_dict 方法加载保存的模型参数到新创建的模型对象中。

以下是一个加载 .pth 模型的示例:

import torch
import torchvision.models as models

# 创建一个与原始模型相同结构的模型类
model = models.resnet18()

# 加载保存的模型参数到新创建的模型对象中
model.load_state_dict(torch.load('model.pth'))

# 将模型设置为评估模式
model.eval()

在上面的示例中,我们使用了 torchvision 中的 ResNet-18 模型作为示例模型,并将其保存为 model.pth 文件。然后,我们创建了一个与原始模型相同结构的模型类,并加载了保存的模型参数到新创建的模型对象中。最后,我们将模型设置为评估模式。

序列图

下面是加载 .pth 模型的示例的序列图,用于更清晰地理解加载模型的过程:

sequenceDiagram
    participant User
    participant Code
    participant Model

    User->>Code: 传递保存的.pth文件路径
    Code->>Model: 创建相同结构的模型类
    Model->>Code: 返回新创建的模型对象
    Code->>Model: 加载保存的模型参数
    Model->>Code: 返回加载后的模型对象
    Code->>Model: 设置模型为评估模式
    Model->>Code: 返回评估模式的模型对象

总结

在本文中,我们了解了如何使用 PyTorch 加载保存为 .pth 文件的模型。我们首先创建了一个与原始模型相同结构的模型类,然后使用 load_state_dict 方法加载保存的模型参数到新创建的模型对象中。最后,我们将模型设置为评估模式。通过这些步骤,我们可以方便地加载预训练的模型,并在需要的时候使用它们进行推理。

希望本文对初学者有所帮助,让大家能够更好地理解和使用 PyTorch 加载 .pth 模型。更多关于 PyTorch 的内容,可以查看 PyTorch 官方文档。