Keras自定义层(常用代码)

本文介绍如何使用Keras框架创建自定义层,包括乘法层和上采样层的具体实现方式。通过这些自定义层可以更好地控制神经网络模型的行为,并展示了如何将这些层应用于实际的模型构建中。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


from keras.models import Model
from keras.layers import Input,Conv2D,Reshape,GlobalAvgPool2D
from keras.layers import Lambda
import tensorflow as tf
import keras
from keras import backend
from keras.layers import multiply
class Mutiply(keras.layers.Layer):
    """ Keras layer for Mutiply a Tensor to be the same shape as another Tensor.
    """
    def __init__(self,**kwargs):
        super(Mutiply,self).__init__(**kwargs)
    def call(self, inputs, **kwargs):
        source, target = inputs
        target_shape = keras.backend.shape(target)
        source = tf.tile(source,[1,1,1,target_shape[3]])
        return tf.multiply(source,target)

    def compute_output_shape(self, input_shape):
        return (input_shape[1][0],) + input_shape[1][1:3] + (input_shape[1][-1],)
class UpsampleLike(keras.layers.Layer):
    """ Keras layer for upsampling a Tensor to be the same shape as another Tensor.
    """

    def call(self, inputs, **kwargs):
        source, target = inputs
        target_shape = keras.backend.shape(target)
        if keras.backend.image_data_format() == 'channels_first':
            source = backend.transpose(source, (0, 2, 3, 1))
            output = tf.image.resize_nearest_neighbor(source, (target_shape[2], target_shape[3]))
            #output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest')
            output = backend.transpose(output, (0, 3, 1, 2))
            return output
        else:
            #return backend.resize_images(source, (target_shape[1], target_shape[2]), method='bilinear')
            return tf.image.resize_bilinear(source, (target_shape[1], target_shape[2]))

    def compute_output_shape(self, input_shape):
        if keras.backend.image_data_format() == 'channels_first':
            return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4]
        else:
            return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],)
def Interp(x, shape):
    ''' interpolation '''
    from keras.backend import tf as ktf
    new_height, new_width = shape
    resized = ktf.image.resize_images(
            x,
            [int(new_height), int(new_width)],
            align_corners=True)
    return resized
if __name__ == '__main__':
    x = Input(shape=(12,12,3))
    normed = Lambda(lambda z: z / 127.5 - 1.,  # Convert input feature range to [-1,1]
                    output_shape=(12, 12, 3),
                    name='lambda1')(x)
    global_feat = Lambda(
        Interp,
        arguments={'shape': (24,24)})(normed)
    global_avg = GlobalAvgPool2D()(global_feat)
    alpha = Reshape(target_shape=(1,1,3))(global_avg)
    final = multiply([alpha,global_feat])
    model = Model(x,final)
    model.summary()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值