Tensorflow实战:ResNet原理及实现(多注释)

本文详细介绍了ResNet残差神经网络的原理,并通过Tensorflow实现了ResNet_V2_152结构,包括网络配置、前向计算测评及运行日志展示,帮助读者理解ResNet如何解决信息丢失问题并简化学习目标。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        参考《Tensorflow实战》黄文坚,对Inception_V3进行了实现与改进,增加了自己的理解,欢迎提问!!

        残差神经单元:假定某段神经网络的输入是x,期望输出是H(x),如果我们直接将输入x传到输出作为初始结果,那么我们需要学习的目标就是F(x) = H(x) - x,这就是一个残差神经单元,相当于将学习目标改变了,不再是学习一个完整的输出H(x),只是输出和输入的差别 H(x) - x ,即残差。

        如下图分别为两层及三层的ResNet残差学习模块:

 

        本文使用三层的ResNet残差学习模块:

        可以看到普通的直连的卷积神经网络和ResNet的最大区别在于,ResNet有很多旁路的支线将输入直接连到后面的层,使得后面的层可以直接学习残差,这种结构也被称为shortcut或skip connections。

        传统的卷积层或全连接层在信息传递时,或多或少会存在信息丢失、损耗等问题。ResNet在某种程度上解决了这个问题,通过直接将输入信息绕道传到输出,保护信息的完整性,整个网络只需要学习输入、输出差别的那一部分,简化了学习目标和难度。

下图为ResNet不同层数时的网络配置,[ ]表示一个残差学习模块。本文使用152-layer的配置,构建ResNet_V2结构。

本文ResNet_V2结构如下图所示

为方便理解,贴出方法调用关系,如下图,红色显示在该位置调用其他方法,绿色表示在该处返回参数。

本文使用ResNet_V2_152的结构及参数,进行了前向计算的测评,代码及详细注释如下:

import tensorflow as tf
import collections
import time
from datetime import datetime
import math

'''############################################04《TensorFlow实战》实现ResNet_V2##################################################'''
slim = tf.contrib.slim

'''使用collections.namedtuple设计ResNet基本Block模块组的named tuple,并用它创建Block类'''
class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
    'a named tuple decribing a ResNet block.'

'''一个典型的Block
    需要输入参数,分别是scope、unit_fn、args
    以Block('block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)])为例,它可以定义一个典型的Block
    其中
        1、block1就是这个Block的名称(或scope)
        2、bottleneck是ResNet V2中的残差学习单元
        3、[(256, 64, 1)] * 2 + [(256, 64, 2)]时这个Block的args,args是一个列表,其中每一个元素都对应一个bottleneck残差学习单元,
        前面两个都是(256, 64, 1),最后一个是(256, 64, 2)。每个元素都是一个三元tuple,即(depth, depth_bottleneck, stride)
        比如
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值