PT21-22-Reseau-Neurones/tests/testLearning.py

50 lines
1.0 KiB
Python
Raw Normal View History

import numpy as np
import random
2021-12-22 22:08:20 +01:00
from sys import path
path.insert(1, "..")
from sobek.network import network
random.seed()
2021-12-18 20:57:44 +01:00
myNetwork = network(10, 10)
2021-12-18 20:57:44 +01:00
learningRate = 3
2021-12-22 21:35:06 +01:00
for j in range(1000):
2021-12-18 20:57:44 +01:00
rand = []
inputs = []
desiredOutputs = []
2021-12-16 17:06:51 +01:00
if (j%50 == 0):
print(j)
2021-12-18 20:57:44 +01:00
for i in range(10):
rand.append( random.randrange(10)/10)
2021-12-18 20:57:44 +01:00
for i in range(10):
desiredOutputs.append(np.zeros(10))
desiredOutputs[i][9 - int(rand[i]*10)] = 1.0
2021-12-17 08:36:45 +01:00
2021-12-18 20:57:44 +01:00
for i in range(10):
inputs.append(np.zeros(10))
inputs[i][int(rand[i]*10)] = 1.0
2021-12-18 20:57:44 +01:00
myNetwork.train(inputs, desiredOutputs, learningRate)
2021-12-17 08:36:45 +01:00
test = []
2021-12-18 20:57:44 +01:00
test.append(np.zeros(10))
test.append(np.zeros(10))
2021-12-17 08:36:45 +01:00
test[0][1] = 1.0
2021-12-18 20:57:44 +01:00
test[1][5] = 1.0
print(test[0])
2021-12-17 08:36:45 +01:00
print(myNetwork.process(test[0]))
2021-12-22 21:35:06 +01:00
print(test[1])
print(myNetwork.process(test[1]))
print("Save and load test :")
myNetwork.saveToFile("test")
myNetwork2 = network.networkFromFile("test")
print(myNetwork.process(test[0]).all() == myNetwork2.process(test[0]).all())