GAN的基本结构

GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)

GAN 充分利用“对抗过程”训练两个神经网络,这两个网络会互相博弈直至达到一种理想的平衡状态,我们这个例子中的警察和罪犯就相当于这两个神经网络。其中一个神经网络叫做生成器网络 G(Z),它会使用输入随机噪声数据,生成和已有数据集非常接近的数据,它学习的是数据分布;另一个神经网络叫鉴别器网络 D(X),它会以生成的数据作为输入,尝试鉴别出哪些是生成的数据,哪些是真实数据。鉴别器的核心是实现二元分类,输出的结果是输入数据来自真实数据集(和合成数据或虚假数据相对)的概率。

整个过程的目标函数从正式意义上可以写为:

生成对抗网络和条件生成对抗网络的区别_生成对抗网络和条件生成对抗网络的区别

前面所说的 GAN 最终能达到一种理想的平衡状态,是指生成器应该能模拟真实的数据,鉴别器输出的概率应该为 0.5, 即生成的数据和真实数据一致。也就是说,它不确定来自生成器的新数据是真实还是虚假,二者的概率相等(这样熵最大)。

这里,使用GAN生成正弦信号,下面给出代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt




# torch.manual_seed(1)       # reproducible
# np.random.seed(1)

# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001  # learning rate for generator
LR_D = 0.0001  # learning rate for discriminator
N_IDEAS = 8  # think of this as number of ideas for generating an art work(Generator)
ART_COMPONENTS = 15  # it could be total point G can drew in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])



def artist_works():  # painting from the famous artist (real target)
    # a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    r = 0.02 * np.random.randn(1, ART_COMPONENTS)
    paintings = np.sin(PAINT_POINTS * np.pi) + r
    paintings = torch.from_numpy(paintings).float()
    return paintings


# G = nn.Sequential(  # Generator
#     nn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)
#     nn.ReLU(),
#     nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideas
# )
#
# D = nn.Sequential(  # Discriminator
#     nn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like G
#     nn.ReLU(),
#     nn.Linear(128, 1),
#     nn.Sigmoid(),  # tell the probability that the art work is made by artist
# )

class Ge(nn.Module):
    def __init__(self):
        super(Ge,self).__init__()
        self.fc1=nn.Linear(N_IDEAS,128)
        self.fc2=nn.Linear(128,ART_COMPONENTS)

    def forward(self, x):
        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return x


class De(nn.Module):
    def __init__(self):
        super(De,self).__init__()
        self.fc1=nn.Linear(ART_COMPONENTS,128)
        self.fc2=nn.Linear(128,1)

    def forward(self,x):
        x=F.relu(self.fc1(x))
        x=F.sigmoid(self.fc2(x))
        return x


G=Ge()
D=De()


opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

plt.ion()  # something about continuous plotting

D_loss_history = []
G_loss_history = []
for step in range(10000):
    artist_paintings = artist_works()  # real painting from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas
    G_paintings = G(G_ideas)  # fake painting from G (random ideas)

    prob_artist0 = D(artist_paintings)  # D try to increase this prob
    prob_artist1 = D(G_paintings)  # D try to reduce this prob

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    D_loss_history.append(D_loss)
    G_loss_history.append(G_loss)

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)  # reusing computational graph
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()



    print("4444d",PAINT_POINTS[0])


    if step % 1000 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='r', lw=3, label='Generated painting', )
        plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='b', lw=3, label='upper bound')
        plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),
                 fontdict={'size': 13})
        plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((-1, 1));
        plt.legend(loc='upper right', fontsize=10);
        plt.draw();
        plt.pause(0.01)

# plt.ioff()
# plt.show()

上面代码中,def artist_works()函数这里主要产生给定的正弦信号:

def artist_works():  # painting from the famous artist (real target)
    # a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    r = 0.02 * np.random.randn(1, ART_COMPONENTS)
    paintings = np.sin(PAINT_POINTS * np.pi) + r
    paintings = torch.from_numpy(paintings).float()
    return paintings

下面这段代码主要是构建生成器与判别器网络,这里的网络是在pytorch下完成的。

