Keras深度学习框架第六讲:在TensorFlow中自定义fit()方法中的操作

1、绪论

在TensorFlow中,fit() 方法是一个用于训练模型的便捷函数,它封装了训练循环(training loop)的许多常见步骤,如前向传播(forward pass)、计算损失(loss)、反向传播(backward pass)以及模型权重的更新。然而,有时候你可能想要对训练过程进行更细粒度的控制,或者添加一些自定义的步骤。
当进行监督学习时,可以使用fit()方法,并且一切都会顺利进行。

但是,当需要控制每一个小细节时,就可以完全从头开始编写自己的训练循环。

但如果需要一个自定义的训练算法,但又想从fit()的便捷功能中受益,比如回调(callbacks)、内置的分布支持(built-in distribution support)或步骤融合(step fusing)时,又该如何呢?

Keras的一个核心原则是复杂性的逐步展现。总是能够逐步深入到更低级别的工作流程中。如果高级功能不完全符合你的测试用例,也不会突然陷入困境。我们可以在保留相应级别的高级便利性的同时,对细节获得更多的控制权。

当需要自定义fit()的行为时,我们应该重写Model类的训练步骤函数。这是fit()在处理每一批数据时调用的函数。然后你就可以像平常一样调用fit()——而它将会运行你自己的学习算法。

请注意,这种模式并不会阻止你使用函数式API构建模型。无论你是构建Sequential模型、函数式API模型还是子类化模型,都可以采用这种方法。

2、准备工作

#2.1 基础设置

开始操作前请按照如下进行基础设置

import os

# This guide can only be run with the TF backend.
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
from keras import layers
import numpy as np

2.2 操作示例

以下是一个使用TensorFlow来自定义fit()的示例:

首先需要创建一个新的类,该类继承自keras.Model
然后重写train_step(self, data)这个方法。
之后我们返回数据字典,该字典将度量指标名称(包括损失)映射到它们的当前值。
输入参数data是传递给fit方法的训练数据:

  • 如果你通过调用fit(x, y, ...)传递NumPy数组,那么data将是元组(x, y)
  • 如果你通过调用fit(dataset, ...)传递一个tf.data.Dataset,那么data将是dataset在每个批次中产生的数据。

train_step()方法的主体中,我们实现了一个常规的训练更新过程,类似于你已经熟悉的过程。重要的是,我们通过self.compute_loss()计算损失,该方法封装了在compile()方法中传递的损失函数。

