1. 概述

使用pytorch建立的模型,有时想把pytorch建立好的模型装换为keras,本人使用TensorFlow作为keras的backend

2. 依赖

依赖的标准库:


  • pytorch
  • keras
  • tensorflow
  • pytorch2keras

3. 安装方式

git clone https://github.com/nerox8664/pytorch2keras.git
python setup.py install

4. 代码

import numpy as np
import torch
from torch.autograd import Variable
from pytorch2keras import converter

class Pytorch2KerasTestNet(torch.nn.Module):
def __init__(self):
super(Pytorch2KerasTestNet, self).__init__()
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
y = self.relu(self.in1(self.conv1(x)))
return y


class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

def forward(self, x):
out = self.reflection_pad(x)

print("conv2d")
out = self.conv2d(out)
return out

def check_error(output, k_model, input_np, epsilon=1e-5):
pytorch_output = output.data.numpy()
keras_output = k_model.predict(input_np)

error = np.max(pytorch_output - keras_output)
print('Error:', error)

assert error < epsilon
return error

model = Pytorch2KerasTestNet()
input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
input_var = Variable(torch.FloatTensor(input_np))
output = model(input_var)
k_model = converter.pytorch_to_keras(model, input_var, [(3, 224, 224,)], verbose=True)
k_model.summary()

max_error = 0
error = check_error(output, k_model, input_np)
if max_error < error:
max_error = error
print('Max error: {0}'.format(max_error))

#保存模型
k_model.save('my_model.h5')

# 重新载入模型
from keras.models import load_model
import tensorflow as tf

model = load_model('my_model.h5',custom_objects={"tf": tf})
model.summary()

输出结果:

Layer (type)                 Output Shape              Param #   
=================================================================
input_0 (InputLayer) (None, 3, 224, 224) 0
_________________________________________________________________
5 (Lambda) (None, 3, 232, 232) 0
_________________________________________________________________
6 (Conv2D) (None, 32, 224, 224) 7808
_________________________________________________________________
7 (Lambda) (None, 32, 224, 224) 0
_________________________________________________________________
output_0 (Activation) (None, 32, 224, 224) 0
=================================================================
Total params: 7,808
Trainable params: 7,808
Non-trainable params: 0