登录和标签必须具有相同的形状((294,6)vs(6,1))

2024-05-19 16:25:24 发布

您现在位置:Python中文网/ 问答频道 /正文

我是ML新手。我正在尝试使用tensorflow实现一个多标签分类器

我在网上搜索过,但没有尽头。也许我把基本知识搞错了

实现的代码如下所示:

import pandas as pd
import tensorflow as tf

train=pd.read_csv('data/train.csv',header=None)
print(train.columns)
columns=[294,295,296,297,298,299]
train_x=train.drop(columns,axis=1)
train_y=train[columns]
train_x=train_x.values
train_y=train_y.values
print(train_x.shape)
print(train_y.shape)
print(train_y)
train_dataset= tf.data.Dataset.from_tensor_slices((train_x,train_y))
print(train_dataset)
model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=294,activation='relu'),
    tf.keras.layers.Dense(units=128, activation='relu'),
    tf.keras.layers.Dense(units=6,activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])


model.fit(train_dataset, epochs=10)

输出:

Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            290, 291, 292, 293, 294, 295, 296, 297, 298, 299],
           dtype='int64', length=300)
(1438, 294)
(1438, 6)
[[0 0 0 1 0 1]
 [0 0 0 1 0 1]
 [0 0 0 1 0 1]
 ...
 [1 0 0 0 0 1]
 [1 0 0 0 0 1]
 [1 0 0 0 0 1]]
Traceback (most recent call last):
  File "/home/shivam/Desktop/tftest/load.py", line 25, in <module>
    model.fit(train_dataset, epochs=10)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 819, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 342, in fit
    total_epochs=epochs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 128, in run_one_epoch
    batch_outs = execution_function(iterator)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 98, in execution_function
    distributed_function(input_fn))
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
    result = self._call(*args, **kwds)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
    *args, **kwds))
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 85, in distributed_function
    per_replica_function, args=args)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 763, in experimental_run_v2
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1819, in call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 2164, in _call_for_each_replica
    return fn(*args, **kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/autograph/impl/api.py", line 292, in wrapper
    return func(*args, **kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 433, in train_on_batch
    output_loss_metrics=model._output_loss_metrics)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_eager.py", line 312, in train_on_batch
    output_loss_metrics=output_loss_metrics))
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_eager.py", line 253, in _process_single_batch
    training=training))
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_eager.py", line 167, in _model_loss
    per_sample_losses = loss_fn.call(targets[i], outs[i])
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/losses.py", line 221, in call
    return self.fn(y_true, y_pred, **self._fn_kwargs)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/losses.py", line 994, in binary_crossentropy
    K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/keras/backend.py", line 4615, in binary_crossentropy
    return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
  File "/home/shivam/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/ops/nn_impl.py", line 170, in sigmoid_cross_entropy_with_logits
    (logits.get_shape(), labels.get_shape()))
ValueError: logits and labels must have the same shape ((294, 6) vs (6, 1))

Process finished with exit code 1

有人能告诉我我做错了什么吗? 我试过SparseCategoricCrossention,但也没有什么效果 也许我的数据集形状错了

这是模型的问题吗


Tags: inpycorehomelibpackagestensorflowline
1条回答
网友
1楼 · 发布于 2024-05-19 16:25:24

您正在解决一个多类问题,但损失函数用于解决二进制类问题。请更改编译语句如下

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

注意: 如果你的目标是一个热编码的,使用分类交叉熵。 一种热编码的示例:

[1,0,0]
[0,1,0]
[0,0,1]

但是如果你的目标是整数,使用稀疏的分类交叉熵。 整数编码示例:

1
2
3

您可以在下面的链接中找到keras中的不同损失函数-

https://www.tensorflow.org/api_docs/python/tf/keras/losses

相关问题 更多 >