为什么我的TensorFlow模型在加载后失去了准确性

2024-09-30 22:21:02 发布

您现在位置:Python中文网/ 问答频道 /正文

因此,我正在对MNIST数据集进行培训,代码如下所示。 问题是,在第一次运行时,它会计算所有内容,并给我一个公平的准确性。 但是在第二次运行时(当它应该从保存的文件加载时),准确度会大大降低。 我的代码或我没有遵循的任何实践是否有问题

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from os import environ, sep
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
MODELFILENAME = 'TF_ZTH_02_model' + sep
labels = {
    0:'T-shirt/Top',
    1:'Trouser',
    2:'Pullover',
    3:'Dress',
    4:'Coat',
    5:'Sandal',
    6:'Shirt',
    7:'Sneaker',
    8:'Bag',
    9:'Ankle Boot'
}

def main():
    fashionmnist = keras.datasets.fashion_mnist
    (trainimages, trainlabels), (testimages, testlabels) = fashionmnist.load_data()
    trainimages, testimages = trainimages/255., testimages/255.
    
    try:
        #try load model
        model = keras.models.load_model(MODELFILENAME)
    
    #files doesn't exist, train model
    except:
        
        #activation functions
        #relu - rectified linear unit - return value if its greater than 0 or 0
        #softmax - picks biggest number in set
        model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28, 28)), #size of image 
            keras.layers.Dense(128, activation=tf.nn.relu), 
            keras.layers.Dense(10, activation=tf.nn.softmax) #ten clothing
        ])

        model.compile(
            optimizer = 'adam',
            loss = 'sparse_categorical_crossentropy',
            metrics = 'accuracy'
        )

        model.fit(trainimages, trainlabels, epochs=5)
        
        #save to file
        model.save(MODELFILENAME)
    
    testloss, testacc = model.evaluate(testimages, testlabels)
    print('\nEvaluation, loss and accuracy : ', testloss, testacc)
    predictions = model.predict(testimages)
#    predictions = model.predict(np.asarray([testimages[0]]))
    
    while True:
        x = int(input('\nEnter image number (<%d) : '%len(testimages)))
        print('\nPredictions : ',
              predictions[x],
              predictions[x].argmax(),
              labels[predictions[x].argmax()]
             )
        print('Actual : ', testlabels[x], labels[testlabels[x]])
        plt.ioff()
        plt.imshow(testimages[x])
        plt.title(labels[predictions[x].argmax()])
        plt.show()
    
    #but this ds has objects centered
    #in the case of an unprocessed ds, you'd need to SPOT FEATURES
    #with the help of convolutional networks

try:
    main()
except Exception as e:
    print(e)
finally:
    input()

Output on First Run

Output on First Run

Output on Second Run

Output on Second Run


Tags: importlabelsmodeltfasloadpltkeras