如何用卷积神经网络训练ATC码

2024-09-29 23:28:55 发布

您现在位置:Python中文网/ 问答频道 /正文

我尝试使用一维卷积来输入数据,但精度总是很低。我以前只使用神经网络来训练图像,而我没有训练这种数据。我做错了什么

让我先解释一下我的数据。我的数据是ATC代码。我将把它分成五个部分。如果它是A10BA02,我会把它分成'A',10','B','A',02,我会把字符转换成整数,所以它会变成[65,10,66,65,2]。每个ATC将对应一个数字。在我的示例中,对应于A10BA02的数字是7

xtrain: [[65, 10, 66, 65, 2], [65, 2, 66, 65, 2], [78, 3, 65, 69, 1], [78, 2, 66, 69, 1], [65, 10, 66, 68, 20]]

ytrain: [7, 7, 7, 7, 7]

以上只是一个例子,我的数据有4000多个数据

这是我的密码

with open('DATA.csv') as t:
    tr = csv.reader(fix_nulls(t))
    datas = list(tr)

temp = []
label = []
for i in range(1,4048):
    data = []
    data.append(ord(datas[i][4][0]))
    data.append((int)(datas[i][4][1] + datas[i][4][2]))
    data.append(ord(datas[i][4][3]))
    data.append(ord(datas[i][4][4]))
    if len(datas[i][4]) == 5:
        data.append(0)
    else:
        data.append((int)(datas[i][4][5] + datas[i][4][6]))
    temp.append(data)
    label.append((int)(datas[i][12]))

ndata = np.array(data)
nlabel = np.array(label)
xtrain = []
ytrain = []
xtest = []
ytest = []
for i in range(len(temp)):
    if i%5 != 0:
        xtrain.append(temp[i])
        ytrain.append(label[i])
    else:
        xtest.append(temp[i])
        ytest.append(label[i])
nxtrain = np.array(xtrain)
nytrain = np.array(ytrain)
nxtest = np.array(xtest)
nytest = np.array(ytest)
max_words = 450
nxtrain = sequence.pad_sequences(nxtrain, maxlen=max_words)
nxtest = sequence.pad_sequences(nxtest, maxlen=max_words)

model = Sequential()      # initilaizing the Sequential nature for CNN model
model.add(Embedding(90, 32, input_length=max_words))
model.add(Conv1D(5, 3, activation='relu'))
model.add(MaxPooling1D())
model.add(Flatten())
model.add(Dense(45, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

# Fitting the data onto model
model.fit(nxtrain, nytrain, validation_data=(nxtest, nytest), epochs=100, batch_size=128, verbose=2)
# Getting score metrics from our model
scores = model.evaluate(nxtest, nytest, verbose=0)

这是我第一次问有关堆栈溢出的问题,如果有任何错误,请原谅我


Tags: 数据adddatamodelnparraytemplabel

热门问题