最近正在尝试将pytorch框架下一个风格化网络(MCCNET)的代码转换到jittor框架下。在转换的过程中,大部分torch库中的函数都能在jittor库中找到作用相同的同名函数;小部分没能找到同名函数的也可以通过jittor库中的其他函数做到同样的效果。下面也主要是讲一下如何完成这几个空缺同名函数的实现方式,以及发现的一些因为jittor框架自身的特点带来的问题。我会在目录中将pytorch的函数名打出来,同时标注完成的操作,大家可以根据自己遇到的问题在目录中查询。
MCCNET是中科院自动化所模式识别实验室多媒体计算团队2020年发表于人工智能国际顶级会议AAAI 2021的论文Arbitrary Video Style Transfer via Multi-Channel Correlation中提出的“基于多通道相关性的任意连续视频风格化网络(MCCNet)”,该网络用于完成视频的风格化迁移,风格化效果好且不会产生闪烁问题。
MCCNET代码也可在GitHub中找到。本文就是实现了该代码的jittor迁移。https://github.com/diyiiyiii/MCCNet 转换后的Jittor也可在Jittor分钟中找到:https://github.com/diyiiyiii/MCCNet/tree/Jittor


目录

  • 1. 关于jittor
  • 1.1 安装jittor
  • 1.2 torch->jittor转换脚本
  • 2. 代码转换
  • 2.1 基础操作
  • 2.1.1 torch.Tensor(data) 将numpy数组转换为jt数据类型Var
  • 2.1.2 np.empty(shape) 创建空Var
  • 2.1.3 np.random.permutation(n) 生成1-n的乱序列表
  • 2.1.4 np.mm() 矩阵乘法
  • 2.1.5 np.var() 求方差
  • 2.2 Train阶段问题
  • 2.2.1 dataset和dataloader 加载图片集
  • 2.2.2 load() 加载网络参数
  • 2.2.3 pytorch训练好的模型参数加载到jittor的模型中
  • 2.2.4 children() 获取网络中的各层
  • 2.2.5 MaxPool2d
  • 2.2.6 backward 反向传播
  • 2.2.7 requires_grad 设置参数学习模式
  • 2.2.8 GPU模式
  • 3. Test阶段问题
  • 4. 总结


1. 关于jittor

1.1 安装jittor

因为我是在远程服务器上进行的安装,并且就是Ubuntu操作系统,所以直接pip安装就可以了,非常方便。

python3.7 -m pip install jittor

如果是要在windows系统下安装就需要使用docker安装了。
docker安装教程

1.2 torch->jittor转换脚本

jittor官网提供了一个简单的Pytorch模型代码转Jittor模型的脚本

pytorch 转置 pytorch转jittor_python


但就像官方自己说的,脚本只能用于模型代码的转换,而且得是class+module声明的,sequential声明的模型转换过程中会被直接删去…对于模型之外的代码,转换过程中也是大部分保留,少部分删去。所以这个转换脚本最好还是当做一个对照作用吧,不能直接拿来用。

2. 代码转换

2.1 基础操作

2.1.1 torch.Tensor(data) 将numpy数组转换为jt数据类型Var

使用jittor.array(data) /jittor.float(data)/jittor.float32(data) 数据类型最好为float32,因为矩阵乘法暂时不支持32和64位混用。(array()默认保持类型不变,剩余两个默认转化为float32类型)

import jittor as jt
import numpy as np

data = np.random.randn(1,2,3).astype("float32")
>>[[[ 0.692951   0.3800234 -0.0999987]
  [-2.727701  -2.4574485  1.1308112]]]
  
output = jt.array(data)
>>jt.Var([[[ 0.692951   0.3800234 -0.0999987]
  [-2.727701  -2.4574485  1.1308112]]], dtype=float32)

访问Var中的数据使用Var.data就可以了。

output.data
>>[[[ 0.692951   0.3800234 -0.0999987]
  [-2.727701  -2.4574485  1.1308112]]]