class Ge(nn.Module):
    def __init__(self):
        super(Ge,self).__init__()
        self.fc1=nn.Linear(N_IDEAS,128)
        self.fc2=nn.Linear(128,ART_COMPONENTS)

    def forward(self, x):
        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return x


class De(nn.Module):
    def __init__(self):
        super(De,self).__init__()
        self.fc1=nn.Linear(ART_COMPONENTS,128)
        self.fc2=nn.Linear(128,1)

    def forward(self,x):
        x=F.relu(self.fc1(x))
        x=F.sigmoid(self.fc2(x))
        return x

下面这段代码为生成器和判别器的损失函数:

D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))

实现效果,第一幅图为刚开始随机数输入产生的曲线,第二幅图为鉴别器输出的概率为 0.5,可以看出效果很好:

生成对抗网络和条件生成对抗网络的区别_数据_02

生成对抗网络和条件生成对抗网络的区别_数据_03

有了上面GAN的经验,接下来介绍生成对抗模仿学习:

在这里,整个工程有两个文件组成,一个env_OppositeV4.py构建环境,一个GAIL_OppositeV4.py运行程序。

首先介绍env_OppositeV4.py代码构建环境,先看一个构建的环境效果图:

生成对抗网络和条件生成对抗网络的区别_pytorch_04

图中红色的部分为起点,绿色部分为终点,下面给出env_OppositeV4.py代码:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import random
import cv2

class EnvOppositeV4(object):
    def __init__(self, size):
        self.map_size = size
        self.raw_occupancy = np.zeros((self.map_size, self.map_size))
        for i in range(self.map_size):
            self.raw_occupancy[0][i] = 1
            self.raw_occupancy[self.map_size - 1][i] = 1
            self.raw_occupancy[i][0] = 1
            self.raw_occupancy[i][self.map_size - 1] = 1
            self.raw_occupancy[i][int((self.map_size - 1) / 2)] = 1
        self.raw_occupancy[1][int((self.map_size - 1) / 2)] = 0
        self.raw_occupancy[self.map_size - 2][int((self.map_size - 1) / 2)] = 0

        self.occupancy = self.raw_occupancy.copy()

        self.agt1_pos = [int((self.map_size - 1) / 2), 1]
        self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
        self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1

    def reset(self):
        self.occupancy = self.raw_occupancy.copy()

        self.agt1_pos = [int((self.map_size - 1) / 2), 1]
        self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
        self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1

    def get_state(self):
        state = np.zeros((1, 2))
        state[0, 0] = self.agt1_pos[0] / self.map_size
        state[0, 1] = self.agt1_pos[1] / self.map_size
        return state

    def step(self, action_list):
        reward = 0
        # agent1 move
        if action_list[0] == 0:  # move up
            if self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] != 1:  # if can move
                self.agt1_pos[0] = self.agt1_pos[0] - 1
                self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] = 0
                self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
        elif action_list[0] == 1:  # move down
            if self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] != 1:  # if can move
                self.agt1_pos[0] = self.agt1_pos[0] + 1
                self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] = 0
                self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
        elif action_list[0] == 2:  # move left
            if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] != 1:  # if can move
                self.agt1_pos[1] = self.agt1_pos[1] - 1
                self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] = 0
                self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
        elif action_list[0] == 3:  # move right
            if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] != 1:  # if can move
                self.agt1_pos[1] = self.agt1_pos[1] + 1
                self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] = 0
                self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1

        if self.agt1_pos == self.goal1_pos:
            reward = reward + 5

        done = False
        if reward == 5:
            done = True
        return reward, done

    def get_global_obs(self):
        obs = np.zeros((self.map_size, self.map_size, 3))
        for i in range(self.map_size):
            for j in range(self.map_size):
                if self.occupancy[i][j] == 0:
                    obs[i, j, 0] = 1.0
                    obs[i, j, 1] = 1.0
                    obs[i, j, 2] = 1.0
        obs[self.agt1_pos[0], self.agt1_pos[1], 0] = 1.0
        obs[self.agt1_pos[0], self.agt1_pos[1], 1] = 0.0
        obs[self.agt1_pos[0], self.agt1_pos[1], 2] = 0.0
        return obs

    def render(self):
        obs = self.get_global_obs()
        enlarge = 30
        new_obs = np.ones((self.map_size*enlarge, self.map_size*enlarge, 3))
        for i in range(self.map_size):
            for j in range(self.map_size):

                if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
                    cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 0), -1)
                if obs[i][j][0] == 1.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
                    cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 255), -1)
                if obs[i][j][0] == 0.0 and obs[i][j][1] == 1.0 and obs[i][j][2] == 0.0:
                    cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 255, 0), -1)
                if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 1.0:
                    cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (255, 0, 0), -1)
        cv2.imshow('image', new_obs)
        cv2.waitKey(100)

