ENAS代码解读

ENAS代码解读

参考代码:https://github.com/TDeVries/enas_pytorch

数据集:cifar10

main函数:

def main():
    global args

    np.random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if args.fixed_arc:
        sys.stdout = Logger(filename='logs/' + args.output_filename + '_fixed.log')
    else:
        sys.stdout = Logger(filename='logs/' + args.output_filename + '.log')

    print(args)

    data_loaders = load_datasets()

    controller = Controller(search_for=args.search_for,
                            search_whole_channels=True,
                            num_layers=args.child_num_layers,
                            num_branches=args.child_num_branches,
                            out_filters=args.child_out_filters,
                            lstm_size=args.controller_lstm_size,
                            lstm_num_layers=args.controller_lstm_num_layers,
                            tanh_constant=args.controller_tanh_constant,
                            temperature=None,
                            skip_target=args.controller_skip_target,
                            skip_weight=args.controller_skip_weight)
    controller = controller.cuda()

    shared_cnn = SharedCNN(num_layers=args.child_num_layers,
                           num_branches=args.child_num_branches,
                           out_filters=args.child_out_filters,
                           keep_prob=args.child_keep_prob)
    shared_cnn = shared_cnn.cuda()

    # https://github.com/melodyguan/enas/blob/master/src/utils.py#L218
    controller_optimizer = torch.optim.Adam(params=controller.parameters(),
                                            lr=args.controller_lr,
                                            betas=(0.0, 0.999),
                                            eps=1e-3)

    # https://github.com/melodyguan/enas/blob/master/src/utils.py#L213
    shared_cnn_optimizer = torch.optim.SGD(params=shared_cnn.parameters(),
                                           lr=args.child_lr_max,
                                           momentum=0.9,
                                           nesterov=True,
                                           weight_decay=args.child_l2_reg)

    # https://github.com/melodyguan/enas/blob/master/src/utils.py#L154
    shared_cnn_scheduler = CosineAnnealingLR(optimizer=shared_cnn_optimizer,
                                             T_max=args.child_lr_T,
                                             eta_min=args.child_lr_min)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            # args = checkpoint['args']
            shared_cnn.load_state_dict(checkpoint['shared_cnn_state_dict'])
            controller.load_state_dict(checkpoint['controller_state_dict'])
            shared_cnn_optimizer.load_state_dict(checkpoint['shared_cnn_optimizer'])
            controller_optimizer.load_state_dict(checkpoint['controller_optimizer'])
            shared_cnn_scheduler.optimizer = shared_cnn_optimizer  # Not sure if this actually works
            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            raise ValueError("No checkpoint found at '{}'".format(args.resume))
    else:
        start_epoch = 0

    if not args.fixed_arc:
        train_enas(start_epoch,
                   controller,
                   shared_cnn,
                   data_loaders,
                   shared_cnn_optimizer,
                   controller_optimizer,
                   shared_cnn_scheduler)
    else:
        assert args.resume != '', 'A pretrained model should be used when training a fixed architecture.'
        train_fixed(start_epoch,
                    controller,
                    shared_cnn,
                    data_loaders)

再来看看Controller类的init

class Controller(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py
    '''
    def __init__(self,
                 search_for="macro",
                 search_whole_channels=True,
                 num_layers=12,
                 num_branches=6,
                 out_filters=36,
                 lstm_size=32,
                 lstm_num_layers=2,
                 tanh_constant=1.5,
                 temperature=None,
                 skip_target=0.4,
                 skip_weight=0.8):
        super(Controller, self).__init__()

        self.search_for = search_for # macro
        self.search_whole_channels = search_whole_channels # True
        self.num_layers = num_layers # 12
        self.num_branches = num_branches # 6
        self.out_filters = out_filters # 36

        self.lstm_size = lstm_size # 64
        self.lstm_num_layers = lstm_num_layers # 1
        self.tanh_constant = tanh_constant # 1.5
        self.temperature = temperature # None

        self.skip_target = skip_target # 0.4
        self.skip_weight = skip_weight # 0.8

        self._create_params()

num_layer为12代表最终生成12层的网络,num_branches为6代表6组操作:3x3,5x5正常卷积层,3x3,5x5深度分离卷积层,平均池化和最大池化,

再看Controller的 _create_params(self)函数:

def _create_params(self):
        '''
        https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L83
        '''
        self.w_lstm = nn.LSTM(input_size=self.lstm_size,
                              hidden_size=self.lstm_size,
                              num_layers=self.lstm_num_layers)

        self.g_emb = nn.Embedding(1, self.lstm_size)  # Learn the starting input

        if self.search_whole_channels:
            self.w_emb = nn.Embedding(self.num_branches, self.lstm_size)
            self.w_soft = nn.Linear(self.lstm_size, self.num_branches, bias=False)
        else:
            assert False, "Not implemented error: search_whole_channels = False"

        self.w_attn_1 = nn.Linear(self.lst
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值