# Code developed in class Wednesday, February 25 (week 6)

#---------------------------------------------------------------------------
# A simple simulated world for Rosie the Robot, based on Chapter 8 of
# "Artificial Intelligence: A Guide for Thinking Humans" by Melanie Mitchell

# version 3: adds support for Q learning in deterministic and
# nondeterministic environments
#---------------------------------------------------------------------------
import random

class RosieWorld:

    def __init__(self, size=26):
        self.world_size = size
        self.robot_location = random.randrange(size)
        self.ball_location = 0

    def __str__(self):
        s = ""
        for i in range(self.world_size):
            if self.robot_location == i and self.ball_location == i:
                s += "oR "
            elif self.ball_location == i:
                s += "o  "
            elif self.robot_location == i:
                s += "R  "
            else:
                s += "_  "
        return s

    def forward(self):
        if self.robot_location > 0:
            self.robot_location -= 1
        return 0

    def backward(self):
        if self.robot_location < self.world_size - 1:
            self.robot_location += 1
        return 0

    def kick(self):
        if self.robot_location == self.ball_location:
            self.ball_location = None
            return 10
        else:
            return 0

    def can_see_ball(self):
        if self.ball_location == None:
            return False
        else:
            return True

    # state 0 = in location 0 with no ball present
    # state 1 = in location 0 with ball present
    # state 2 = in location 1
    # state 3 = in location 2
    # ...
    # state N = in location N-1
    def get_state(self):
        if self.robot_location == 0 and self.ball_location == None:
            return 0
        elif self.robot_location == 0 and self.ball_location != None:
            return 1
        else:
            return self.robot_location + 1

#---------------------------------------------------------------------------
# simulates one episode of Rosie performing actions in her world

ACTION_NAMES = ['forward', 'back', 'kick']
ACTION_CODES = [0, 1, 2]

# new version that updates Q values (in learn mode)
def run_episode(Q=None, epsilon=0.5, discount_factor=0.8, learning_rate=1, mode='run'):
    assert mode in ['run', 'step', 'quiet', 'learn'], "invalid mode"
    rosie = RosieWorld()
    if mode == 'run':
        print(rosie)
    elif mode == 'step':
        print(rosie, end="")
        command = input()
        if command == 'q':  # quit
            return
        if command == 'r':  # run to completion
            mode = 'run'
    steps = 0
    while rosie.can_see_ball():
        # observe current state
        state = rosie.get_state()

        if Q == None:
            # choose an action randomly
            action = random.choice(ACTION_CODES)
        elif random.uniform(0, 1) < epsilon:
            # with probability epsilon choose an action randomly
            action = random.choice(ACTION_CODES)
        else:
            # with probability (1 - epsilon) choose "best" action
            action = argmax(Q[state])

        # perform action and observe reward
        if action == 0:
            reward = rosie.forward()
        elif action == 1:
            reward = rosie.backward()
        elif action == 2:
            reward = rosie.kick()

        # new
        # observe new state and update Q table
        if mode == 'learn':
            new_state = rosie.get_state()

            # update rule for deterministic environments:
            #Q[state][action] = reward + discount_factor * max(Q[new_state])

            # update rule for nondeterministic environments (this is equivalent
            # to the deterministic rule above when learning_rate=1)
            Q[state][action] += learning_rate * (reward + discount_factor * max(Q[new_state])
                                                 - Q[state][action])

            #print(f"updated state {state} action {action}")

        if mode == 'run':
            print(rosie)
        elif mode == 'step':
            print(rosie, end="")
            command = input()
            if command == 'q':  # quit
                return
            if command == 'r':  # run to completion
                mode = 'run'
        steps += 1
    # quiet mode and learn mode are completely quiet
    if mode in ['run', 'step']:
        plural = '' if steps == 1 else 's'
        print(f"Rosie kicked the ball after {steps} step{plural}!")
    return steps

#---------------------------------------------------------------------------

# new
# updates Q table over multiple training episodes
def train(Q, num_episodes, epsilon=0.5, discount_factor=0.8, learning_rate=1):
    for episode_num in range(1, num_episodes+1):
        steps = run_episode(Q, epsilon, discount_factor, learning_rate, mode='learn')
        plural = '' if steps == 1 else 's'
        print(f"Episode #{episode_num}: epsilon {epsilon:.4f}, {steps} step{plural}")

# collects statistics over multiple episodes
def stats(Q=None, epsilon=0.5, trials=100):
    total = 0
    for i in range(trials):
        total += run_episode(Q, epsilon, mode='quiet')
    avg = total / trials
    print(f"Over {trials} trials, Rosie averaged {avg:.1f} steps to kick the ball")

# creates an empty Q table
def make_table(num_states=27): # for worlds of size 26
    table = []
    for i in range(num_states):
        new_row = [0.0, 0.0, 0.0]
        table.append(new_row)
    return table

# creates a Q table with random values, for demo purposes only
def make_random_table(num_states=27):
    table = []
    for i in range(num_states):
        new_row = [random.uniform(0, 10) for j in range(3)]
        table.append(new_row)
    return table

# prints out a Q table nicely
def print_table(Q):
    print("state    forward    backward     kick")
    for row in range(len(Q)):
        f, b, k = Q[row]
        print(f"{row:3} {f:10.3} {b:10.3} {k:10.3}")

# returns the position/index number of the maximum value in a list
def argmax(values):
    m = max(values)
    # break ties randomly
    best_indices = [i for i in range(len(values)) if values[i] == m]
    return random.choice(best_indices)

