学习笔记[email protected]

本文介绍了如何在TensorFlow2中通过@tf.function将EagerExecution转换为图执行模式,提升模型训练速度。通过封装函数并加上修饰符,实现模型的高效执行,并探讨了注意事项和适用场景。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow 2 默认的即时执行模式(Eager Execution)为我们带来了灵活及易调试的特性,但为了追求更快的速度与更高的性能,我们依然希望使用 TensorFlow 1.X 中默认的图执行模式(Graph Execution)。此时,TensorFlow 2 为我们提供了 tf.function 模块,结合 AutoGraph 机制,使得我们仅需加入一个简单的 @tf.function 修饰符,就能轻松将模型以图执行模式运行。

实现方式

只需要将我们希望以图执行模式运行的代码封装在一个函数内,并在函数前加上 @tf.function 即可。

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import time

np.random.seed(42)  # 设置numpy随机数种子
tf.random.set_seed(42)  # 设置tensorflow随机数种子

# 生成训练数据
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1  # y=x^2+1 + 随机噪声
x_train = np.expand_dims(x, 1)  # 将一维数据扩展为二维
y_train = np.expand_dims(y, 1)  # 将一维数据扩展为二维
plt.plot(x, y, '.')  # 画出训练数据


def create_model():
    inputs = keras.Input((1,))
    x = keras.layers.Dense(10, activation='relu')(inputs)
    outputs = keras.layers.Dense(1)(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = create_model()  # 创建一个模型
loss_fn = keras.losses.MeanSquaredError()  # 定义损失函数
optimizer = keras.optimizers.SGD()  # 定义优化器


@tf.function  # 将训练过程转化为图执行模式
def train():
    with tf.GradientTape() as tape:
        y_pred = model(x_train, training=True)  # 前向传播,注意不要忘了training=True
        loss = loss_fn(y_train, y_pred)  # 计算损失
        tf.summary.scalar("loss", loss, epoch+1)  # 将损失写入tensorboard
    grads = tape.gradient(loss, model.trainable_variables)  # 计算梯度
    optimizer.apply_gradients(zip(grads, model.trainable_variables))  # 使用优化器进行反向传播
    return loss


epochs = 1000
begin_time = time.time()  # 训练开始时间
for epoch in range(epochs):
    loss = train()

    print('epoch:', epoch+1, '\t', 'loss:', loss.numpy())  # 打印训练信息
end_time = time.time()  # 训练结束时间

print("训练时长:", end_time-begin_time)

# 预测
y_pre = model.predict(x_train)

# 画出预测值
plt.plot(x, y_pre.squeeze())
plt.show()

通过实验得出结论:使用@tf.function比不使用@tf.function训练时间快了很多倍。
使用@tf.function的函数在执行时会生成一个计算图,里面的操作就是计算图的每个节点。下次调用相同的函数,且参数类型相同时,则会直接使用这个计算图计算。若函数名不同或参数类型不同时,则会另外生成一个新的计算图。

注意点

建议在函数内只使用 TensorFlow 的原生操作,不要使用过于复杂的 Python 语句,函数参数最好只包括 TensorFlow张量或 NumPy 数组。

  • 因为只有tf的原生操作才会在计算图中生产节点。(如python的原生print()函数不会生成节点,而tensorflow的tf.print()会)
  • 对于Tensorflow张量或Numpy数组作为参数的函数,只要类型相同便可重用之前的计算图。而对于python原生数据(如原生的整数、浮点数 1,1.5等)必须参数的值一模一样才会重用之前的计算图,否则的话会创建新的计算图。

另外,一般而言,当模型由较多小的操作组成的时候, @tf.function带来的提升效果较大。而当模型的操作数量较少,但单一操作均很耗时的时候,则 @tf.function 带来的性能提升不会太大。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值