#---------------------------------------------------------------------------
# A simple simulated world for Rosie the Robot, based on Chapter 8 of
# "Artificial Intelligence: A Guide for Thinking Humans" by Melanie Mitchell

# version 4: adds support for epsilon schedules
#---------------------------------------------------------------------------
import random

class RosieWorld:

    def __init__(self, size=26):
        self.world_size = size
        self.robot_location = random.randrange(size)
        self.ball_location = 0

    def __str__(self):
        s = ""
        for i in range(self.world_size):
            if self.robot_location == i and self.ball_location == i:
                s += "oR "
            elif self.ball_location == i:
                s += "o  "
            elif self.robot_location == i:
                s += "R  "
            else:
                s += "_  "
        return s

    def forward(self):
        if self.robot_location > 0:
            self.robot_location -= 1
        return 0

    def backward(self):
        if self.robot_location < self.world_size-1:
            self.robot_location += 1
        return 0

    def kick(self):
        if self.robot_location == self.ball_location:
            self.ball_location = None
            return 10
        else:
            return 0

    def can_see_ball(self):
        if self.ball_location == None:
            return False
        else:
            return True

    # state 0 = in location 0 with no ball present
    # state 1 = in location 0 with ball present
    # state 2 = in location 1
    # state 3 = in location 2
    # ...
    # state N = in location N-1
    def get_state(self):
        if self.robot_location == 0 and self.ball_location == None:
            return 0
        elif self.robot_location == 0 and self.ball_location != None:
            return 1
        else:
            return self.robot_location + 1

#---------------------------------------------------------------------------
# simulates one episode of Rosie performing actions in her world

ACTION_NAMES = ['forward', 'back', 'kick']
ACTION_CODES = [0, 1, 2]

# version that updates Q values (in learn mode)
def run_episode(Q=None, epsilon=0.5, discount_factor=0.8, learning_rate=1, mode='run'):
    assert mode in ['run', 'step', 'quiet', 'learn'], "invalid mode"
    rosie = RosieWorld()
    if mode == 'run':
        print(rosie)
    elif mode == 'step':
        print(rosie, end="")
        command = input()
        if command == 'q':  # quit
            return
        if command == 'r':  # run to completion
            mode = 'run'
    steps = 0
    while rosie.can_see_ball():
        # observe current state
        state = rosie.get_state()

        if Q == None:
            # choose an action randomly
            action = random.choice(ACTION_CODES)
        elif random.uniform(0, 1) < epsilon:
            # with probability epsilon choose an action randomly
            action = random.choice(ACTION_CODES)
        else:
            # with probability (1 - epsilon) choose "best" action
            action = argmax(Q[state])

        # perform action and observe reward
        if action == 0:
            reward = rosie.forward()
        elif action == 1:
            reward = rosie.backward()
        elif action == 2:
            reward = rosie.kick()

        # observe new state and update Q table
        if mode == 'learn':
            new_state = rosie.get_state()

            # update rule for deterministic environments:
            #Q[state][action] = reward + discount_factor * max(Q[new_state])

            # update rule for nondeterministic environments (this is equivalent
            # to the deterministic rule above when learning_rate=1)
            Q[state][action] += learning_rate * (reward + discount_factor * max(Q[new_state])
                                                 - Q[state][action])

            #print(f"updated state {state} action {action}")

        if mode == 'run':
            print(rosie)
        elif mode == 'step':
            print(rosie, end="")
            command = input()
            if command == 'q':
                return
            if command == 'r':
                mode = 'run'
        steps += 1

    if mode in ['run', 'step']:
        plural = '' if steps == 1 else 's'
        print(f"Rosie kicked the ball after {steps} step{plural}!")
    return steps

#---------------------------------------------------------------------------

