diff --git a/sobek/network.py b/sobek/network.py index 8d02041..c5575cd 100755 --- a/sobek/network.py +++ b/sobek/network.py @@ -1,5 +1,6 @@ import random import numpy as np +import matplotlib import matplotlib.pyplot as plt import matplotlib.animation as animation import pickle @@ -93,7 +94,7 @@ class network: vizualisationFrame = np.empty((30, 30)) for x in range(30): for y in range(30): - vizualisationFrame[x][y] = self.process(np.array([float(x), float(y)])) + vizualisationFrame[x][y] = self.process(np.array([float(x)/30, float(y)/30])) vizualisationData.append([graph.imshow(vizualisationFrame, animated=True)]) inputBatches = [inputs[j:j+batchSize] for j in range(0, len(inputs), batchSize)] @@ -135,7 +136,7 @@ class network: print(self.accuracy(accuracyInputs, accuracyDesiredOutputs)) if (visualize): - ani = animation.ArtistAnimation(fig, vizualisationData, interval=100) + ani = animation.ArtistAnimation(fig, vizualisationData, interval=100, repeat_delay=1000) plt.show() def __Error(self, layer, neuron): diff --git a/tests/flowerGardenData b/tests/flowerGardenData index ac5a67b..d965c87 100755 Binary files a/tests/flowerGardenData and b/tests/flowerGardenData differ diff --git a/tests/flowerGardenLearningVisualization.py b/tests/flowerGardenLearningVisualization.py index 1ebc9d1..e78ab0d 100755 --- a/tests/flowerGardenLearningVisualization.py +++ b/tests/flowerGardenLearningVisualization.py @@ -12,6 +12,6 @@ trainLabels = data[1] myNetwork = network(2, 16, 1) -learningRate = 5.0 +learningRate = 3.0 -myNetwork.train(trainPoints, trainLabels, learningRate, batchSize=10, epochs=1000, visualize=True) +myNetwork.train(trainPoints, trainLabels, learningRate, batchSize=100, epochs=3000, visualize=True) diff --git a/tests/generateSobekFlowerGarden.py b/tests/generateSobekFlowerGarden.py index dc9a57c..fb6e5ee 100755 --- a/tests/generateSobekFlowerGarden.py +++ b/tests/generateSobekFlowerGarden.py @@ -9,7 +9,7 @@ trainLabels = [] random.seed(1216513) -for i in range(100): +for i in range(1000): x = random.randint(-50, 50) y = random.randint(-50, 50)