CIFAR100数据集练习
import tensorflow as tf
import os
class CNNMnist(object):
def __init__(self):
(self.train,self.train_label),(self.test,self.test_label)=tf.keras.datasets.cifar100.load_data()
self.train=self.train.reshape(-1,32,32,3)/255.0
self.test=self.test.reshape(-1,32,32,3)/255.0
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, kernel_size=5, strides=1,
padding='same', data_format='channels_last', activation=tf.nn.relu),
tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same'),
tf.keras.layers.Conv2D(64, kernel_size=5, strides=1,
padding='same', data_format='channels_last', activation=tf.nn.relu),
tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1024, activation=tf.nn.relu),
tf.keras.layers.Dense(100, activation=tf.nn.softmax),
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
data=CNNMnist()
model.fit(data.train,data.train_label, epochs=1, batch_size=32)
test_loss, test_acc = model.evaluate(data.test, data.test_label)
print(test_loss, test_acc)
model.save_weights('./checkpoints/my_checkpoint')
model1.load_weights("./checkpoints/my_checkpoint")