前言

在Pytorch中,有一些预训练模型或者预先封装的功能往往通过torch.hub模块中的一些方法进行加载,会保存一些文件在本地,通常默认地址是在C盘。考虑到某些预加载的资源很大,保存在C盘十分的占用存储空间,因此有时候需要修改这个保存地址。

注意!本文有较长篇幅分析Pytorch缓存路径的设置逻辑,若无相关需求,可直接跳到总结部分查看具体配置方法。

分析

其实不论是使用torch.hub.load()或者是Pytorch提供的预训练模型的服务,通过对源码的跟踪分析,会发现它们下载资源的方式都是通过torch.hub模块进行完成的,以最常见的预训练模型下载函数load_state_dict_from_url() 为例,可以在其函数声明部分看到 model_dir 参数。

def load_state_dict_from_url(
    url: str,
    model_dir: Optional[str] = None,
    map_location: Optional[Union[Callable[[str], str], Dict[str, str]]] = None,
    progress: bool = True,
    check_hash: bool = False,
    file_name: Optional[str] = None
) -> Dict[str, Any]:

model_dir 参数处于缺省状态时,该函数会调用同一模块下的 get_dir() 函数获取默认缓存地址。

if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

进入 get_dir() 函数可以看到,其调用了一个私有方法 _get_torch_home() 获取默认路径。

def get_dir():
    r"""
    Get the Torch Hub cache directory used for storing downloaded models & weights.

    If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
    filesystem layout, with a default value ``~/.cache`` if the environment
    variable is not set.
    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_HUB'):
        warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')

    if _hub_dir is not None:
        return _hub_dir
    return os.path.join(_get_torch_home(), 'hub')

函数的相关代码以及一些常量定义如下:

ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'

def _get_torch_home():
    torch_home = os.path.expanduser(
        os.getenv(ENV_TORCH_HOME,
                  os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
                                         DEFAULT_CACHE_DIR), 'torch')))
    return torch_home

从其调用了os.getenv() 来看,显然是通过读取环境变量来确定默认目录的。接下来我们依次分析该代码段。首先是对于最外层的os.getenv(ENV_TORCH_HOME) ,其获取常量ENV_TORCH_HOME所指向的环境变量中的值,即环境变量TORCH_HOME中的值,若未找到该环境变量的值,则返回如下值:

os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
                                         DEFAULT_CACHE_DIR), 'torch')

该值由第一个参数与第二个参数拼接而成,其中第一个参数同样是尝试获取环境变量,而第二个参数则是常量文本 ‘torch’ ,对于第一个参数而言,其获取常量ENV_XDG_CACHE_HOME所指向的环境变量XDG_CACHE_HOME的值,若获取失败则返回常量DEFAULT_CACHE_DIR的值,即 ‘~/.cache’

同时,当返回路径如果是 **'~/.cache/torch’时,最外层 os.path.expanduser会自动替换~**关键符号为当前计算机的用户路径,例如 C:\Users\Administrator.cache\torch

总结

在hub.py 文件中

_get_torch_home 函数获取缓存默认存储位置

通常优先取环境变量 ‘TORCH_HOME’ 中的值,在代码中期被声明为ENV_TORCH_HOME 变量。

若不存在则取环境变量 ‘XDG_CACHE_HOME’ 的值拼接 ‘torch’ 为默认位置,其中 ‘XDG_CACHE_HOME’ 在代码中被声明为变量ENV_XDG_CACHE_HOME。

若依旧不存在,则返回 ‘~/.cache’ +‘torch’ ,并替换 ~ 为本地用户路径。

因此,可以通过配置环境变量来修改Pytorch的默认缓存位置,具体如下:

‘XDG_CACHE_HOME’ = Pytorch相关包存放缓存的默认位置

‘TORCH_HOME’ = %XDG_CACHE_HOME%\torch

具体步骤如下:

首先打开计算机的属性面板

PyTorch C盘 pytorch c盘 缓存_pytorch


接着在属性面板右上角打开 “高级系统设置”

PyTorch C盘 pytorch c盘 缓存_缓存_02


从高级设置中进入环境变量设置界面

PyTorch C盘 pytorch c盘 缓存_缓存_03


通过点击新建,完成对环境变量的新增,其中用户变量仅对当前用户有效,而环境变量则对本机器所有用户生效。

PyTorch C盘 pytorch c盘 缓存_环境变量_04


在我的设置中,我设置如下:

XDG_CACHE_HOME=D:\Python\cache

TORCH_HOME=%XDG_CACHE_HOME%\torch即用D:\Python\cache存储Pytorch相关包下载的缓存,并使用D:\Python\cache\torch缓存Pytorch本身下载的一些缓存文件。其中 %XDG_CACHE_HOME% 可看做对环境变量XDG_CACHE_HOME的引用。

PyTorch C盘 pytorch c盘 缓存_缓存_05

或者在项目运行的代码前加上临时环境变量设置
os.environ[‘TORCH_HOME’]=‘E:/Data/torch-model’

// 全文完

因笔者能力有限,若文章内容存在错误或不恰当之处,欢迎留言、私信批评指正。
Email:YePeanut[at]foxmail.com