类似地,我们调用metric.update_state(y, y_pred)来更新在compile()方法中传递的度量指标的状态,并在最后通过self.metrics查询结果来获取它们的当前值。

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compute_loss(y=y, y_pred=y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply(gradients, trainable_vars)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value
        return {
   
   m.name: m.result() for m in self.metrics}

运行代码,看看输出结果

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.5089 - loss: 0.3778   
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 318us/step - mae: 0.3986 - loss: 0.2466
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 372us/step - mae: 0.3848 - loss: 0.2319

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699222602.443035       1 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

<keras.src.callbacks.history.History at 0x2a5599f00>

3、底层操作方法

在操作过程中,可以在compile()方法中省略损失函数的传递,而是在train_step中手动完成所有操作。对于度量指标(metrics)也是如此。

以下是一个更加底层操作的示例,它仅使用compile()方法来配置优化器:

首先,我们在__init__()方法中创建度量指标实例来跟踪损失和平均绝对误差(MAE)分数。

然后,我们实现一个自定义的train_step(),更新这些度量指标的状态(通过调用它们的update_state()方法),接着查询它们(通过result()方法)来返回当前平均值,以便进度条显示并传递给任何回调函数。

请注意,在每个epoch之间,我们需要调用度量指标的reset_states()方法!否则,调用result()会返回从训练开始以来的平均值,而我们通常处理的是每个epoch的平均值。幸运的是,框架可以为我们完成这一操作:只需在模型的metrics属性中列出你希望重置的任何度量指标对象。在每个fit() epoch的开始或调用evaluate()时,模型会自动调用这些对象上的reset_states()方法。

class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean
学习 TensorFlow 框架及其应用开发是一个循序渐进的过程,需要从基础概念入手,逐步掌握其核心功能和实际应用。以下是一些关键的学习路径和建议: ### 1. 理解张量(Tensor)与计算图 TensorFlow 的核心是基于张量的计算,张量是一种多维数组结构,可以高效地在 CPU、GPU 或 TPU 上运行。理解张量的基本操作和数据流图(计算图)是学习 TensorFlow 的第一步。TensorFlow 使用静态图(在 TensorFlow 1.x 中)或即时执行模式(Eager Execution,在 TensorFlow 2.x 中)来管理计算流程[^4]。 ### 2. 掌握基本 API 和操作 TensorFlow 提供了丰富的张量运算函数,包括矩阵乘法、卷积、池化、求和等,这些函数构成了构建神经网络的基础。可以通过官方文档和教程熟悉这些函数的使用方法,并尝试在简单的数学问题中应用它们[^3]。 ### 3. 学习模型构建与训练 使用 TensorFlow 构建深度学习模型通常涉及以下几个步骤: - **定义模型结构**:使用 `tf.keras` 模块可以快速构建模型,它提供了预定义的层(如全连接层、卷积层等)和损失函数。 - **配置训练过程**:选择优化器(如 Adam、SGD)、损失函数和评估指标。 - **训练模型**:通过 `model.fit()` 方法进行模型训练,输入训练数据和标签。 - **评估与预测**:使用 `model.evaluate()` 和 `model.predict()` 进行模型评估和预测。 ### 4. 灵活控制与定制 TensorFlow 的一大优势是其高度的灵活性,允许开发者对模型进行深度定制。例如,可以使用 `tf.GradientTape` 实现自定义的训练循环,或者使用低级 API(如 `tf.nn` 模块)构建完全自定义的网络结构。这种灵活性使 TensorFlow 成为研究和生产环境中的首选框架之一[^2]。 ### 5. 实践项目与案例 通过实际项目来巩固所学知识是学习 TensorFlow 的关键。可以从简单的图像分类任务(如 MNIST 或 CIFAR-10 数据集)开始,逐步过渡到更复杂的任务,如目标检测、自然语言处理或生成对抗网络(GAN)。此外,可以参考官方文档中的案例实战,学习如何将 TensorFlow 应用于真实场景[^1]。 ### 6. 环境搭建与版本管理 安装 TensorFlow 时,建议使用 Python 的虚拟环境(如 `venv` 或 `conda`)来管理依赖。对于 CPU 版本,可以直接使用 `pip install tensorflow` 进行安装;如果需要 GPU 支持,则需安装 `tensorflow-gpu` 包,并确保系统中已安装合适的 CUDA 和 cuDNN 版本。此外,可以使用 Docker 或源码编译来部署特定版本的 TensorFlow[^5]。 ### 7. 学习资源推荐 - **官方文档**:TensorFlow 官方文档是学习最权威的资源,提供了详细的 API 参考和教程。 - **在线课程**:Coursera、Udacity 和 Google Developers 等平台提供了多个 TensorFlow 课程。 - **书籍**:《TensorFlow 深度学习实战》、《Deep Learning with TensorFlow 2.x》等书籍适合系统性学习。 - **社区与论坛**:Stack Overflow、GitHub 和 TensorFlow 官方论坛是解决问题和获取帮助的好地方。 --- ```python import tensorflow as tf # 示例:构建一个简单的全连接神经网络 model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10) ]) # 编译模型 model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # 加载数据集 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # 训练模型 model.fit(x_train, y_train, epochs=5) # 评估模型 model.evaluate(x_test, y_test, verbose=2) ``` ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值