tensorflow中的共享变量

在TensorFlow中,tf.Variable用于创建变量,但当需要在不同模型间共享变量时,应使用get_variable配合variable_scope。get_variable通过name属性实现变量共享,而Variable在定义时如果没有指定name,系统会自动生成。在同一作用域下,多次使用get_variable创建同名变量会导致错误,可通过variable_scope和reuse=True参数实现变量重用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

(1)用途

在构建模型时,需要使用tf.Variable来创建一个变量(也可以理解成节点)。但在某种情况下,一个模型需要使用其他模型创建的变量,两个模型一起训练。此时需要用到共享变量。这时就是通过引入get_variable方法,实现共享变量来解决这个问题。

(2) 使用get-variable获取变量

get_variable一般会配合variable_scope一起使用,以实现共享变量。variable_scope的意思是变量作用域。在某一作用域中的变量可以被设置成共享的方式,被其他网络模型使用。
get_variable函数的定义如下

    tf.get_variable(<name>, <shape>, <initializer>)

注意:
使用get_variable生成的变量是以指定的name属性为唯一标识,并不是定义的变量名称。使用时一般通过name属性定位到具体变量,并将其共享到其他模型中。

(3) get_variable和Variable的区别

Variable

import tensorflow as tf
x=tf.Variable(1.0,name="v1")
print(x,"x的名字为:",x.name)
x2=tf.Variable(2.0,name="v2")
print(x2,"x2的名字为:",x2.name)
with tf.Session() as sess:
    all_v=tf.global_variables_initializer()
    sess.run(all_v)
    print("X1=",x.eval()) #获取值
    print("X2=",x2.value()) #获取全部信息

注意:

Variable定义时没有指定名字,系统会自动给加上一个名字Variable:0。
当Variable定义多个相同的变量时,图只会当最后一个有效

get_variable

    16   get_var1 = tf.get_variable("firstvar", [1], initializer=tf.constant_
        initializer(0.3))
    17   print ("get_var1:", get_var1.name)
    18# 此时会出错,firstvar在前面已经定义,如果改为firstvar2则正常
    19   get_var1 = tf.get_variable("firstvar", [1], initializer=tf.constant_
        initializer(0.4))
    20   print ("get_var1:", get_var1.name)

(4)在特定的作用域下获取变量

在作用域下,使用get_variable,以及嵌套variable_scope。在前面的例子中已经知道使用get_variable创建两个同样名字的变量是行不通的,如果真的想要那么做,可以使用variable_scope将它们隔开,代码如下。

    import tensorflow as tf
    with tf.variable_scope("test1", ):     #定义一个作用域test1
        var1 = tf.get_variable("firstvar", shape=[2], dtype=tf.float32)

    with tf.variable_scope("test2"):
        var2 = tf.get_variable("firstvar", shape=[2], dtype=tf.float32)

    print ("var1:", var1.name)
    print ("var2:", var2.name)

其实,variable_scope里面有个reuse=True属性,表示使用已经定义过的变量。这时get_variable将不会再创建新的变量,而是去图(一个计算任务)中get_variable所创建过的变量中找与name相同的变量。

    11   with tf.variable_scope("test1", reuse=True ):
    12        var3= tf.get_variable("firstvar", shape=[2], dtype=tf.float32)
    13        with tf.variable_scope("test2"):
    14            var4 = tf.get_variable("firstvar", shape=[2], dtype=tf.float32)
    15
    16   print ("var3:", var3.name)
    17   print ("var4:", var4.name)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值