PT21-22-Reseau-Neurones/tests/MNISTLearning.py

73 lines
2.1 KiB
Python
Raw Permalink Normal View History

2021-12-22 21:35:06 +01:00
import numpy as np
import gzip
import time
2021-12-22 22:08:20 +01:00
from sys import path
path.insert(1, "..")
from sobek.network import network
2021-12-22 21:35:06 +01:00
print("--- Data loading ---")
def getData(fileName):
with open(fileName, 'rb') as f:
data = f.read()
return np.frombuffer(gzip.decompress(data), dtype=np.uint8).copy()
tempTrainImages = getData("./MNIST/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 784)).tolist()
trainImages = []
for image in tempTrainImages:
for pixel in range(784):
if image[pixel] !=0:
image[pixel] = image[pixel]/256
trainImages.append(np.array(image, dtype=np.float64))
tempTrainLabels = getData("./MNIST/train-labels-idx1-ubyte.gz")[8:]
trainLabels = []
for label in tempTrainLabels:
trainLabels.append(np.zeros(10))
trainLabels[-1][label] = 1.0
2022-01-06 12:13:22 +01:00
tempAccuracyImages = getData("./MNIST/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 784)).tolist()
accuracyImages = []
for image in tempAccuracyImages:
for pixel in range(784):
if image[pixel] !=0:
image[pixel] = image[pixel]/256
accuracyImages.append(np.array(image, dtype=np.float64))
tempAccuracyLabels = getData("./MNIST/t10k-labels-idx1-ubyte.gz")[8:]
accuracyLabels = []
for label in tempAccuracyLabels:
accuracyLabels.append(np.zeros(10))
accuracyLabels[-1][label] = 1.0
myNetwork = network(784, 32, 10)
2021-12-22 21:35:06 +01:00
learningRate = 3.0
print("--- Learning ---")
startTime = time.perf_counter()
"""
for i in range(1):
print("Epoch: " + str(i))
batchEnd = 10
while batchEnd < 1000:
batchImages = trainImages[:batchEnd]
batchLabels = trainLabels[:batchEnd]
myNetwork.train(batchImages, batchLabels, learningRate)
batchEnd += 10
if (batchEnd%100) == 0:
print(batchEnd)
"""
2022-01-06 12:13:22 +01:00
myNetwork.train(trainImages, trainLabels, learningRate, batchSize=10, epochs=30, accuracyInputs=accuracyImages, accuracyDesiredOutputs=accuracyLabels)
2021-12-22 21:35:06 +01:00
endTime = time.perf_counter()
print("Learning time : " + str(endTime - startTime))
print(trainLabels[121])
print(myNetwork.process(trainImages[121]))
2022-01-06 12:13:22 +01:00
myNetwork.saveToFile("MNISTtest3")