ValueError:无法在Tensorflow中将NumPy数组转换为Tensor(不支持的对象类型NumPy.ndarray)

2024-05-02 00:20:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在创建一个新的Pandas Dataframe列,使用4个其他列应用自定义定义的函数行

下面是应用该函数的列的结构

enter image description here

创建的新列如下所示

enter image description here

我编写的函数如下:

def convert_credit_rows(row):
  return np.asarray([row['A'], row['B'], row['C'], row['D']], dtype=np.float32)

X_train['credit_balance'] = X_train.apply(convert_credit_rows, axis=1)
X_test['credit_balance'] = X_test.apply(convert_credit_rows, axis=1)

我将此数据集提供给一个简单的神经网络,如下所示:

def CreditBalanceDetector():

  X_train_credit_balance = X_train['credit_balance']
  X_test_credit_balance = X_test['credit_balance']

  model = Sequential()
  model.add(Dense(20, activation='relu'))
  model.add(Dense(10, activation='relu'))
  model.add(Dense(6, activation='softmax'))

  model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.0005), 
  metrics=['accuracy'])
  early_stop = EarlyStopping(monitor='val_loss',patience=3)
  model.fit(X_train_credit_balance, y_train, epochs=50, validation_data=
  (X_test_credit_balance, y_test), callbacks=[early_stop])

但在尝试训练模型时,我得到了以下错误

enter image description here

虽然在StackOverflow中有两个类似的问题,但解决方案表明这些问题对我不起作用

如果有人能找出我哪里出了问题,我将不胜感激。多谢各位


Tags: 函数testaddconvertmodeldefnptrain
1条回答
网友
1楼 · 发布于 2024-05-02 00:20:41

我可以找出我代码中的错误。为将来可能受益的人在此发布

在上面的代码中,我提供了一个Pandas Series对象作为X_train_credit_balanceX_test_credit_balance的数据类型,其中model.fit()函数需要一个数组。如果我们检查X_train_credit_balance的单个元素,如下所示

print(X_train_credit_balance[0])

它将给出以下不需要的输出:

array([30., 30.], dtype=float32)

正确的代码

可以通过如下修改convert_credit_rows(row)函数来纠正上述行为:

credit_list = []

def convert_credit_rows(row):
  credit_list.append(np.asarray([row['A'], row['B'], row['C'], row['D']], dtype=np.float32))

 X_train_credit_balance = np.array(credit_list)

在本例中,convert_credit_rows函数将应用于每一行,创建(m,n)维数组的列表-credit_list。然后,作为下一步,我们可以通过np.array(credit_list)credit_list转换为一个数组。如果我们在操作结束时打印出credit_list,我们可以看到格式正确的数组,如下所示:

[[1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]]

现在,如果我们打印出X_train_credit_balance的类型,它将是<class 'numpy.ndarray'>,而不是熊猫系列对象

相关问题 更多 >