PT21-22-Reseau-Neurones/tests/MNISTDrawingPrediction.py

45 lines
1.6 KiB
Python
Raw Normal View History

2021-12-22 21:35:06 +01:00
import tkinter
from PIL import Image, ImageDraw
import numpy as np
2021-12-22 22:08:20 +01:00
from sys import path
path.insert(1, "..")
from sobek.network import network
2021-12-22 21:35:06 +01:00
class Sketchpad(tkinter.Canvas):
def __init__(self, parent, predictionLabel, **kwargs, ):
super().__init__(parent, **kwargs)
self.bind("<Button-3>", self.test)
self.bind("<B1-Motion>", self.add_line)
self.PILImage = Image.new("F", (560, 560), 100)
self.draw = ImageDraw.Draw(self.PILImage)
self.MNISTNN = network.networkFromFile("MNIST30epoch")
self.predictionLabel = predictionLabel
def add_line(self, event):
self.create_oval((event.x+32, event.y+32, event.x-32, event.y-32), fill="black")
self.draw.ellipse([event.x-32, event.y-32, event.x+32, event.y+32], fill="black")
smallerImage = self.PILImage.reduce(20)
imageAsArray = np.array(smallerImage.getdata())
imageAsArray = (100 - imageAsArray)/100
self.predictionLabel['text'] = ( "Predicted number : " + str(np.argmax(self.MNISTNN.process(imageAsArray))))
def test(self, event):
self.PILImage = Image.new("F", (560, 560), 100)
self.draw = ImageDraw.Draw(self.PILImage)
self.delete("all")
window = tkinter.Tk()
window.title("Number guesser")
window.resizable(False, False)
window.columnconfigure(0, weight=1)
window.rowconfigure(0, weight=1)
predictionLabel = tkinter.Label(window, text="Predicted number :")
sketch = Sketchpad(window, predictionLabel, width=560, height=560)
sketch.grid(column=0, row=0, sticky=(tkinter.N, tkinter.W, tkinter.E, tkinter.S))
predictionLabel.grid(column=0, row=1)
window.mainloop()