一.引言

函数式 API 的重要特性是能够多次重复使用一个层实例,如果对一个层实例调用两次,而不是每次调用都实例化一个新层,那么每次调用就可以重复使用相同的权重。这样可以构建具有共享分支的模型。

二.共享层权重

1.模型结构

假设模型判断两个句子的相似度,模型有两个输入,分别为句子A,句子B,并输出一个 0-1 的分数代表相似度。在这种前提下,句子AB是具备交换性的,即A与B的相似性应该与B与A的相似性是一致的,所以只需要用一个 LSTM 处理即可,其权重是根据两个输入共同学习的,模型也成为连体 LSTM 或者 共享 LSTM 模型。

2.模型构建

类似于之前初始化一个 Flatten 层一直用一样,这里初始化一个 LSTM 一直用,区别是后者会根据多个输入同时学习。

# 只实例化LSTM层一次
    lstm = layers.LSTM(32)

    left_input = Input(shape=(None, 128))
    left_output = lstm(left_input)

    right_input = Input(shape=(None, 128))
    right_output = lstm(right_input)

    merged = layers.concatenate([left_output, right_output], axis=-1)
    prediction = layers.Dense(1, activation='sigmoid')(merged)

    model = Model([left_input, right_input], prediction)

两个输入层都接入了 LSTM :  

CNN共享权重 lstm权重共享_keras

3.模型训练

正常情况下语料是 None x 128 即不定长的句子,这里偷个懒都打成 128x128,训练目标是 0-1 的浮点数:

CNN共享权重 lstm权重共享_CNN共享权重_02

vocabulary_size = 10000
    num_samples = 10000
    max_length = 128
    left_data = np.random.randint(1, vocabulary_size, size=(num_samples, max_length, max_length))
    right_data = np.random.randint(1, vocabulary_size, size=(num_samples, max_length, max_length))
    targets = np.random.rand(num_samples)
    print(targets[0:100])

    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy')

    model.fit([left_data, right_data], targets, batch_size=128, epochs=10)
Epoch 1/10
79/79 [==============================] - 4s 43ms/step - loss: 0.7331

......

Epoch 9/10
79/79 [==============================] - 3s 40ms/step - loss: 0.6935
Epoch 10/10
79/79 [==============================] - 3s 38ms/step - loss: 0.6927

三.共享模型权重

1.模型结构

除了共享层,理论上一个☝️ 模型也可以看做是一个抽象的大层,所以也可以通过输入张量给模型得到新张量,这和层的用法看起来差不多,共享层重复使用层的权重,共享模型(大层)重复使用很多的权重。 假设临近的两个摄像头一前一后采集图像,可以通过共享层将两个摄像头的图像进行提取,最后连接在一起进行后续操作。

2.模型构建

这里从 keras.application 加载了训练好的模型,include_top 选择 Flase 只留下卷积基,最后拼接出结果。 

xception_base = applications.Xception(weights=None, include_top=False)

    left_input = Input(shape=(250, 250, 3))
    right_input = Input(shape=(250, 250, 3))

    left_features = xception_base(left_input)
    right_features = xception_base(right_input)

    merged_features = layers.concatenate([left_features, right_features], axis=-1)
    merged_features = Flatten()(merged_features)
    output = layers.Dense(10, activation='softmax')(merged_features)

    model = Model([left_input, right_input], output)
    model.summary()

CNN共享权重 lstm权重共享_tensorflow_03

3.模型训练

这里 Xception 模型的参数太多了,本地实在跑步起来,这一步就忽略了。。。 总之我们可以多次调用相同的层或模型实例,在不同的处理分支之间重复使用层或者模型的权重,同时配合多输入或者多输出。