This commit is contained in:
eynard 2022-03-10 15:09:20 +01:00
parent dcddbd017f
commit 5cdc7b52e1
4 changed files with 6 additions and 5 deletions

View File

@ -1,5 +1,6 @@
import random import random
import numpy as np import numpy as np
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.animation as animation import matplotlib.animation as animation
import pickle import pickle
@ -93,7 +94,7 @@ class network:
vizualisationFrame = np.empty((30, 30)) vizualisationFrame = np.empty((30, 30))
for x in range(30): for x in range(30):
for y in range(30): for y in range(30):
vizualisationFrame[x][y] = self.process(np.array([float(x), float(y)])) vizualisationFrame[x][y] = self.process(np.array([float(x)/30, float(y)/30]))
vizualisationData.append([graph.imshow(vizualisationFrame, animated=True)]) vizualisationData.append([graph.imshow(vizualisationFrame, animated=True)])
inputBatches = [inputs[j:j+batchSize] for j in range(0, len(inputs), batchSize)] inputBatches = [inputs[j:j+batchSize] for j in range(0, len(inputs), batchSize)]
@ -135,7 +136,7 @@ class network:
print(self.accuracy(accuracyInputs, accuracyDesiredOutputs)) print(self.accuracy(accuracyInputs, accuracyDesiredOutputs))
if (visualize): if (visualize):
ani = animation.ArtistAnimation(fig, vizualisationData, interval=100) ani = animation.ArtistAnimation(fig, vizualisationData, interval=100, repeat_delay=1000)
plt.show() plt.show()
def __Error(self, layer, neuron): def __Error(self, layer, neuron):

Binary file not shown.

View File

@ -12,6 +12,6 @@ trainLabels = data[1]
myNetwork = network(2, 16, 1) myNetwork = network(2, 16, 1)
learningRate = 5.0 learningRate = 3.0
myNetwork.train(trainPoints, trainLabels, learningRate, batchSize=10, epochs=1000, visualize=True) myNetwork.train(trainPoints, trainLabels, learningRate, batchSize=100, epochs=3000, visualize=True)

View File

@ -9,7 +9,7 @@ trainLabels = []
random.seed(1216513) random.seed(1216513)
for i in range(100): for i in range(1000):
x = random.randint(-50, 50) x = random.randint(-50, 50)
y = random.randint(-50, 50) y = random.randint(-50, 50)