上面代码中,这个部分生成如下图,其实就是生成环境的矩形框,1的部分到时候赋予黑颜色,0的部分赋予白色,就构建出了上面的图,这里也计算了agent的目标位置与起始位置。

def __init__(self, size):
    self.map_size = size
    self.raw_occupancy = np.zeros((self.map_size, self.map_size))
    for i in range(self.map_size):
        self.raw_occupancy[0][i] = 1
        self.raw_occupancy[self.map_size - 1][i] = 1
        self.raw_occupancy[i][0] = 1
        self.raw_occupancy[i][self.map_size - 1] = 1
        self.raw_occupancy[i][int((self.map_size - 1) / 2)] = 1
    self.raw_occupancy[1][int((self.map_size - 1) / 2)] = 0
    self.raw_occupancy[self.map_size - 2][int((self.map_size - 1) / 2)] = 0

    self.occupancy = self.raw_occupancy.copy()

    self.agt1_pos = [int((self.map_size - 1) / 2), 1]
    self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
    self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1

生成对抗网络和条件生成对抗网络的区别_ide_05

通过下面代码把数字为1的地方赋予黑色,把0的地方赋予白色,结果如下图。
def get_global_obs(self):
    obs = np.zeros((self.map_size, self.map_size, 3))
    for i in range(self.map_size):
        for j in range(self.map_size):
            if self.occupancy[i][j] == 0:
                obs[i, j, 0] = 1.0
                obs[i, j, 1] = 1.0
                obs[i, j, 2] = 1.0
    obs[self.agt1_pos[0], self.agt1_pos[1], 0] = 1.0
    obs[self.agt1_pos[0], self.agt1_pos[1], 1] = 0.0
    obs[self.agt1_pos[0], self.agt1_pos[1], 2] = 0.0
    return obs

生成对抗网络和条件生成对抗网络的区别_生成对抗网络和条件生成对抗网络的区别_06

通过下面的代码把框图放大。

生成对抗网络和条件生成对抗网络的区别_pytorch_04

def render(self):
    obs = self.get_global_obs()
    enlarge = 30
    new_obs = np.ones((self.map_size*enlarge, self.map_size*enlarge, 3))
    for i in range(self.map_size):
        for j in range(self.map_size):

            if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
                cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 0), -1)
            if obs[i][j][0] == 1.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
                cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 255), -1)
            if obs[i][j][0] == 0.0 and obs[i][j][1] == 1.0 and obs[i][j][2] == 0.0:
                cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 255, 0), -1)
            if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 1.0:
                cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (255, 0, 0), -1)
    cv2.imshow('image',new_obs)
    cv2.waitKey(100)

 下面这段代码主要是描述agent的动作与reward。

def step(self, action_list):
    reward = 0
    # agent1 move
    if action_list[0] == 0:  # move up
        if self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] != 1:  # if can move
            self.agt1_pos[0] = self.agt1_pos[0] - 1
            self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] = 0
            self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
    elif action_list[0] == 1:  # move down
        if self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] != 1:  # if can move
            self.agt1_pos[0] = self.agt1_pos[0] + 1
            self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] = 0
            self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
    elif action_list[0] == 2:  # move left
        if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] != 1:  # if can move
            self.agt1_pos[1] = self.agt1_pos[1] - 1
            self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] = 0
            self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
    elif action_list[0] == 3:  # move right
        if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] != 1:  # if can move
            self.agt1_pos[1] = self.agt1_pos[1] + 1
            self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] = 0
            self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1

    if self.agt1_pos == self.goal1_pos:
        reward = reward + 5

    done = False
    if reward == 5:
        done = True
    return reward, done

