我是ML的新手,为了我的项目,我正在尝试使用神经网络制作一个数字分类器。 我制作了一个图形用户界面,在那里你可以画数字,它会把NumPy数组传递给神经网络。 我用mnist数字数据集训练了我的神经网络,模型准确率为97.70%,但它无法预测输入的数字
#CODE FOR NEURAL NETWORK
class mltest():
def __init__(self):
self.model = keras.Sequential([ keras.layers.Dense(120,input_shape = (784,),activation = 'relu'),
keras.layers.Dense(10,activation = 'softmax')])
def train(self):
(x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data()
x_train = x_train/255
x_test = x_test/255
x_train = x_train.reshape(len(x_train),28*28)
x_test = x_test.reshape(len(x_test),28*28)
self.model.compile(optimizer='adam',loss='SparseCategoricalCrossentropy',metrics=['accuracy'])
self.model.fit(x_train,y_train,epochs=12,batch_size=200)
self.model.evaluate(x_test,y_test)
def test(self,value):
y_predicted = self.model.predict(value)
print(np.argmax(y_predicted))
if __name__ =='__main__':
obj = mltest()
obj.train()
Epoch 12/12 300/300 [==============================] - 1s 2ms/step - loss: 0.0338 - accuracy: 0.9908 313/313 [==============================] - 0s 776us/step - loss: 0.0763 - accuracy: 0.9770
GUI代码
class Window(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Ai")
self.setGeometry(300,300,300,300)
self.image = QImage(self.size(), QImage.Format_RGB32)
self.image.fill(Qt.white)
self.drawing = False
self.brushSize = 25
self.brushColor = Qt.black
self.lastPoint = QPoint()
self.object = mltest()
#To send the picture data to neural network when enter key is pressed
def keyPressEvent(self, event):
if event.key() == 16777220:
screen = QtWidgets.QApplication.primaryScreen()
screenshot = screen.grabWindow(QtWidgets.QWidget.winId(self))
bufffer = QBuffer()
bufffer.open(bufffer.ReadWrite)
screenshot = screenshot.save(bufffer,'PNG')
image = Image.open(io.BytesIO(bufffer.data()))
image = image.convert('P')
image = np.array(image)
image =cv2.resize(image,dsize = (28,28))
image = cv2.blur(image,(2,2))
image = cv2.bitwise_not(image)
image[image<=30] = 0
print(image)
plt.matshow(image)
plt.show()
image = image.reshape(1,28*28)
image = image/255
self.object.test(image)
#To draw
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self.drawing = True
self.lastPoint = event.pos()
def mouseMoveEvent(self, event):
if (event.buttons() & Qt.LeftButton) & self.drawing:
painter = QPainter(self.image)
painter.setPen(QPen(self.brushColor, self.brushSize,
Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin))
painter.drawLine(self.lastPoint, event.pos())
self.lastPoint = event.pos()
self.update()
def mouseReleaseEvent(self, event):
if event.button() == Qt.LeftButton:
self.drawing = False
def paintEvent(self, event):
canvasPainter = QPainter(self)
canvasPainter.drawImage(self.rect(), self.image, self.image.rect())
if __name__ == '__main__':
App = QApplication(sys.argv)
window = Window()
window.show()
sys.exit(App.exec())
显然,您在两个文件中使用了
if __name__ =='__main__':
,并且仅当您调用网络的文件时才对网络进行训练,而当您启动GUI应用程序时,您的网络是在未经训练的情况下创建的self.object = mltest()
。除非你调用self.object.train()
,否则它很可能没有经过训练,因此无法做出好的预测相关问题 更多 >
编程相关推荐