前言

Tensorflow2.0中引入了keras库,极大的简化了我们搭建网络的复杂度;同时eager模式的引入,更加方便了我们对代码的编写及其调试。我们知道原始的keras框架可定制性不强,但在tensorflow2.0中可以自定义我们的每一层和模型。

正文:

下列为使用Sequential(容器)搭建多全连接层网络,我没有去查看源码,但我确定它肯定是继承了keras.Model类;因此它可以调用build()配置参数,调用fit()进行训练。

在tf2.0中,我们创建的Sequential的model直接可以直接model(输入)进行正向传播,拿到输出值,不需要调用第一层的call()方法将输入值输入到第一层,获得输出再输入第二层,这样一步一步得到最终的输出;其实这些操作在model.call()中已经被实现;即执行model(输入)操作时,底层会自动调用model.call()方法。

tensorflow2模型搭建 tensorflow2自定义模型_tensorflow

1)自定义层:

如果我们不仅限使用上述的layers.Dense(),而是自定义我们的全连接层,该怎么做呢?其实非常简单,只需要我们自定义类继承keras.layers.Layers类,同时我们要实现该父类的一些方法,包括:

  1. __init()__ 
  2. call() :这里实现自定义逻辑

这里我们自定义一个全连接层MyDense(),同时实现了 __init()__ 方法,在这里定义了2个变量分别为w和b;注意:我们的自定义的2个变量一定要使用add_variable()这种方法创建,因为我们要让创建的变量交由上上文管理器进行管理,而不能使用tf.constant()类似的方法创建变量。call() 方法中返回该层进行自定义操作的结果,下面这里直接传统的全连接运算,返回运算结果。training也是要经常根据自己的业务逻辑进行处理的。

tensorflow2模型搭建 tensorflow2自定义模型_自定义_02

2)自定义模型:

对比自定义层,自定义模型稍微复杂些,同样自定义类继承keras.Model类,同时我们要实现该父类的一些方法,包括:

  1. __init()__ 
  2. call() :这里实现自定义逻辑
  3. compile()
  4. fit()
  5. evaluate()
  6. predict()

 下面自定义了一个名为MyModel的模型,在__init__()中我们自定义的添加一些全连接层;同时在call()方法中指定这些全连接层如何进行传递的;我们没有重写compile()、fit()、evaluate()、predict()等方法。

tensorflow2模型搭建 tensorflow2自定义模型_tensorflow_03

后面会抽出时间去读一些这些方法的源码,做到查缺补漏。