import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from hypernetwork import Mixer
from torch.optim import Adam
import math
import tensorboardX
from tensorboardX import SummaryWriter

import yaml
import os
import matplotlib.pyplot as plt


def target_func1(x, y):
    return np.sin(2 * x * math.pi) + np.exp(y)

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

if config['log_to_tensorboard']:
    final_path = './log/{}/{}/'.format('IGM' if config['abs'] else 'NonIGM', config['seed'])
    os.makedirs(final_path, exist_ok=True)
    writer = SummaryWriter(final_path)
    with open(final_path+'/config.txt', 'w') as f:
        yaml.dump(config, f)
    f.close()
else:
    writer = None

x = np.arange(config['x_min'], config['x_max'], config['dx'])
y = np.arange(config['y_min'], config['y_max'], config['dy'])
X, Y = np.meshgrid(x, y)
X, Y = X / 100, Y / 100
target_values = torch.tensor(target_func1(X, Y)).float()
fig = plt.figure(figsize=(5, 5), dpi=500)
# plt.matshow(predict_values.detach().cpu().numpy())
plt.contourf(X * 100, Y * 100, target_values.detach().cpu().numpy(), config['density'])
plt.savefig(final_path+'/ground_truth.png'.format(iter), bbox_inches='tight')
plt.savefig(final_path+'/ground_truth.pdf'.format(iter), bbox_inches='tight')

action_x, action_y = torch.from_numpy(x).long(), torch.from_numpy(y).long()
joint_action_xy = torch.cartesian_prod(action_x, action_y)
value_x, value_y = torch.from_numpy(x).float(), torch.from_numpy(y).float()
joint_value_xy = torch.cartesian_prod(value_x, value_y) / 100
const_state = torch.ones(1).float()
const_state = const_state.unsqueeze(0).repeat([joint_action_xy.size(0), 1])

np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
torch.cuda.manual_seed(config["seed"])
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

mixing_net = Mixer(config)
optimiser = Adam(params=mixing_net.parameters(), lr=config['lr'])


if config['use_cuda']:
    joint_action_xy = joint_action_xy.cuda()
    joint_value_xy = joint_value_xy.cuda()
    const_state = const_state.cuda()
    target_values = target_values.cuda()
    mixing_net = mixing_net.cuda()


for iter in range(config['max_iter']+1):
    predict_values = mixing_net(joint_value_xy, const_state, joint_action_xy).view(*target_values.size())
    if iter % config['plot_freq'] == 0:
        fig = plt.figure(figsize=(5, 5), dpi=500)
        # plt.matshow(predict_values.detach().cpu().numpy())
        plt.contourf(X * 100, Y * 100, predict_values.detach().cpu().numpy(), config['density'])
        plt.savefig(final_path+'/iter_{}.png'.format(iter), bbox_inches='tight')
        plt.savefig(final_path+'/iter_{}.pdf'.format(iter), bbox_inches='tight')
    loss = torch.mean((target_values - predict_values) ** 2)
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    if writer is not None:
        writer.add_scalar('train_loss', loss.detach().cpu().item(), iter)
