From bfd06b4f297de68873290df002cf54c76c1e374a Mon Sep 17 00:00:00 2001 From: eynard Date: Fri, 17 Dec 2021 08:36:45 +0100 Subject: [PATCH] changements tests --- testLearning.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/testLearning.py b/testLearning.py index 8f2234c..55ede76 100644 --- a/testLearning.py +++ b/testLearning.py @@ -4,37 +4,42 @@ from sobek.network import network random.seed() -myNetwork = network(1, 10) +myNetwork = network(10, 10) learningRate = 1 -for j in range(100000): +for j in range(10000): inputs = [] + inputs2 = [] desiredOutputs = [] if (j%50 == 0): print(j) for i in range(1000): - inputs.append([random.randrange(10)]) + inputs.append([(random.randrange(10)/10)]) inputs = np.array(inputs, dtype=object) for i in range(1000): desiredOutputs.append([0]*10) - desiredOutputs[i][9 - inputs[i][0]] = 1.0 + desiredOutputs[i][9 - int(inputs[i][0]*10)] = 1.0 desiredOutputs = np.array(desiredOutputs, dtype=object) + + for i in range(1000): + inputs2.append([0]*10) + inputs2[i][int(inputs[i][0]*10)] = 1.0 + inputs2 = np.array(inputs2, dtype=object) if (j%10000 == 0): learningRate*= 0.1 - myNetwork.train(inputs, desiredOutputs, learningRate) + + myNetwork.train(inputs2, desiredOutputs, learningRate) -print(myNetwork.process(np.array([0.0], dtype=object))) -print(myNetwork.process(np.array([1.0], dtype=object))) -print(myNetwork.process(np.array([2.0], dtype=object))) -print(myNetwork.process(np.array([3.0], dtype=object))) -print(myNetwork.process(np.array([4.0], dtype=object))) -print(myNetwork.process(np.array([5.0], dtype=object))) -print(myNetwork.process(np.array([6.0], dtype=object))) -print(myNetwork.process(np.array([7.0], dtype=object))) -print(myNetwork.process(np.array([8.0], dtype=object))) -print(myNetwork.process(np.array([9.0], dtype=object))) \ No newline at end of file +test = [] +test.append([0]*10) +test.append([0]*10) +test[0][1] = 1.0 +test[1][8] = 1.0 +test = np.array(test, dtype=object) +print(myNetwork.process(test[0])) +print(myNetwork.process(test[1])) \ No newline at end of file