2.1.2 np.empty(shape) 创建空Var

使用jittor.random(shape,dtype,uniform)jittor.empty(shape)

jt.random((2,2,4))
>>jt.Var([[[0.6693543  0.83819515 0.5461786  0.6237627 ]
  [0.58049047 0.3033327  0.11268225 0.85048825]]

 [[0.1557529  0.7390003  0.3779687  0.52000093]
  [0.6558841  0.4438333  0.5063377  0.01796175]]], dtype=float32)

这样模型中声明的数据,默认是可以训练的参数。
或使用np.random创建一个array,再转换成Var类型,像2.1.1中举例那样。

2.1.3 np.random.permutation(n) 生成1-n的乱序列表

使用get_random_list(n)

from jittor.dataset.utils import get_random_list
get_random_list(4)
>>[2, 0, 1, 3]

里面就是封装了一个np.random.permutation(n)

2.1.4 np.mm() 矩阵乘法

jittor没有自带的mm函数,只有bmm函数,所以想要实现矩阵相乘就只能自己写一个函数了。但是jittor官网有直接给出矩阵乘法实现的函数,可以直接拿来用。

def matmul(a, b):
    (n, m), k = a.shape, b.shape[-1]
    a = a.broadcast([n,m,k], dims=[2])
    b = b.broadcast([n,m,k], dims=[0])
    return (a*b).sum(dim=1)

2.1.5 np.var() 求方差

jittor同样没有var()函数,但是有std()函数,如果要通过方差计算标准差的化可以直接使用std(data)函数。但是std函数是没有dim参数的,只会求出data中所有数据的var。如果想要指定维度算方差的话,还是需要自己写一个函数。

#dim=2的情况
def calc_mean_std(feat, eps=1e-5):
   N, C, H, W = feat.size()
   assert (len(feat.size()) == 4)
   dims = list(range(2,feat.ndim))
   X = ( H * W ) / (H * W - 1 )  #用于将方差转换为样本方差
   mean = jt.mean(feat, dims=dims)
   xmean = mean * X
   x2mean = jt.mean(feat * feat, dims=dims) * X
   xvar = (x2mean - xmean * xmean).maximum(0.0)
   return mean.view(N, C, 1, 1), jt.sqrt(xvar+eps).view(N, C, 1, 1)

2.2 Train阶段问题

2.2.1 dataset和dataloader 加载图片集

jittor中没有单独的dataset和dataloader函数,但有集两个操作于一体的Dataset类以及它的子类ImageFolder.

class jittor.dataset.Dataset(batch_size=16, shuffle=False, drop_last=False, num_workers=0, buffer_size=536870912, stop_grad=True, keep_numpy_array=False)
class jittor.dataset.ImageFolder(root, transform=None)

对于训练集读取操作可以用一句语句实现:

#torch代码
#class FlatFolderDataset(data.Dataset):
#      ... ...
content_dataset = FlatFolderDataset(content_dir, content_transform)
content_loader = data.DataLoader(
    content_dataset, batch_size=args.batch_size,
    sampler=InfiniteSamplerWrapper(content_dataset),
    num_workers=args.n_threads)

#jitter代码
content_dataset_loader = ImageFolder(args.content_dir, content_transform).set_attrs(batch_size = args.batch_size, num_workers = args.n_threads)
#通过设置属性的方式达到输入参数的效果

ImageFolder的缺陷是没有sampler选项。如果想实现torch中的sampler效果,就需要自己定义一个loader类,重写它的__iter__参数。关于__iter__的重写可以参考官方文档,里面对于各种情况下的sampler都有讲解。
下面是MCCNET中的sampler在__iter__中实现的例子,实现了单卡单线程当训练集大小小于iter次数时,对训练集进行重复随机取index。sampler.py代码

