活动介绍
file-type

优化MNIST识别:滑动平均与LeNet-5参数调整

下载需积分: 9 | 5KB | 更新于2024-08-05 | 133 浏览量 | 0 下载量 举报 收藏
download 立即下载
本资源是一份针对初学者的TensorFlow代码示例,用于实现滑动平均参数范围在MNIST手写数字识别任务中的应用,特别采用了经典的LeNet-5架构。该代码库的问题已经修复并可供运行,以便学习者理解变量作用域(variable_scope)以及如何构建一个简单的全连接神经网络。 首先,我们关注标题中的关键点,"滑动平均参数范围"可能指的是在训练过程中使用移动平均(moving average)来平滑权重更新,这是一种防止过拟合的策略,通过在训练步骤中逐渐积累并平均模型参数来获得更稳定的性能。 代码中定义了几个重要的参数: 1. **INPUT_NODE** 和 **OUTPUT_NODE** 分别表示输入层和输出层的节点数量,这里是MNIST数据集的28x28像素图像(每个像素作为输入)和10个类别(输出节点)。 2. **LAYER1_NODE** 是隐藏层的节点数,这里是500个神经元。 3. **BATCH_SIZE** 指定每次训练的样本数量,这里是100个样本。 4. **LEARNING_RATE_BASE** 和 **LEARNING_RATE_DECAY** 是学习率的基础值和衰减率,分别设置为0.8和0.99。 5. **REGULARAZTION_RATE** 是正则化率,用于防止过拟合,这里设置为0.0001。 6. **TRAINING_STEPS** 是训练轮数,总共有5000步。 7. **MOVING_AVERAGE_DECAY** 是移动平均的衰减因子,用于控制滑动平均窗口大小。 核心函数`get_weight_variable`用于创建权重变量,其中`shape`是变量的形状,`regularizer`是一个可选的正则化器。如果提供了正则器,它会被添加到损失集合中。`inferface_v`函数负责构建神经网络的第一层,使用`variable_scope`来管理变量,确保在整个网络中具有清晰的命名空间,并且在多层结构中复用权重。 在`inferface_v`函数中,`variable_scope('layer1')`用于定义一个名为'layer1'的变量作用域,内部调用`get_weight_variable`生成第一层的权重。这展示了如何利用`variable_scope`来组织和管理不同层级的权重变量,使得代码更具模块化。 这份代码示例为初学者提供了一个实战教程,介绍了如何在TensorFlow中使用LeNet-5模型进行MNIST数据集的手写数字识别,同时还涉及了滑动平均参数范围的实践、变量作用域的使用以及正则化的应用。对于想学习深度学习基础和TensorFlow编程的读者来说,这是一个很好的学习资源。

相关推荐

filetype

``` import tensorflow as tf from keras import datasets, layers, models import matplotlib.pyplot as plt # 导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签 (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data() # 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。) train_images, test_images = train_images / 255.0, test_images / 255.0 # 查看数据维数信息 print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape) #调整数据到我们需要的格式 train_images = train_images.reshape((60000, 28, 28, 1)) test_images = test_images.reshape((10000, 28, 28, 1)) print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape) train_images = train_images.astype("float32") / 255.0 def image_to_patches(images, patch_size=4): batch_size = tf.shape(images)[0] patches = tf.image.extract_patches( images=images[:, :, :, tf.newaxis], sizes=[1, patch_size, patch_size, 1], strides=[1, patch_size, patch_size, 1], rates=[1, 1, 1, 1], padding="VALID" ) return tf.reshape(patches, [batch_size, -1, patch_size*patch_size*1]) class TransformerBlock(tf.keras.layers.Layer): def __init__(self, embed_dim, num_heads): super().__init__() self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) self.ffn = tf.keras.Sequential([ tf.keras.layers.Dense(embed_dim*4, activation="relu"), tf.keras.layers.Dense(embed_dim) ]) self.layernorm1 = tf.keras.layers.LayerNormalization() self.layernorm2 = tf.keras.layers.LayerNormalization() def call(self, inputs): attn_output = self.att(inputs, inputs) out1 = self.layernorm1(inputs + attn_output) ffn_output = self.ffn(out1) return self.layernorm2(out1 + ffn_output) class PositionEmbedding(tf.keras.layers.Layer): def __init__(self, max_len, embed_dim): super().__init__() self.pos_emb = tf.keras.layers.Embedding(input_dim=max_len, output_dim=embed_dim) def call(self, x): positions = tf.range(start=0, limit=tf.shape(x)[1], delta=1) return x + self.pos_emb(positions) def build_transformer_model(): inputs = tf.keras.Input(shape=(49, 16)) # 4x4 patches x = tf.keras.layers.Dense(64)(inputs) # 嵌入维度64 # 添加位置编码 x = PositionEmbedding(max_len=49, embed_dim=64)(x) # 堆叠Transformer模块 x = TransformerBlock(embed_dim=64, num_heads=4)(x) x = TransformerBlock(embed_dim=64, num_heads=4)(x) # 分类头 x = tf.keras.layers.GlobalAveragePooling1D()(x) outputs = tf.keras.layers.Dense(10, activation="softmax")(x) return tf.keras.Model(inputs=inputs, outputs=outputs) model = build_transformer_model() model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) # 数据预处理 train_images_pt = image_to_patches(train_images[..., tf.newaxis]) test_images_pt = image_to_patches(test_images[..., tf.newaxis]) history = model.fit( train_images_pt, train_labels, validation_data=(test_images_pt, test_labels), epochs=10, batch_size=128 )```代码检查并添加注释