from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

from operator import attrgetter
from copy import deepcopy
import numpy as np
from absl import logging
import random


class Matrix_gameEnv_2(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
            self,
            seed,
            k,
            n_agents=2,
    ):
        # Map arguments
        self.n_agents = n_agents
        self.seed = seed

        # Actions
        self.n_actions = 3

        # Statistics
        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.battles_won = 0
        self.battles_game = 0

        self.episode_limit = 1
        self.k = k

        # self.matrix_table = np.zeros([self.n_actions, self.n_actions])
        # for i in range(self.n_actions):
        #    for j in range(self.n_actions):
            #    self.matrix_table[i][j] = np.cos(i / 100 * 4 * np.pi) + np.exp(j / 100)
        self.matrix_table = np.array([[-self.k, 0, 10], [0, 2, 0], [8, 0, -self.k]])
        # self.matrix_table = np.array([[8, -12, -12], [-12, 6, 0], [-12, 0, 6]])

        # Qatten
        self.unit_dim = 1

    def step(self, actions):
        """Returns reward, terminated, info."""
        self._total_steps += 1
        self._episode_steps += 1
        info = {}

        reward = self.matrix_table[actions[0]][actions[1]]

        terminated = False
        info['battle_won'] = False

        if self._episode_steps >= self.episode_limit:
            terminated = True

        if terminated:
            if reward == 8:
                info['battle_won'] = True
            self._episode_count += 1
            self.battles_game += 1

        return reward, terminated, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        return [self.get_obs_agent(i) for i in range(self.n_agents)]

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        return np.array([self._episode_steps])

    def get_obs_size(self):
        """Returns the size of the observation."""
        return 1

    def get_state(self):
        """Returns the global state."""
        return np.array([0 for _ in range(self.n_agents)])

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.n_agents

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        self._episode_steps = 0

        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit,
                    "unit_dim": self.unit_dim}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "win_rate": self.battles_won / self.battles_game
        }
        return stats
