在TensorFlow中​​fit()​​函数可以接收numpy类型数据,前提数据量不大可以全部加载到内存中,但是如果数据量过大我们就需要将其按批次读取,转化成迭代器的形式,也就是DataSets

可以将 ​​Dataset​​​ 实例直接传递给方法 ​​fit()​​​、​​evaluate()​​​ 和 ​​predict()​​:

如果使用DataSet就不需要像numpy数据那种在fit中指定batch_size了

完整代码:

"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2022/1/2
* 时间: 19:29
* 描述:
"""
import tensorflow as tf
import tensorflow.keras.datasets.mnist
from keras import Input, Model
from keras.layers import Dense
from tensorflow import keras

(train_images, train_labels), (val_images, val_labels) = tensorflow.keras.datasets.mnist.load_data()

train_images, val_images = train_images / 255.0, val_images / 255.0

train_images = train_images.reshape(60000, 784)
val_images = val_images.reshape(10000, 784)

train_datasets = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_datasets = train_datasets.shuffle(buffer_size=1024).batch(64)

val_datasets = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_datasets = val_datasets.batch(64)


def get_model():
inputs = Input(shape=(784,))
outputs = Dense(10, activation='softmax')(inputs)
model = Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
return model


model = get_model()

model.fit(
train_datasets,
epochs=5,
validation_data=val_datasets
)