前言
在Pytorch中,有一些预训练模型或者预先封装的功能往往通过torch.hub模块中的一些方法进行加载,会保存一些文件在本地,通常默认地址是在C盘。考虑到某些预加载的资源很大,保存在C盘十分的占用存储空间,因此有时候需要修改这个保存地址。
注意!本文有较长篇幅分析Pytorch缓存路径的设置逻辑,若无相关需求,可直接跳到总结部分查看具体配置方法。
分析
其实不论是使用torch.hub.load()或者是Pytorch提供的预训练模型的服务,通过对源码的跟踪分析,会发现它们下载资源的方式都是通过torch.hub模块进行完成的,以最常见的预训练模型下载函数load_state_dict_from_url() 为例,可以在其函数声明部分看到 model_dir 参数。
def load_state_dict_from_url(
url: str,
model_dir: Optional[str] = None,
map_location: Optional[Union[Callable[[str], str], Dict[str, str]]] = None,
progress: bool = True,
check_hash: bool = False,
file_name: Optional[str] = None
) -> Dict[str, Any]:
当 model_dir 参数处于缺省状态时,该函数会调用同一模块下的 get_dir() 函数获取默认缓存地址。
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
进入 get_dir() 函数可以看到,其调用了一个私有方法 _get_torch_home() 获取默认路径。
def get_dir():
r"""
Get the Torch Hub cache directory used for storing downloaded models & weights.
If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
filesystem layout, with a default value ``~/.cache`` if the environment
variable is not set.
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_HUB'):
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
if _hub_dir is not None:
return _hub_dir
return os.path.join(_get_torch_home(), 'hub')
函数的相关代码以及一些常量定义如下:
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
DEFAULT_CACHE_DIR), 'torch')))
return torch_home
从其调用了os.getenv() 来看,显然是通过读取环境变量来确定默认目录的。接下来我们依次分析该代码段。首先是对于最外层的os.getenv(ENV_TORCH_HOME) ,其获取常量ENV_TORCH_HOME所指向的环境变量中的值,即环境变量TORCH_HOME中的值,若未找到该环境变量的值,则返回如下值:
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
DEFAULT_CACHE_DIR), 'torch')
该值由第一个参数与第二个参数拼接而成,其中第一个参数同样是尝试获取环境变量,而第二个参数则是常量文本 ‘torch’ ,对于第一个参数而言,其获取常量ENV_XDG_CACHE_HOME所指向的环境变量XDG_CACHE_HOME的值,若获取失败则返回常量DEFAULT_CACHE_DIR的值,即 ‘~/.cache’。
同时,当返回路径如果是 **'~/.cache/torch’时,最外层 os.path.expanduser会自动替换~**关键符号为当前计算机的用户路径,例如 C:\Users\Administrator.cache\torch。
总结
在hub.py 文件中
_get_torch_home 函数获取缓存默认存储位置
通常优先取环境变量 ‘TORCH_HOME’ 中的值,在代码中期被声明为ENV_TORCH_HOME 变量。
若不存在则取环境变量 ‘XDG_CACHE_HOME’ 的值拼接 ‘torch’ 为默认位置,其中 ‘XDG_CACHE_HOME’ 在代码中被声明为变量ENV_XDG_CACHE_HOME。
若依旧不存在,则返回 ‘~/.cache’ +‘torch’ ,并替换 ~ 为本地用户路径。
因此,可以通过配置环境变量来修改Pytorch的默认缓存位置,具体如下:
‘XDG_CACHE_HOME’ = Pytorch相关包存放缓存的默认位置
‘TORCH_HOME’ = %XDG_CACHE_HOME%\torch
具体步骤如下:
首先打开计算机的属性面板
接着在属性面板右上角打开 “高级系统设置”
从高级设置中进入环境变量设置界面
通过点击新建,完成对环境变量的新增,其中用户变量仅对当前用户有效,而环境变量则对本机器所有用户生效。
在我的设置中,我设置如下:
XDG_CACHE_HOME=D:\Python\cache
TORCH_HOME=%XDG_CACHE_HOME%\torch即用D:\Python\cache存储Pytorch相关包下载的缓存,并使用D:\Python\cache\torch缓存Pytorch本身下载的一些缓存文件。其中 %XDG_CACHE_HOME% 可看做对环境变量XDG_CACHE_HOME的引用。
或者在项目运行的代码前加上临时环境变量设置
os.environ[‘TORCH_HOME’]=‘E:/Data/torch-model’
// 全文完
因笔者能力有限,若文章内容存在错误或不恰当之处,欢迎留言、私信批评指正。
Email:YePeanut[at]foxmail.com