Unity+ML-Agents训练机械臂到达目标点

  • 环境搭建:
  • 训练效果:



首先你需要从Unity官方Github下载

ML-Agents机器学习代理,然后进行安装。然后还需要在你的电脑上配置Tensorflow,因为以前跑过一些深度学习的东西,Tensorflow已经配置好了,还没有装的小伙伴可以参考

这篇教程,推荐使用Anaconda进行安装,安全快捷无痛苦。

环境搭建:

首先最基础的,按照流程创建Academy和Brain。新建一个空物体,命名随意,最好具有辨识度;然后添加一个C#脚本,同样命名随意,然后把继承类改为Academy,其他不用修改。

Unity 机械臂避障 unity控制机械臂_机械臂


Unity 机械臂避障 unity控制机械臂_PPO_02


然后点击Add component,输入brain,添加Brain组件:

Unity 机械臂避障 unity控制机械臂_Unity3d_03

然后开始设置智能体,其创建流程跟Academy一样,然后将继承类改为Agent。补充一点,Academy和agent脚本中都要添加using MLAgents。因为设计多智能体并行训练,所以我将机械臂模型放在agent组件下作为它的子物体

Unity 机械臂避障 unity控制机械臂_Unity3d_04


机械臂一定要按层级关系设置好父子关系:

Unity 机械臂避障 unity控制机械臂_机械臂_05


然后在Agent脚本中设置机械臂各关节和目标物体然后再agent组件指定各关节。

Unity 机械臂避障 unity控制机械臂_PPO_06


Unity 机械臂避障 unity控制机械臂_机械臂_07


上面东西设置好后就可以开始训练了,agent训练代码

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class all_baxter_agent : Agent {
    public GameObject S0;
    public GameObject S1;
    public GameObject E0;
    public GameObject E1;
    public GameObject W0;
    public GameObject W1;

    public GameObject endJoint;
    public GameObject target;

    private float previousDistance = float.MaxValue;
    public bool goal = false;
    public void Start()
    {
        Done();
    }
    public override void AgentReset()
    {
        previousDistance = float.MaxValue;
        if (goal)
        {
            target.transform.localPosition = new Vector3(Random.Range(-22f, -17f), 7f, Random.Range(-16f, -12f)); //目标物体初始化范围
            goal = false;
        }
        else
        {
            S0.transform.localRotation = Quaternion.Euler(new Vector3(0f, 0f, 0f));
            S1.transform.localRotation = Quaternion.Euler(new Vector3(60f, 0f, 0f));
            E0.transform.localRotation = Quaternion.Euler(new Vector3(0f, 0f, 0f));
            E1.transform.localRotation = Quaternion.Euler(new Vector3(-60f, 0f, 0f));
            W0.transform.localRotation = Quaternion.Euler(new Vector3(0f, 0f, 0f));
            W1.transform.localRotation = Quaternion.Euler(new Vector3(-60f, 0f, 0f));
        }
    }
    public override void CollectObservations()
    {
        Vector3 relativePosition = gameObject.transform.position - target.transform.position;
        float distanceToTarget = Vector3.Distance(target.transform.position, endJoint.transform.position);//计算机械臂末端和目标点的距离

        AddVectorObs(S0.transform.rotation.y);
        AddVectorObs(S1.transform.rotation.x);
        AddVectorObs(E0.transform.rotation.z);
        AddVectorObs(E1.transform.rotation.x);
        AddVectorObs(W0.transform.rotation.z);
        AddVectorObs(W1.transform.rotation.x);
        AddVectorObs(relativePosition.x);
        AddVectorObs(relativePosition.y);
        AddVectorObs(relativePosition.z);
        AddVectorObs(distanceToTarget); //获取的状态观察量,作为PPO网络的输入
    }
    public override void AgentAction(float[] vectorAction, string textAction)
    {
        float distanceToTarget = Vector3.Distance(target.transform.position, endJoint.transform.position);
        var action_s0 = Mathf.Clamp(vectorAction[0], -1f, 1f);
        var action_s1 = Mathf.Clamp(vectorAction[1], -1f, 1f);
        var action_e0 = Mathf.Clamp(vectorAction[2], -1f, 1f);
        var action_e1 = Mathf.Clamp(vectorAction[3], -1f, 1f);
        var action_w0 = Mathf.Clamp(vectorAction[4], -1f, 1f);
        var action_w1 = Mathf.Clamp(vectorAction[5], -1f, 1f);  //取[-1,1]之间的随机动作,是PPO网络的输出
        
        if(Mathf.Abs(S0.transform.localRotation.y)>0.5f|| S1.transform.localRotation.x > 0.85f ||
            S1.transform.localRotation.x < -0.17f || E1.transform.localRotation.x>0f || 
            E1.transform.localRotation.x < -0.85f||Mathf.Abs(W1.transform.localRotation.x)>0.72) //限制机械臂各关节转动范围
        {
            AddReward(-5.0f);
            Done();
        }

        if (endJoint.transform.position.y < 3.2 || endJoint.transform.position.y > 9)
        {
            AddReward(-5.0F);
            Done();
        }

        if (distanceToTarget < previousDistance)//靠近目标点,给予奖励
        {
            AddReward(0.5f);
        }
        else
        {
            AddReward(-0.5f);
        }
        if(distanceToTarget<0.2f)//机械臂末端与目标物体距离小于阈值时,给予奖励
        {
            goal = true;
            AddReward(25.0f);
            Done();
        }
        previousDistance = distanceToTarget;

        S0.transform.Rotate(new Vector3(0, 1, 0), action_s0);  //旋转需要使用四元数
        S1.transform.Rotate(new Vector3(1, 0, 0), action_s1);
        E0.transform.Rotate(new Vector3(0, 0, 1), action_e0);
        E1.transform.Rotate(new Vector3(1, 0, 0), action_e1);
        W0.transform.Rotate(new Vector3(0, 0, 1), action_w0);
        W1.transform.Rotate(new Vector3(1, 0, 0), action_w1);
          }
}

训练效果:

训练初期,机械臂在状态空间中乱窜:

Unity 机械臂避障 unity控制机械臂_Unity3d_08


大概经过了4个小时训练,发现效果还不错:

Unity 机械臂避障 unity控制机械臂_PPO_09

另外,我还使用了多智能体并行训练来提高训练效果,发现很有用:

Unity 机械臂避障 unity控制机械臂_Unity 机械臂避障_10