# Code developed in class Monday, March 16 (week 8)

# Implementation of the Delta Rule for a perceptron with 2 inputs
# and a continuous-valued output between 0.0 and 1.0

import random, math

class ContinuousPerceptron:

    # constructor
    def __init__(self):
        self.weight1 = random.uniform(-0.1, 0.1)
        self.weight2 = random.uniform(-0.1, 0.1)
        self.bias = random.uniform(-0.1, 0.1)
        self.learning_rate = 0.1
        # NEW
        self.tolerance = 0.1

    # sets the weights and bias to the given values
    def set_weights(self, w1, w2, b):
        self.weight1 = w1
        self.weight2 = w2
        self.bias = b

    # prints out the current weight and bias values
    def show_weights(self):
        print(f"{self.weight1:+.4f} {self.weight2:+.4f} {self.bias:+.4f}")

    # NEW
    def initialize(self):
        self.weight1 = random.uniform(-0.1, 0.1)
        self.weight2 = random.uniform(-0.1, 0.1)
        self.bias = random.uniform(-0.1, 0.1)
        print("weights randomized")
        self.show_weights()

    # MODIFIED
    # computes the output of the perceptron given an input pattern
    def propagate(self, pattern):
        x1, x2 = pattern
        total = self.weight1 * x1 + self.weight2 * x2 + self.bias
        output = 1 / (1 + math.exp(-total))
        return output

    # MODIFIED
    # returns a PAIR of values (e, c) where e is the total sum-squared error
    # for all patterns in the given dataset, and c is the fraction of output
    # values that are within self.tolerance of the given target values
    def total_error(self, patterns, targets):
        sum_squared_error = 0
        count = 0
        for pattern, target in zip(patterns, targets):
            output = self.propagate(pattern)
            if abs(target - output) <= self.tolerance:
                count += 1
            sum_squared_error += (target - output) ** 2
        correct = count / len(targets)
        return sum_squared_error, correct

    # MODIFIED
    # trains the perceptron to completion on the given set of patterns
    def train(self, patterns, targets):
        error, correct = self.total_error(patterns, targets)
        epoch = 0
        while correct < 1.0:
            # one training epoch
            for pattern, target in zip(patterns, targets):
                self.adjust_weights(pattern, target)
            epoch += 1
            error, correct = self.total_error(patterns, targets)
            print(f"Epoch {epoch:5}: TSS error: {error:.5f}, correct: {correct:.3f}")
        print("All patterns learned")

    # MODIFIED
    # updates the weights and bias for a single pattern/target pair
    def adjust_weights(self, pattern, target):
        output = self.propagate(pattern)
        delta = (output - target) * output * (1 - output)   # Delta Rule
        x1, x2 = pattern
        # compute weight changes
        weight1_change = -self.learning_rate * x1 * delta
        weight2_change = -self.learning_rate * x2 * delta
        bias_change = -self.learning_rate * delta
        # update weights
        self.weight1 += weight1_change
        self.weight2 += weight2_change
        self.bias += bias_change

    # MODIFIED
    # shows the current output of the perceptron on the given patterns
    def test(self, patterns, targets):
        for pattern, target in zip(patterns, targets):
            output = self.propagate(pattern)
            if abs(output - target) <= self.tolerance:
                print(f"{pattern} --> {output}")
            else:
                print(f"{pattern} --> {output}  (WRONG, should be {target})")
        tss_error, correct = self.total_error(patterns, targets)
        print(f"TSS error = {tss_error:.1f}, correct: {correct:.3f}")


patterns = [[0,0], [0,1], [1,0], [1,1]]

ANDtargets  = [0, 0, 0, 1]
ORtargets   = [0, 1, 1, 1]
NANDtargets = [1, 1, 1, 0]
NORtargets  = [1, 0, 0, 0]
XORtargets  = [0, 1, 1, 0]

p = ContinuousPerceptron()

# may get stuck in a local minimum:
# p.initialize()
# p.train(patterns, ANDtargets)
# p.train(patterns, NANDtargets)