torch的sampler定义
#import numpy as np
#from torch.utils import data
def InfiniteSampler(n):
    # i = 0
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31
jittor代码
class FlatFolderDataset(Dataset):
    def __iter__(self):
        i = self.len - 1
        batch_size = args.batch_size
        index_list = get_random_list(self.len)
        batch_data = []
        while True:
            for x in range(batch_size):
                y = i
                if i >= self.len:
                    y = i - self.len
                batch_data.append(self[index_list[y]])
                i += 1

            if (i >= self.len):
                index_list = get_random_list(self.len)
                i = 0

            batch_data = self.collate_batch(batch_data)
            batch_data = self.to_jittor(batch_data)
            yield jt.float(batch_data)
            batch_data = []

但由于单卡单线程较慢,所以改成了多线程的方式,具体代码可以在GitHub中看到。

2.2.2 load() 加载网络参数

jittor中是有同名参数的。

def load(path):
    if path.endswith(".pth"):
        try:
            dirty_fix_pytorch_runtime_error()
            import torch
        except:
            raise RuntimeError("pytorch need to be installed when load pth format.")
        model_dict = torch.load(path, map_location=torch.device('cpu'))
    else:
        model_dict = safeunpickle(path)
    return model_dict

可以看到load()函数会判断读入的参数文件后缀是否为".pth"(即torhc的参数文件)。如果是,则调用torch自带的load()函数,并加上了参数"map_location=torch.device(‘cpu’)",而这个参数会导致这句语句中断
但是不知道为什么每次调用都跑不动,于是就单拎出来测试了一下jittor包装的load函数。

model_dict = torch.load(path, map_location=torch.device('cpu'))

发现还是跑不动,但把map_location参数去掉就好了。
在调用.pth文件时,推荐直接使用torch的load函数。
顺便说一下,Jittor中的参数文件后缀为.pkl

2.2.3 pytorch训练好的模型参数加载到jittor的模型中

只需导入torch自带的模型,再将它的参数赋值给jittor的模型

import torch 
import torchvision.models as tcmodels 
import jittor.models as jtmodels

pytorch_model =tcmodels.__dict__['vgg19']()
jittor_model = jtmodels.__dict__['vgg19']() 
 # Set eval to avoid dropout layer 
 pytorch_model.eval() 
 jittor_model.eval() 
 # Jittor loads pytorch parameters to ensure forward alignment 
 jittor_model.load_parameters(pytorch_model.state_dict())

如果是本地的pth文件,直接使用load_state_dict()加载就可以。

vgg.load_state_dict(torch.load(vgg_path))

2.2.4 children() 获取网络中的各层

有时候我们并不想要输入通过整个网络后的结果,而是在通过某个特定层之后的结果,就需要将model中的层单独取出来,然后再输入数据、运算。
当你使用jittor自带的模型时,要注意通过children()取出来的结果,可能并不是想你想象得那样。

