- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
一、前言
实验环境:
- 语言环境:Python3.10
- 编译器:Pycharm
- 深度学习环境:TensorFlow 2.10.0
1.设置GPU
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2' # 只显示 warning 和 Error
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import tensorflow as tf
import matplotlib.pyplot as plt
import os, PIL, pathlib
import matplotlib.pyplot as plt
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from PIL import Image
#----------------设置GPU---------------
gpus = tf.config.list_physical_devices("GPU")
if gpus:
gpu0 = gpus[0] # 如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) # 设置GPU显存用量按需使用
tf.config.set_visible_devices([gpu0], "GPU")
print(gpus)
代码输出:
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
2.导入数据
代码输入:
data_dir = "F:/lsydata/48-data/"
data_dir = pathlib.Path(data_dir)
3.查看数据
代码输入:
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:",image_count)
代码输出:
图片总数为: 1800
代码输入:
roses = list(data_dir.glob('Jennifer Lawrence/*.jpg'))
image = PIL.Image.open(str(roses[0]))
image.show()
代码输出:
二、数据预处理
1.加载数据
代码输入:
batch_size = 32
img_height = 224
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.1,
subset="training",
label_mode = "categorical",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
代码输出:
Found 1800 files belonging to 17 classes.
Using 1620 files for training.
代码输入:
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.1,
subset="validation",
label_mode = "categorical",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
代码输出:
Found 1800 files belonging to 17 classes.
Using 180 files for validation.
代码输入:
class_names = train_ds.class_names
print(class_names)
代码输出:
['Angelina Jolie', 'Brad Pitt', 'Denzel Washington', 'Hugh Jackman', 'Jennifer Lawrence', 'Johnny Depp', 'Kate Winslet', 'Leonardo DiCaprio', 'Megan Fox', 'Natalie Portman', 'Nicole Kidman', 'Robert Downey Jr', 'Sandra Bullock', 'Scarlett Johansson', 'Tom Cruise', 'Tom Hanks', 'Will Smith']
2.可视化数据
代码输入:
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
for i in range(20):
ax = plt.subplot(5, 10, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
# 如果 labels 是独热编码向量,使用 np.argmax() 获取类别索引
plt.title(class_names[np.argmax(labels[i].numpy())]) # 使用 np.argmax 转换独热编码为整数索引
plt.axis("off")
plt.show() # 添加 plt.show()
代码输出:
3.再次检查数据
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
代码输出:
(32, 224, 224, 3)
(32, 17)
4.配置数据集
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
三、构建简单的CNN网络
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995
layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""
model = models.Sequential([
layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3
layers.AveragePooling2D((2, 2)), # 池化层1,2*2采样
layers.Conv2D(32, (3, 3), activation='relu'), # 卷积层2,卷积核3*3
layers.AveragePooling2D((2, 2)), # 池化层2,2*2采样
layers.Dropout(0.3),
layers.Conv2D(64, (3, 3), activation='relu'), # 卷积层3,卷积核3*3
layers.Dropout(0.3),
layers.Flatten(), # Flatten层,连接卷积层与全连接层
layers.Dense(128, activation='relu'), # 全连接层,特征进一步提取
layers.Dense(len(class_names)) # 输出层,输出预期结果
])
model.summary() # 打印网络结构
代码输出:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
rescaling (Rescaling) (None, 224, 224, 3) 0
conv2d (Conv2D) (None, 222, 222, 16) 448
average_pooling2d (AverageP (None, 111, 111, 16) 0
ooling2D)
conv2d_1 (Conv2D) (None, 109, 109, 32) 4640
average_pooling2d_1 (Averag (None, 54, 54, 32) 0
ePooling2D)
dropout (Dropout) (None, 54, 54, 32) 0
conv2d_2 (Conv2D) (None, 52, 52, 64) 18496
dropout_1 (Dropout) (None, 52, 52, 64) 0
flatten (Flatten) (None, 173056) 0
dense (Dense) (None, 128) 22151296
dense_1 (Dense) (None, 17) 2193
=================================================================
Total params: 22,177,073
Trainable params: 22,177,073
Non-trainable params: 0
_________________________________________________________________
四、训练模型
1.设置动态学习率
# 设置初始学习率
initial_learning_rate = 1e-4
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
decay_steps=60, # 敲黑板!!!这里是指 steps,不是指epochs
decay_rate=0.96, # lr经过一次衰减就会变成 decay_rate*lr
staircase=True)
# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(optimizer=optimizer,
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
2.早停与保存最佳模型参数
代码输入:
epochs = 50
# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
monitor='val_accuracy',
verbose=1,
save_best_only=True,
save_weights_only=True)
# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
min_delta=0.001,
patience=20,
verbose=1)
3.训练模型
代码输入:
history = model.fit(train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[checkpointer, earlystopper])
代码输出:
Epoch 1/50
51/51 [==============================] - ETA: 0s - loss: 2.8314 - accuracy: 0.1093
Epoch 1: val_accuracy improved from -inf to 0.14444, saving model to best_model.h5
51/51 [==============================] - 21s 327ms/step - loss: 2.8314 - accuracy: 0.1093 - val_loss: 2.7463 - val_accuracy: 0.1444
Epoch 2/50
51/51 [==============================] - ETA: 0s - loss: 2.7182 - accuracy: 0.1377
Epoch 2: val_accuracy improved from 0.14444 to 0.15556, saving model to best_model.h5
51/51 [==============================] - 16s 316ms/step - loss: 2.7182 - accuracy: 0.1377 - val_loss: 2.7199 - val_accuracy: 0.1556
Epoch 3/50
51/51 [==============================] - ETA: 0s - loss: 2.6453 - accuracy: 0.1574
Epoch 3: val_accuracy improved from 0.15556 to 0.16667, saving model to best_model.h5
51/51 [==============================] - 16s 314ms/step - loss: 2.6453 - accuracy: 0.1574 - val_loss: 2.6856 - val_accuracy: 0.1667
Epoch 4/50
51/51 [==============================] - ETA: 0s - loss: 2.5845 - accuracy: 0.1710
Epoch 4: val_accuracy improved from 0.16667 to 0.17778, saving model to best_model.h5
51/51 [==============================] - 15s 295ms/step - loss: 2.5845 - accuracy: 0.1710 - val_loss: 2.6639 - val_accuracy: 0.1778
Epoch 5/50
51/51 [==============================] - ETA: 0s - loss: 2.4537 - accuracy: 0.2210
Epoch 5: val_accuracy did not improve from 0.17778
51/51 [==============================] - 14s 283ms/step - loss: 2.4537 - accuracy: 0.2210 - val_loss: 2.6044 - val_accuracy: 0.1611
Epoch 6/50
51/51 [==============================] - ETA: 0s - loss: 2.3626 - accuracy: 0.2481
Epoch 6: val_accuracy improved from 0.17778 to 0.18889, saving model to best_model.h5
51/51 [==============================] - 15s 293ms/step - loss: 2.3626 - accuracy: 0.2481 - val_loss: 2.5409 - val_accuracy: 0.1889
Epoch 7/50
51/51 [==============================] - ETA: 0s - loss: 2.2313 - accuracy: 0.2790
Epoch 7: val_accuracy did not improve from 0.18889
51/51 [==============================] - 15s 286ms/step - loss: 2.2313 - accuracy: 0.2790 - val_loss: 2.5864 - val_accuracy: 0.1556
Epoch 8/50
51/51 [==============================] - ETA: 0s - loss: 2.1252 - accuracy: 0.3222
Epoch 8: val_accuracy did not improve from 0.18889
51/51 [==============================] - 14s 284ms/step - loss: 2.1252 - accuracy: 0.3222 - val_loss: 2.5553 - val_accuracy: 0.1778
Epoch 9/50
51/51 [==============================] - ETA: 0s - loss: 2.0189 - accuracy: 0.3457
Epoch 9: val_accuracy improved from 0.18889 to 0.19444, saving model to best_model.h5
51/51 [==============================] - 15s 298ms/step - loss: 2.0189 - accuracy: 0.3457 - val_loss: 2.4892 - val_accuracy: 0.1944
Epoch 10/50
51/51 [==============================] - ETA: 0s - loss: 1.8827 - accuracy: 0.4068
Epoch 10: val_accuracy did not improve from 0.19444
51/51 [==============================] - 14s 279ms/step - loss: 1.8827 - accuracy: 0.4068 - val_loss: 2.5209 - val_accuracy: 0.1889
Epoch 11/50
51/51 [==============================] - ETA: 0s - loss: 1.7860 - accuracy: 0.4407
Epoch 11: val_accuracy improved from 0.19444 to 0.21111, saving model to best_model.h5
51/51 [==============================] - 15s 294ms/step - loss: 1.7860 - accuracy: 0.4407 - val_loss: 2.4616 - val_accuracy: 0.2111
Epoch 12/50
51/51 [==============================] - ETA: 0s - loss: 1.6464 - accuracy: 0.4864
Epoch 12: val_accuracy did not improve from 0.21111
51/51 [==============================] - 14s 283ms/step - loss: 1.6464 - accuracy: 0.4864 - val_loss: 2.5705 - val_accuracy: 0.2000
Epoch 13/50
51/51 [==============================] - ETA: 0s - loss: 1.5519 - accuracy: 0.5185
Epoch 13: val_accuracy improved from 0.21111 to 0.22778, saving model to best_model.h5
51/51 [==============================] - 15s 293ms/step - loss: 1.5519 - accuracy: 0.5185 - val_loss: 2.5380 - val_accuracy: 0.2278
Epoch 14/50
51/51 [==============================] - ETA: 0s - loss: 1.4248 - accuracy: 0.5765
Epoch 14: val_accuracy did not improve from 0.22778
51/51 [==============================] - 14s 282ms/step - loss: 1.4248 - accuracy: 0.5765 - val_loss: 2.4964 - val_accuracy: 0.2278
Epoch 15/50
51/51 [==============================] - ETA: 0s - loss: 1.3002 - accuracy: 0.6302
Epoch 15: val_accuracy did not improve from 0.22778
51/51 [==============================] - 15s 284ms/step - loss: 1.3002 - accuracy: 0.6302 - val_loss: 2.5376 - val_accuracy: 0.2167
Epoch 16/50
51/51 [==============================] - ETA: 0s - loss: 1.1780 - accuracy: 0.6630
Epoch 16: val_accuracy improved from 0.22778 to 0.23333, saving model to best_model.h5
51/51 [==============================] - 16s 312ms/step - loss: 1.1780 - accuracy: 0.6630 - val_loss: 2.5487 - val_accuracy: 0.2333
Epoch 17/50
51/51 [==============================] - ETA: 0s - loss: 1.0784 - accuracy: 0.6870
Epoch 17: val_accuracy did not improve from 0.23333
51/51 [==============================] - 14s 283ms/step - loss: 1.0784 - accuracy: 0.6870 - val_loss: 2.7446 - val_accuracy: 0.2333
Epoch 18/50
51/51 [==============================] - ETA: 0s - loss: 0.9745 - accuracy: 0.7401
Epoch 18: val_accuracy did not improve from 0.23333
51/51 [==============================] - 15s 285ms/step - loss: 0.9745 - accuracy: 0.7401 - val_loss: 2.6188 - val_accuracy: 0.2222
Epoch 19/50
51/51 [==============================] - ETA: 0s - loss: 0.8790 - accuracy: 0.7648
Epoch 19: val_accuracy improved from 0.23333 to 0.26667, saving model to best_model.h5
51/51 [==============================] - 16s 315ms/step - loss: 0.8790 - accuracy: 0.7648 - val_loss: 2.6858 - val_accuracy: 0.2667
Epoch 20/50
51/51 [==============================] - ETA: 0s - loss: 0.8145 - accuracy: 0.7877
Epoch 20: val_accuracy improved from 0.26667 to 0.28333, saving model to best_model.h5
51/51 [==============================] - 15s 299ms/step - loss: 0.8145 - accuracy: 0.7877 - val_loss: 2.8673 - val_accuracy: 0.2833
Epoch 21/50
51/51 [==============================] - ETA: 0s - loss: 0.7555 - accuracy: 0.7988
Epoch 21: val_accuracy did not improve from 0.28333
51/51 [==============================] - 14s 282ms/step - loss: 0.7555 - accuracy: 0.7988 - val_loss: 2.8761 - val_accuracy: 0.2667
Epoch 22/50
51/51 [==============================] - ETA: 0s - loss: 0.6843 - accuracy: 0.8272
Epoch 22: val_accuracy did not improve from 0.28333
51/51 [==============================] - 15s 288ms/step - loss: 0.6843 - accuracy: 0.8272 - val_loss: 2.8389 - val_accuracy: 0.2444
Epoch 23/50
51/51 [==============================] - ETA: 0s - loss: 0.6049 - accuracy: 0.8494
Epoch 23: val_accuracy did not improve from 0.28333
51/51 [==============================] - 14s 284ms/step - loss: 0.6049 - accuracy: 0.8494 - val_loss: 2.7282 - val_accuracy: 0.2278
Epoch 24/50
51/51 [==============================] - ETA: 0s - loss: 0.5492 - accuracy: 0.8728
Epoch 24: val_accuracy did not improve from 0.28333
51/51 [==============================] - 14s 282ms/step - loss: 0.5492 - accuracy: 0.8728 - val_loss: 2.8577 - val_accuracy: 0.2556
Epoch 25/50
51/51 [==============================] - ETA: 0s - loss: 0.4938 - accuracy: 0.8926
Epoch 25: val_accuracy did not improve from 0.28333
51/51 [==============================] - 14s 283ms/step - loss: 0.4938 - accuracy: 0.8926 - val_loss: 2.9235 - val_accuracy: 0.2778
Epoch 26/50
51/51 [==============================] - ETA: 0s - loss: 0.4711 - accuracy: 0.8969
Epoch 26: val_accuracy did not improve from 0.28333
51/51 [==============================] - 14s 284ms/step - loss: 0.4711 - accuracy: 0.8969 - val_loss: 2.9309 - val_accuracy: 0.2667
Epoch 27/50
51/51 [==============================] - ETA: 0s - loss: 0.4189 - accuracy: 0.9086
Epoch 27: val_accuracy did not improve from 0.28333
51/51 [==============================] - 15s 286ms/step - loss: 0.4189 - accuracy: 0.9086 - val_loss: 2.9222 - val_accuracy: 0.2667
Epoch 28/50
51/51 [==============================] - ETA: 0s - loss: 0.3597 - accuracy: 0.9302
Epoch 28: val_accuracy did not improve from 0.28333
51/51 [==============================] - 14s 283ms/step - loss: 0.3597 - accuracy: 0.9302 - val_loss: 3.0324 - val_accuracy: 0.2722
Epoch 29/50
51/51 [==============================] - ETA: 0s - loss: 0.3427 - accuracy: 0.9395
Epoch 29: val_accuracy improved from 0.28333 to 0.28889, saving model to best_model.h5
51/51 [==============================] - 16s 312ms/step - loss: 0.3427 - accuracy: 0.9395 - val_loss: 3.0583 - val_accuracy: 0.2889
Epoch 30/50
51/51 [==============================] - ETA: 0s - loss: 0.3183 - accuracy: 0.9389
Epoch 30: val_accuracy did not improve from 0.28889
51/51 [==============================] - 15s 289ms/step - loss: 0.3183 - accuracy: 0.9389 - val_loss: 3.2093 - val_accuracy: 0.2667
Epoch 31/50
51/51 [==============================] - ETA: 0s - loss: 0.2951 - accuracy: 0.9488
Epoch 31: val_accuracy improved from 0.28889 to 0.30000, saving model to best_model.h5
51/51 [==============================] - 15s 291ms/step - loss: 0.2951 - accuracy: 0.9488 - val_loss: 3.1210 - val_accuracy: 0.3000
Epoch 32/50
51/51 [==============================] - ETA: 0s - loss: 0.2711 - accuracy: 0.9549
Epoch 32: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 284ms/step - loss: 0.2711 - accuracy: 0.9549 - val_loss: 3.2228 - val_accuracy: 0.2611
Epoch 33/50
51/51 [==============================] - ETA: 0s - loss: 0.2408 - accuracy: 0.9636
Epoch 33: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 282ms/step - loss: 0.2408 - accuracy: 0.9636 - val_loss: 3.2474 - val_accuracy: 0.2833
Epoch 34/50
51/51 [==============================] - ETA: 0s - loss: 0.2295 - accuracy: 0.9636
Epoch 34: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 283ms/step - loss: 0.2295 - accuracy: 0.9636 - val_loss: 3.3462 - val_accuracy: 0.2667
Epoch 35/50
51/51 [==============================] - ETA: 0s - loss: 0.2185 - accuracy: 0.9753
Epoch 35: val_accuracy did not improve from 0.30000
51/51 [==============================] - 15s 286ms/step - loss: 0.2185 - accuracy: 0.9753 - val_loss: 3.2879 - val_accuracy: 0.2778
Epoch 36/50
51/51 [==============================] - ETA: 0s - loss: 0.2046 - accuracy: 0.9710
Epoch 36: val_accuracy did not improve from 0.30000
51/51 [==============================] - 15s 285ms/step - loss: 0.2046 - accuracy: 0.9710 - val_loss: 3.3474 - val_accuracy: 0.2722
Epoch 37/50
51/51 [==============================] - ETA: 0s - loss: 0.1939 - accuracy: 0.9735
Epoch 37: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 281ms/step - loss: 0.1939 - accuracy: 0.9735 - val_loss: 3.4008 - val_accuracy: 0.2944
Epoch 38/50
51/51 [==============================] - ETA: 0s - loss: 0.1793 - accuracy: 0.9747
Epoch 38: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 283ms/step - loss: 0.1793 - accuracy: 0.9747 - val_loss: 3.3911 - val_accuracy: 0.2889
Epoch 39/50
51/51 [==============================] - ETA: 0s - loss: 0.1593 - accuracy: 0.9858
Epoch 39: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 283ms/step - loss: 0.1593 - accuracy: 0.9858 - val_loss: 3.4803 - val_accuracy: 0.2889
Epoch 40/50
51/51 [==============================] - ETA: 0s - loss: 0.1506 - accuracy: 0.9858
Epoch 40: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 281ms/step - loss: 0.1506 - accuracy: 0.9858 - val_loss: 3.5497 - val_accuracy: 0.2778
Epoch 41/50
51/51 [==============================] - ETA: 0s - loss: 0.1398 - accuracy: 0.9840
Epoch 41: val_accuracy did not improve from 0.30000
51/51 [==============================] - 15s 283ms/step - loss: 0.1398 - accuracy: 0.9840 - val_loss: 3.5341 - val_accuracy: 0.2889
Epoch 42/50
51/51 [==============================] - ETA: 0s - loss: 0.1281 - accuracy: 0.9883
Epoch 42: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 283ms/step - loss: 0.1281 - accuracy: 0.9883 - val_loss: 3.6202 - val_accuracy: 0.2722
Epoch 43/50
51/51 [==============================] - ETA: 0s - loss: 0.1187 - accuracy: 0.9870
Epoch 43: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 279ms/step - loss: 0.1187 - accuracy: 0.9870 - val_loss: 3.6078 - val_accuracy: 0.2833
Epoch 44/50
51/51 [==============================] - ETA: 0s - loss: 0.1246 - accuracy: 0.9901
Epoch 44: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 283ms/step - loss: 0.1246 - accuracy: 0.9901 - val_loss: 3.6804 - val_accuracy: 0.2889
Epoch 45/50
51/51 [==============================] - ETA: 0s - loss: 0.1191 - accuracy: 0.9889
Epoch 45: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 284ms/step - loss: 0.1191 - accuracy: 0.9889 - val_loss: 3.7361 - val_accuracy: 0.2778
Epoch 46/50
51/51 [==============================] - ETA: 0s - loss: 0.1091 - accuracy: 0.9944
Epoch 46: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 282ms/step - loss: 0.1091 - accuracy: 0.9944 - val_loss: 3.7907 - val_accuracy: 0.2722
Epoch 47/50
51/51 [==============================] - ETA: 0s - loss: 0.1011 - accuracy: 0.9914
Epoch 47: val_accuracy did not improve from 0.30000
51/51 [==============================] - 15s 293ms/step - loss: 0.1011 - accuracy: 0.9914 - val_loss: 3.7820 - val_accuracy: 0.2833
Epoch 48/50
51/51 [==============================] - ETA: 0s - loss: 0.0988 - accuracy: 0.9932
Epoch 48: val_accuracy did not improve from 0.30000
51/51 [==============================] - 14s 283ms/step - loss: 0.0988 - accuracy: 0.9932 - val_loss: 3.8408 - val_accuracy: 0.2833
Epoch 49/50
51/51 [==============================] - ETA: 0s - loss: 0.0907 - accuracy: 0.9957
Epoch 49: val_accuracy did not improve from 0.30000
51/51 [==============================] - 15s 286ms/step - loss: 0.0907 - accuracy: 0.9957 - val_loss: 3.8006 - val_accuracy: 0.2944
Epoch 50/50
51/51 [==============================] - ETA: 0s - loss: 0.0949 - accuracy: 0.9907
Epoch 50: val_accuracy did not improve from 0.30000
51/51 [==============================] - 15s 285ms/step - loss: 0.0949 - accuracy: 0.9907 - val_loss: 3.9532 - val_accuracy: 0.2833
五、模型评估
1.Loss与Accuracy图
代码输入:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(loss))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
代码输出:
2.指定图片进行预测
# 加载效果最好的模型权重
from PIL import Image
import numpy as np
img = Image.open("F:/lsydata/48-data/Jennifer Lawrence/003_963a3627.jpg") #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])
img_array = tf.expand_dims(image, 0)
predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])
代码输出:
预测结果为: Jennifer Lawrence