ENAS代码解读
参考代码:https://github.com/TDeVries/enas_pytorch
数据集:cifar10
main函数:
def main():
global args
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
if args.fixed_arc:
sys.stdout = Logger(filename='logs/' + args.output_filename + '_fixed.log')
else:
sys.stdout = Logger(filename='logs/' + args.output_filename + '.log')
print(args)
data_loaders = load_datasets()
controller = Controller(search_for=args.search_for,
search_whole_channels=True,
num_layers=args.child_num_layers,
num_branches=args.child_num_branches,
out_filters=args.child_out_filters,
lstm_size=args.controller_lstm_size,
lstm_num_layers=args.controller_lstm_num_layers,
tanh_constant=args.controller_tanh_constant,
temperature=None,
skip_target=args.controller_skip_target,
skip_weight=args.controller_skip_weight)
controller = controller.cuda()
shared_cnn = SharedCNN(num_layers=args.child_num_layers,
num_branches=args.child_num_branches,
out_filters=args.child_out_filters,
keep_prob=args.child_keep_prob)
shared_cnn = shared_cnn.cuda()
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L218
controller_optimizer = torch.optim.Adam(params=controller.parameters(),
lr=args.controller_lr,
betas=(0.0, 0.999),
eps=1e-3)
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L213
shared_cnn_optimizer = torch.optim.SGD(params=shared_cnn.parameters(),
lr=args.child_lr_max,
momentum=0.9,
nesterov=True,
weight_decay=args.child_l2_reg)
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L154
shared_cnn_scheduler = CosineAnnealingLR(optimizer=shared_cnn_optimizer,
T_max=args.child_lr_T,
eta_min=args.child_lr_min)
if args.resume:
if os.path.isfile(args.resume):
print("Loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
# args = checkpoint['args']
shared_cnn.load_state_dict(checkpoint['shared_cnn_state_dict'])
controller.load_state_dict(checkpoint['controller_state_dict'])
shared_cnn_optimizer.load_state_dict(checkpoint['shared_cnn_optimizer'])
controller_optimizer.load_state_dict(checkpoint['controller_optimizer'])
shared_cnn_scheduler.optimizer = shared_cnn_optimizer # Not sure if this actually works
print("Loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
raise ValueError("No checkpoint found at '{}'".format(args.resume))
else:
start_epoch = 0
if not args.fixed_arc:
train_enas(start_epoch,
controller,
shared_cnn,
data_loaders,
shared_cnn_optimizer,
controller_optimizer,
shared_cnn_scheduler)
else:
assert args.resume != '', 'A pretrained model should be used when training a fixed architecture.'
train_fixed(start_epoch,
controller,
shared_cnn,
data_loaders)
再来看看Controller类的init
class Controller(nn.Module):
'''
https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py
'''
def __init__(self,
search_for="macro",
search_whole_channels=True,
num_layers=12,
num_branches=6,
out_filters=36,
lstm_size=32,
lstm_num_layers=2,
tanh_constant=1.5,
temperature=None,
skip_target=0.4,
skip_weight=0.8):
super(Controller, self).__init__()
self.search_for = search_for # macro
self.search_whole_channels = search_whole_channels # True
self.num_layers = num_layers # 12
self.num_branches = num_branches # 6
self.out_filters = out_filters # 36
self.lstm_size = lstm_size # 64
self.lstm_num_layers = lstm_num_layers # 1
self.tanh_constant = tanh_constant # 1.5
self.temperature = temperature # None
self.skip_target = skip_target # 0.4
self.skip_weight = skip_weight # 0.8
self._create_params()
num_layer为12代表最终生成12层的网络,num_branches为6代表6组操作:3x3,5x5正常卷积层,3x3,5x5深度分离卷积层,平均池化和最大池化,
再看Controller的 _create_params(self)函数:
def _create_params(self):
'''
https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L83
'''
self.w_lstm = nn.LSTM(input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.lstm_num_layers)
self.g_emb = nn.Embedding(1, self.lstm_size) # Learn the starting input
if self.search_whole_channels:
self.w_emb = nn.Embedding(self.num_branches, self.lstm_size)
self.w_soft = nn.Linear(self.lstm_size, self.num_branches, bias=False)
else:
assert False, "Not implemented error: search_whole_channels = False"
self.w_attn_1 = nn.Linear(self.lst