import jittor.models as jtmodels
vgg = jtmodels.vgg19()
enc_layers = list(vgg.children())
print(enc_layers)
>>[Sequential(
    0: Conv(3, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[64,], None, Kw=None, fan=None, i=None, bound=None)
    1: relu()
    2: Conv(64, 64, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[64,], None, Kw=None, fan=None, i=None, bound=None)
    3: relu()
    4: Pool(2, 2, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    5: Conv(64, 128, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[128,], None, Kw=None, fan=None, i=None, bound=None)
    6: relu()
    7: Conv(128, 128, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[128,], None, Kw=None, fan=None, i=None, bound=None)
    8: relu()
    9: Pool(2, 2, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    10: Conv(128, 256, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[256,], None, Kw=None, fan=None, i=None, bound=None)
    11: relu()
    12: Conv(256, 256, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[256,], None, Kw=None, fan=None, i=None, bound=None)
    13: relu()
    14: Conv(256, 256, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[256,], None, Kw=None, fan=None, i=None, bound=None)
    15: relu()
    16: Conv(256, 256, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[256,], None, Kw=None, fan=None, i=None, bound=None)
    17: relu()
    18: Pool(2, 2, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    19: Conv(256, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    20: relu()
    21: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    22: relu()
    23: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    24: relu()
    25: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    26: relu()
    27: Pool(2, 2, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
    28: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    29: relu()
    30: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    31: relu()
    32: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    33: relu()
    34: Conv(512, 512, (3, 3), (1, 1), (1, 1), (1, 1), 1, float32[512,], None, Kw=None, fan=None, i=None, bound=None)
    35: relu()
    36: Pool(2, 2, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=False, op=maximum)
), AdaptiveAvgPool2d((7, 7)), Sequential(
    0: Linear(25088, 4096, float32[4096,], None)
    1: relu()
    2: Dropout(0.5, is_train=False)
    3: Linear(4096, 4096, float32[4096,], None)
    4: relu()
    5: Dropout(0.5, is_train=False)
    6: Linear(4096, 1000, float32[1000,], None)
)]

可以发现他的结果实际上是两个Sequential加一个单独的AdaptiveAvgPool2d,并不是所有的层单独排列在列表中。
当你使用以下代码:

enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
output1 = enc_1(input_feat)
output2 = enc_2(output1)
output3 = enc_3(output2)

想象中的outpu1应该是input通过relu1_1之后的feat,output2应该是通过relu2_1之后的结果,但实际上output1是input通过整个vgg的output,因为在第一步取enc_layers[:4]时就将两个sequential和一个layer全取出来了,即取了整个网络,之后的enc_2和enc_3都为none。
针对这个问题秩只需做一点小小的改进,将enc_layers赋值为enc_layers[0]即可。

enc_layers = list(enc_layers[0].children())

2.2.5 MaxPool2d

jittor库中是没有MaxPool2d的,但有Pool类

class jittor.nn.Pool(kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op='maximum')

用Pool替换MaxPool2d就可以了。要注意的是Pool的kernal_size,stride和padding参数虽然可以像(2,2)的元组,但是进入底层的运算会报类型错误。

pytorch 转置 pytorch转jittor_方差_02

所以还是要将(2,2)和(0,0)写成2和0。

2.2.6 backward 反向传播

jittor中的正向传播函数为execute,相当于torch中的forward函数。
同时jittor没有backward函数,而是融合到了step函数中。

torch代码
    optimizer.zero_grad()
    loss.sum().backward()
    optimizer.step()

jittor代码
    #optimizer.zero_grad()也封装到了step函数中
    optimizer.step(loss.sum())

2.2.7 requires_grad 设置参数学习模式

通过设置模型参数的requires_grad = False确实是会生效的。但当你查看输入通过模型后获得的结果的requires_grad属性时,就会发现它又被改回了True。
解决办法是使用with no_grads()语句。

with jt.no_grad():
	style_feats = self.encode_with_intermeidate(style)

style_feats中的requires_grad属性就会被设置为False。

2.2.8 GPU模式

jt.flags.use_cuda属性可以设置运行在GPU模式还是CPU模式。

jt.flags.use_cuda = 0 # jt.flags.use_cuda 表示是否使用 gpu 训练。
# 如果 jt.flags.use_cuda=1,表示使用GPU训练 如果 jt.flags.use_cuda = 0 表示使用 CPU

要在每个py开头都加上这句语句,不然就会导致一些奇怪的bug。在文件开头声明后,就不需要再代码中再加入如to(device)的声明。jittor使用同一内存管理。

pytorch 转置 pytorch转jittor_深度学习_03


个人的理解是大概不用像torch那样分为CPU Tensor和CPU Tensor,即不用考虑数据在CPU和GPU之间切换,也不用像torch那样设置to(device)

3. Test阶段问题

基本上所有问题都在train代码中出现过,所以没有什么太大的问题。但是还是要手动转换的,官网的辅助转换工具不能转换除模型以外的代码

4. 总结

以上就是将MCCNET从torch框架转换到jittor框架的过程中遇到的问题。有兴趣的朋友也可以自己尝试一下转换,如果遇到新的问题的话可以交流一下!
MCCNET代码:https://github.com/diyiiyiii/MCCNet