DQN PyTorch:深度强化学习的基础

注:本文所用代码基于PyTorch 1.9和Python 3.8。

引言

深度强化学习(Deep Reinforcement Learning,DRL)是结合了深度学习和强化学习的一种方法,它通过让智能体(Agent)从环境中获取数据,使用深度神经网络来学习环境的动态变化并作出相应的决策。DQN(Deep Q-Network)是一种经典的深度强化学习算法,它基于Q-learning算法,并使用深度神经网络来估计Q值函数。本文将介绍DQN算法的基本原理和实现,以及使用PyTorch库来实现一个简单的DQN模型。

DQN算法原理

DQN算法的核心思想是使用深度神经网络来估计Q值函数,从而避免了传统Q-learning算法中需要使用表格存储所有状态和动作的缺点。DQN算法的基本原理如下:

  1. 初始化一个深度神经网络,用于估计Q值函数。
  2. 在每个时间步t,智能体观察当前状态St,并基于当前策略选择一个动作At。
  3. 执行动作At,观察环境的反馈,包括奖励Rt+1和下一个状态St+1。
  4. 将当前状态St和执行的动作At存储到经验回放缓冲区(Experience Replay Buffer)中。
  5. 从经验回放缓冲区中随机采样一批样本,用于更新深度神经网络的参数。
  6. 重复步骤2-5直到达到预设的终止条件。

DQN算法的一个重要改进是使用目标网络(Target Network)来稳定训练过程。目标网络是一个与估计网络(Estimation Network)结构相同的神经网络,但其参数在一段时间内固定不变。通过定期将估计网络的参数复制给目标网络,可以减少训练过程中的目标值的变化,从而提高训练的稳定性。

DQN算法实现

接下来,我们将使用PyTorch库来实现一个简单的DQN模型,并使用OpenAI Gym中的CartPole环境进行训练和测试。在实现DQN模型之前,我们需要安装相应的库:

!pip install gym
!pip install torch

首先,我们导入所需的库:

import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque

然后,我们定义DQN模型的神经网络结构:

class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

接下来,我们定义DQN算法的主要逻辑:

class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.batch_size = 32
        self.model = DQN(state_size, action_size)
        self.target_model = DQN(state_size, action_size)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def act(self, state):
        if random.random() <= self.epsilon:
            return random.randrange(self.action_size)
        else:
            state = torch.from_numpy(state).float().unsqueeze(0)
            q_values = self.model(state)
            return torch.argmax(q_values).item()

    def remember(self, state, action