# updates Q table over multiple training episodes
def train(Q, num_episodes, epsilon=0.5, discount_factor=0.8, learning_rate=1):
    for episode_num in range(1, num_episodes+1):
        steps = run_episode(Q, epsilon, discount_factor, learning_rate, mode='learn')
        plural = '' if steps == 1 else 's'
        print(f"Episode #{episode_num}: epsilon {epsilon:.4f}, {steps} step{plural}")

def test(Q, epsilon=0.01):
    run_episode(Q, epsilon)  # no learning

# collects statistics over multiple episodes using the given Q table
def stats(Q=None, epsilon=0.5, trials=100):
    total = 0
    for i in range(trials):
        total += run_episode(Q, epsilon, mode='quiet')
    avg = total / trials
    print(f"Over {trials} trials, Rosie averaged {avg:.1f} steps to kick the ball")

# creates an empty Q table
def make_table(num_states=27): # for worlds of size 26
    table = []
    for i in range(num_states):
        new_row = [0.0, 0.0, 0.0]
        table.append(new_row)
    return table

# creates a Q table with random values, for demo purposes only
def make_random_table(num_states=27):
    table = []
    for i in range(num_states):
        new_row = [random.uniform(0, 10) for j in range(3)]
        table.append(new_row)
    return table

# prints out a Q table nicely
def print_table(Q):
    print("state    forward    backward    kick")
    for row in range(len(Q)):
        f, b, k = Q[row]
        print(f"{row:3} {f:10.3} {b:10.3} {k:10.3}")

# returns the position/index number of the maximum value in a list
def argmax(values):
    m = max(values)
    # break ties randomly
    best_indices = [i for i in range(len(values)) if values[i] == m]
    return random.choice(best_indices)

#---------------------------------------------------------------------------
# epsilon schedules

import numpy as np
import matplotlib.pyplot as plt

def constant_schedule(num_episodes, epsilon):
    schedule = [epsilon] * num_episodes
    return schedule

# creates a schedule of num_episodes epsilon values, starting with
# start_epsilon and decreasing by decrement every interval time steps,
# and going no lower than end_epsilon:

def decreasing_schedule(num_episodes, start_epsilon=1.0, interval=50,
                        decrement=0.01, end_epsilon=0.01):
    schedule = []
    epsilon = start_epsilon
    for n in range(1, num_episodes+1):
        schedule.append(epsilon)
        if n % interval == 0:
            epsilon -= decrement
            epsilon = max(epsilon, end_epsilon)
    return schedule

# examples:

# 100 episodes, starting with epsilon=1.0, decreasing epsilon by 0.2
# every 10 episodes, going no lower than epsilon=0.1
schedule1 = decreasing_schedule(100, start_epsilon=1.0, interval=10,
                                decrement=0.2, end_epsilon=0.1)

# 100 episodes, starting with epsilon=1.0, decreasing epsilon by 0.05
# every 5 episodes, going no lower than epsilon=0.01
schedule2 = decreasing_schedule(100, start_epsilon=1.0, interval=5,
                                decrement=0.05, end_epsilon=0.01)

# displays a schedule as a graph
def show_schedule(schedule):
    num_episodes = len(schedule)
    x_values = np.arange(1, num_episodes+1)
    plt.plot(x_values, schedule, 'r-')
    plt.title("Epsilon Schedule")
    plt.xlabel("episode number")
    plt.xlim(1, num_episodes)
    plt.ylabel("epsilon value")
    plt.ylim(0.0, 1.1)
    plt.show()

#show_schedule(schedule1)
#show_schedule(schedule2)

def train_by_schedule(Q, schedule, discount_factor=0.8, learning_rate=1):
    if type(schedule) is not list:
        print("Second argument must be an epsilon schedule")
        return
    episode_num = 1
    for epsilon in schedule:
        steps = run_episode(Q, epsilon, discount_factor, learning_rate, mode='learn')
        plural = '' if steps == 1 else 's'
        print(f"Episode #{episode_num}: epsilon {epsilon:.4f}, {steps} step{plural}")
        episode_num += 1