到这里,agent运行环境已经介绍完成。

下面给出GAIL_OppositeV4.py代码:

from torch.distributions.categorical import Categorical
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from env_OppositeV4 import EnvOppositeV4
import numpy as np
import csv
from collections import deque
import os



class Actor(nn.Module):
    def __init__(self, N_action):
        super(Actor, self).__init__()
        self.N_action = N_action
        self.fc1 = nn.Linear(2, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, self.N_action)

    def get_action(self, h):
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        h = F.softmax(self.fc3(h), dim=1)
        m = Categorical(h.squeeze(0))
        a = m.sample()
        log_prob = m.log_prob(a)
        return a.item(), h, log_prob

class Discriminator(nn.Module):
    def __init__(self, s_dim, N_action):
        super(Discriminator, self).__init__()
        self.s_dim = s_dim
        self.N_action = N_action
        self.fc1 = nn.Linear(self.s_dim + self.N_action, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, action):
        state_action = torch.cat([state, action], 1)
        x = torch.relu(self.fc1(state_action))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

class GAIL(object):
    def __init__(self, s_dim, N_action):
        self.s_dim = s_dim
        self.N_action = N_action
        self.actor1 = Actor(self.N_action)
        self.disc1 = Discriminator(self.s_dim, self.N_action)
        self.d1_optimizer = torch.optim.Adam(self.disc1.parameters(), lr=1e-3)
        self.a1_optimizer = torch.optim.Adam(self.actor1.parameters(), lr=1e-3)
        self.loss_fn = torch.nn.MSELoss()
        self.adv_loss_fn = torch.nn.BCELoss()
        self.gamma = 0.9

    def get_action(self, obs1):
        action1, pi_a1, log_prob1 = self.actor1.get_action(torch.from_numpy(obs1).float())
        return action1, pi_a1, log_prob1

    def int_to_tensor(self, action):
        temp = torch.zeros(1, self.N_action)
        temp[0, action] = 1
        return temp

    def train_D(self, s1_list, a1_list, e_s1_list, e_a1_list):
        p_s1 = torch.from_numpy(s1_list[0]).float()
        p_a1 = self.int_to_tensor(a1_list[0])
        for i in range(1, len(s1_list)):
            temp_p_s1 = torch.from_numpy(s1_list[i]).float()
            p_s1 = torch.cat([p_s1, temp_p_s1], dim=0)
            temp_p_a1 = self.int_to_tensor(a1_list[i])
            p_a1 = torch.cat([p_a1, temp_p_a1], dim=0)

        e_s1 = torch.from_numpy(e_s1_list[0]).float()
        e_a1 = self.int_to_tensor(e_a1_list[0])
        for i in range(1, len(e_s1_list)):
            temp_e_s1 = torch.from_numpy(e_s1_list[i]).float()
            e_s1 = torch.cat([e_s1, temp_e_s1], dim=0)
            temp_e_a1 = self.int_to_tensor(e_a1_list[i])
            e_a1 = torch.cat([e_a1, temp_e_a1], dim=0)

        p1_label = torch.zeros(len(s1_list), 1)
        e1_label = torch.ones(len(e_s1_list), 1)

        e1_pred = self.disc1(e_s1, e_a1)
        # print('e1_pred', e1_pred)
        loss = self.adv_loss_fn(e1_pred, e1_label)
        p1_pred = self.disc1(p_s1, p_a1)
        # print('p1_pred', p1_pred)
        loss = loss + self.adv_loss_fn(p1_pred, p1_label)
        self.d1_optimizer.zero_grad()
        loss.backward()
        self.d1_optimizer.step()

    def train_G(self, s1_list, a1_list, log_pi_a1_list, r1_list, e_s1_list, e_a1_list):
        T = len(s1_list)
        p_s1 = torch.from_numpy(s1_list[0]).float()
        p_a1 = self.int_to_tensor(a1_list[0])
        for i in range(1, len(s1_list)):
            temp_p_s1 = torch.from_numpy(s1_list[i]).float()
            p_s1 = torch.cat([p_s1, temp_p_s1], dim=0)
            temp_p_a1 = self.int_to_tensor(a1_list[i])
            p_a1 = torch.cat([p_a1, temp_p_a1], dim=0)

        e_s1 = torch.from_numpy(e_s1_list[0]).float()
        e_a1 = self.int_to_tensor(e_a1_list[0])
        for i in range(1, len(e_s1_list)):
            temp_e_s1 = torch.from_numpy(e_s1_list[i]).float()
            e_s1 = torch.cat([e_s1, temp_e_s1], dim=0)
            temp_e_a1 = self.int_to_tensor(e_a1_list[i])
            e_a1 = torch.cat([e_a1, temp_e_a1], dim=0)

        p1_pred = self.disc1(p_s1, p_a1)
        fake_reward = p1_pred.mean()

        a1_loss = torch.FloatTensor([0.0])
        for t in range(T):
            a1_loss = a1_loss + fake_reward * log_pi_a1_list[t]
        a1_loss = -a1_loss / T

        # print(a1_loss)
        self.a1_optimizer.zero_grad()
        a1_loss.backward()
        self.a1_optimizer.step()

