PT21-22-Reseau-Neurones/MNISTLoadTest.py

30 lines
877 B
Python
Raw Normal View History

2021-12-22 21:35:06 +01:00
import numpy as np
from sobek.network import network
import gzip
print("--- Data loading ---")
def getData(fileName):
with open(fileName, 'rb') as f:
data = f.read()
return np.frombuffer(gzip.decompress(data), dtype=np.uint8).copy()
tempTrainImages = getData("./MNIST/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 784)).tolist()
trainImages = []
for image in tempTrainImages:
for pixel in range(784):
if image[pixel] !=0:
image[pixel] = image[pixel]/256
trainImages.append(np.array(image, dtype=np.float64))
tempTrainLabels = getData("./MNIST/t10k-labels-idx1-ubyte.gz")[8:]
trainLabels = []
for label in tempTrainLabels:
trainLabels.append(np.zeros(10))
trainLabels[-1][label] = 1.0
print("--- Testing ---")
myNetwork = network.networkFromFile("MNIST30epoch")
print(myNetwork.accuracy(trainImages, trainLabels))