# Code developed in class Wednesday, March 4 (week 7)

# Basic Perceptron implementation

import random

class Perceptron:

    # constructor
    def __init__(self):
        self.weight1 = random.uniform(-0.1, 0.1)
        self.weight2 = random.uniform(-0.1, 0.1)
        self.bias = random.uniform(-0.1, 0.1)
        self.learning_rate = 0.1

    # sets the weights and bias to the given values
    def set_weights(self, w1, w2, b):
        self.weight1 = w1
        self.weight2 = w2
        self.bias = b

    # prints out the current weight and bias values
    def show_weights(self):
        print(f"{self.weight1:+.4f} {self.weight2:+.4f} {self.bias:+.4f}")

    # computes the output of the perceptron given an input pattern
    def propagate(self, pattern):
        x1, x2 = pattern
        total = self.weight1 * x1 + self.weight2 * x2 + self.bias
        if total >= 0:
            output = 1
        else:
            output = 0
        return output

    # returns the total error for the given set of patterns
    def total_error(self, patterns, targets):
        sum_squared_error = 0
        for pattern, target in zip(patterns, targets):
            output = self.propagate(pattern)
            sum_squared_error += (target - output) ** 2
        return sum_squared_error

    # trains the perceptron to completion on the given set of patterns
    def train(self, patterns, targets):
        error = self.total_error(patterns, targets)
        while error > 0:
            # one training epoch
            for pattern, target in zip(patterns, targets):
                self.adjust_weights(pattern, target)
            error = self.total_error(patterns, targets)
        print("All patterns learned")

    # updates the weights and bias for a single pattern/target pair
    def adjust_weights(self, pattern, target):
        output = self.propagate(pattern)
        if output == target:
            # no weight changes necessary
            pass
        else:
            error = target - output
            x1, x2 = pattern
            # compute weight changes
            weight1_change = self.learning_rate * x1 * error
            weight2_change = self.learning_rate * x2 * error
            bias_change = self.learning_rate * error
            # update weights
            self.weight1 += weight1_change
            self.weight2 += weight2_change
            self.bias += bias_change
            print(f"pattern: {pattern}, new weights: {self.weight1:+.4f} " +
                  f"{self.weight2:+.4f} {self.bias:+.4f}")

    # shows the current output of the perceptron on the given patterns
    def test(self, patterns, targets):
        for pattern, target in zip(patterns, targets):
            output = self.propagate(pattern)
            if output == target:
                print(f"{pattern} --> {output}")
            else:
                print(f"{pattern} --> {output}  (WRONG, should be {target})")
        tss_error = self.total_error(patterns, targets)
        print(f"TSS error = {tss_error:.1f}")

#------------------------------------------------------------------
# some input and target patterns

# 2-bit input patterns
patterns = [[0,0], [0,1], [1,0], [1,1]]

# targets for 2-bit input patterns
ANDtargets  = [0, 0, 0, 1]
ORtargets   = [0, 1, 1, 1]
NANDtargets = [1, 1, 1, 0]
NORtargets  = [1, 0, 0, 0]
XORtargets  = [0, 1, 1, 0]