class REINFORCE(object):
    def __init__(self, N_action):
        self.N_action = N_action
        self.actor1 = Actor(self.N_action)

    def get_action(self, obs):
        action1, pi_a1, log_prob1 = self.actor1.get_action(torch.from_numpy(obs).float())
        return action1, pi_a1, log_prob1

    def train(self, a1_list, pi_a1_list, r_list):
        a1_optimizer = torch.optim.Adam(self.actor1.parameters(), lr=1e-3)
        T = len(r_list)
        G_list = torch.zeros(1, T)
        G_list[0, T - 1] = torch.FloatTensor([r_list[T - 1]])
        for k in range(T - 2, -1, -1):
            G_list[0, k] = r_list[k] + 0.95 * G_list[0, k + 1]

        a1_loss = torch.FloatTensor([0.0])
        for t in range(T):
            a1_loss = a1_loss + G_list[0, t] * torch.log(pi_a1_list[t][0, a1_list[t]])
        a1_loss = -a1_loss / T
        a1_optimizer.zero_grad()
        a1_loss.backward()
        a1_optimizer.step()

    def save_model(self):
        torch.save(self.actor1, 'V4_actor.pkl')

    def load_model(self):
        self.actor1 = torch.load('V4_actor.pkl')

if __name__ == '__main__':
    torch.set_num_threads(1)
    env = EnvOppositeV4(9)
    max_epi_iter = 100
    max_MC_iter = 100

    # train expert policy by REINFORCE algorithm
    agent = REINFORCE(N_action=5)
    if os.path.exists('./V4_actor.pkl'):
        agent.load_model()
    else:
        print('无保存模型,将从头开始训练!')

    for epi_iter in range(max_epi_iter):
        env.reset()
        a1_list = []
        pi_a1_list = []
        r_list = []
        acc_r = 0
        for MC_iter in range(max_MC_iter):
            env.render()
            state = env.get_state()
            action1, pi_a1, log_prob1 = agent.get_action(state)
            a1_list.append(action1)
            pi_a1_list.append(pi_a1)
            reward, done = env.step([action1, 0])
            acc_r = acc_r + reward
            r_list.append(reward)
            if done:
                break
        print('Train expert, Episode', epi_iter, 'average reward', acc_r / MC_iter)
        if done:
            agent.train(a1_list, pi_a1_list, r_list)

    # record expert policy
    agent.save_model()
    exp_s_list = []
    exp_a_list = []
    env.reset()
    for MC_iter in range(max_MC_iter):
        env.render()
        state = env.get_state()
        action1, pi_a1, log_prob1 = agent.get_action(state)
        exp_s_list.append(state)
        exp_a_list.append(action1)
        reward, done = env.step([action1, 0])
        print('step', MC_iter, 'agent 1 at', exp_s_list[MC_iter], 'agent 1 action', exp_a_list[MC_iter], 'reward', reward, 'done', done)
        if done:
            break

    # generative adversarial imitation learning from [exp_s_list, exp_a_list]
    agent = GAIL(s_dim=2, N_action=5)
    for epi_iter in range(max_epi_iter):
        env.reset()
        s1_list = []
        a1_list = []
        r1_list = []
        log_pi_a1_list = []
        acc_r = 0
        for MC_iter in range(max_MC_iter):
            # env.render()
            state = env.get_state()
            action1, pi_a1, log_prob1 = agent.get_action(state)
            s1_list.append(state)
            a1_list.append(action1)
            log_pi_a1_list.append(log_prob1)
            reward, done = env.step([action1, 0])
            acc_r = acc_r + reward
            r1_list.append(reward)
            if done:
                break
        print('Imitate by GAIL, Episode', epi_iter, 'average reward', acc_r/MC_iter)
        # train Discriminator
        agent.train_D(s1_list, a1_list, exp_s_list, exp_a_list)

        # train Generator
        agent.train_G(s1_list, a1_list, log_pi_a1_list, r1_list, exp_s_list, exp_a_list)

    # learnt policy
    print('expert trajectory')
    for i in range(len(exp_a_list)):
        print('step', i, 'agent 1 at', exp_s_list[i], 'agent 1 action', exp_a_list[i])

    print('learnt trajectory')
    env.reset()
    for MC_iter in range(max_MC_iter):
        # env.render()
        state = env.get_state()
        action1, pi_a1, log_prob1 = agent.get_action(state)
        exp_s_list.append(state)
        exp_a_list.append(action1)
        reward, done = env.step([action1, 0])
        print('step', MC_iter, 'agent 1 at', exp_s_list[MC_iter], 'agent 1 action', exp_a_list[MC_iter])
        if done:
            break

