rien
This commit is contained in:
parent
dcddbd017f
commit
5cdc7b52e1
@ -1,5 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.animation as animation
|
import matplotlib.animation as animation
|
||||||
import pickle
|
import pickle
|
||||||
@ -93,7 +94,7 @@ class network:
|
|||||||
vizualisationFrame = np.empty((30, 30))
|
vizualisationFrame = np.empty((30, 30))
|
||||||
for x in range(30):
|
for x in range(30):
|
||||||
for y in range(30):
|
for y in range(30):
|
||||||
vizualisationFrame[x][y] = self.process(np.array([float(x), float(y)]))
|
vizualisationFrame[x][y] = self.process(np.array([float(x)/30, float(y)/30]))
|
||||||
vizualisationData.append([graph.imshow(vizualisationFrame, animated=True)])
|
vizualisationData.append([graph.imshow(vizualisationFrame, animated=True)])
|
||||||
|
|
||||||
inputBatches = [inputs[j:j+batchSize] for j in range(0, len(inputs), batchSize)]
|
inputBatches = [inputs[j:j+batchSize] for j in range(0, len(inputs), batchSize)]
|
||||||
@ -135,7 +136,7 @@ class network:
|
|||||||
print(self.accuracy(accuracyInputs, accuracyDesiredOutputs))
|
print(self.accuracy(accuracyInputs, accuracyDesiredOutputs))
|
||||||
|
|
||||||
if (visualize):
|
if (visualize):
|
||||||
ani = animation.ArtistAnimation(fig, vizualisationData, interval=100)
|
ani = animation.ArtistAnimation(fig, vizualisationData, interval=100, repeat_delay=1000)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def __Error(self, layer, neuron):
|
def __Error(self, layer, neuron):
|
||||||
|
Binary file not shown.
@ -12,6 +12,6 @@ trainLabels = data[1]
|
|||||||
|
|
||||||
myNetwork = network(2, 16, 1)
|
myNetwork = network(2, 16, 1)
|
||||||
|
|
||||||
learningRate = 5.0
|
learningRate = 3.0
|
||||||
|
|
||||||
myNetwork.train(trainPoints, trainLabels, learningRate, batchSize=10, epochs=1000, visualize=True)
|
myNetwork.train(trainPoints, trainLabels, learningRate, batchSize=100, epochs=3000, visualize=True)
|
||||||
|
@ -9,7 +9,7 @@ trainLabels = []
|
|||||||
|
|
||||||
random.seed(1216513)
|
random.seed(1216513)
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(1000):
|
||||||
x = random.randint(-50, 50)
|
x = random.randint(-50, 50)
|
||||||
y = random.randint(-50, 50)
|
y = random.randint(-50, 50)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user