运行结果为:

expert trajectory
step 0 agent 1 at [[0.44444444 0.11111111]] agent 1 action 1
step 1 agent 1 at [[0.55555556 0.11111111]] agent 1 action 4
step 2 agent 1 at [[0.55555556 0.11111111]] agent 1 action 3
step 3 agent 1 at [[0.55555556 0.22222222]] agent 1 action 1
step 4 agent 1 at [[0.66666667 0.22222222]] agent 1 action 0
step 5 agent 1 at [[0.55555556 0.22222222]] agent 1 action 0
step 6 agent 1 at [[0.44444444 0.22222222]] agent 1 action 3
step 7 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 8 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 9 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 10 agent 1 at [[0.33333333 0.33333333]] agent 1 action 4
step 11 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 12 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 13 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 14 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 15 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 16 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 17 agent 1 at [[0.55555556 0.33333333]] agent 1 action 2
step 18 agent 1 at [[0.55555556 0.22222222]] agent 1 action 3
step 19 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 20 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 21 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 22 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 23 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 24 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 25 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 26 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 27 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 28 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 29 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 30 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 31 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 32 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 33 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 34 agent 1 at [[0.22222222 0.33333333]] agent 1 action 2
step 35 agent 1 at [[0.22222222 0.22222222]] agent 1 action 3
step 36 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 37 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 38 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 39 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 40 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 41 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 42 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 43 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 44 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 45 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 46 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 47 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 48 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 49 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 50 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 51 agent 1 at [[0.66666667 0.33333333]] agent 1 action 0
step 52 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 53 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 54 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 55 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 56 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 57 agent 1 at [[0.66666667 0.33333333]] agent 1 action 4
step 58 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 59 agent 1 at [[0.77777778 0.33333333]] agent 1 action 1
step 60 agent 1 at [[0.77777778 0.33333333]] agent 1 action 4
step 61 agent 1 at [[0.77777778 0.33333333]] agent 1 action 0
step 62 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 63 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 64 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 65 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 66 agent 1 at [[0.77777778 0.33333333]] agent 1 action 0
step 67 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 68 agent 1 at [[0.77777778 0.33333333]] agent 1 action 3
step 69 agent 1 at [[0.77777778 0.44444444]] agent 1 action 3
step 70 agent 1 at [[0.77777778 0.55555556]] agent 1 action 0
step 71 agent 1 at [[0.66666667 0.55555556]] agent 1 action 0
step 72 agent 1 at [[0.55555556 0.55555556]] agent 1 action 0
step 73 agent 1 at [[0.44444444 0.55555556]] agent 1 action 0
step 74 agent 1 at [[0.33333333 0.55555556]] agent 1 action 1
step 75 agent 1 at [[0.44444444 0.55555556]] agent 1 action 4
step 76 agent 1 at [[0.44444444 0.55555556]] agent 1 action 0
step 77 agent 1 at [[0.33333333 0.55555556]] agent 1 action 1
step 78 agent 1 at [[0.44444444 0.55555556]] agent 1 action 3
step 79 agent 1 at [[0.44444444 0.66666667]] agent 1 action 0
step 80 agent 1 at [[0.33333333 0.66666667]] agent 1 action 3
step 81 agent 1 at [[0.33333333 0.77777778]] agent 1 action 1
learnt trajectory
step 0 agent 1 at [[0.44444444 0.11111111]] agent 1 action 1
step 1 agent 1 at [[0.55555556 0.11111111]] agent 1 action 4
step 2 agent 1 at [[0.55555556 0.11111111]] agent 1 action 3
step 3 agent 1 at [[0.55555556 0.22222222]] agent 1 action 1
step 4 agent 1 at [[0.66666667 0.22222222]] agent 1 action 0
step 5 agent 1 at [[0.55555556 0.22222222]] agent 1 action 0
step 6 agent 1 at [[0.44444444 0.22222222]] agent 1 action 3
step 7 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 8 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 9 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 10 agent 1 at [[0.33333333 0.33333333]] agent 1 action 4
step 11 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 12 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 13 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 14 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 15 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 16 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 17 agent 1 at [[0.55555556 0.33333333]] agent 1 action 2
step 18 agent 1 at [[0.55555556 0.22222222]] agent 1 action 3
step 19 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 20 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 21 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 22 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 23 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 24 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 25 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 26 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 27 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 28 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 29 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 30 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 31 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 32 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 33 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 34 agent 1 at [[0.22222222 0.33333333]] agent 1 action 2
step 35 agent 1 at [[0.22222222 0.22222222]] agent 1 action 3
step 36 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 37 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 38 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 39 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 40 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 41 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 42 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 43 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 44 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 45 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 46 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 47 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 48 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 49 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 50 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 51 agent 1 at [[0.66666667 0.33333333]] agent 1 action 0
step 52 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 53 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 54 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 55 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 56 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 57 agent 1 at [[0.66666667 0.33333333]] agent 1 action 4
step 58 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 59 agent 1 at [[0.77777778 0.33333333]] agent 1 action 1
step 60 agent 1 at [[0.77777778 0.33333333]] agent 1 action 4
可以看出learnt trajectory与expert trajectory轨迹一样。

好了,现在来介绍里面的细节部分:

对于我们这个自己构建的环境,我们没有专家轨迹怎么办呢?那就自己来制作专家轨迹。

这里,使用下面代码进行样本收集:

for epi_iter in range(max_epi_iter):
    env.reset()
    a1_list = []
    pi_a1_list = []
    r_list = []
    acc_r = 0
    for MC_iter in range(max_MC_iter):
        env.render()
        state = env.get_state()
        action1, pi_a1, log_prob1 = agent.get_action(state)
        a1_list.append(action1)
        pi_a1_list.append(pi_a1)
        reward, done = env.step([action1, 0])
        acc_r = acc_r + reward
        r_list.append(reward)

下面这段代码为只有agent到达绿色的目标点采用来训练网络更新参数。

if done:
    agent.train(a1_list, pi_a1_list, r_list)