TensorFlow con中的秩不匹配错误

2024-09-19 23:32:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在尝试做一个简单的2层神经网络。我学习了TensorFlowAPI和官方教程,我做了一层模型,但在神经网络方面有问题。以下是导致错误的代码部分:

with graph.as_default():
    tf_train_dataset = tf.placeholder(tf.float32, shape=(batch_size, image_size * image_size))
    tf_train_labels = tf.placeholder(tf.int32, shape=(batch_size, num_labels))
    tf_valid_dataset = tf.constant(valid_dataset)
    tf_test_dataset = tf.constant(test_dataset)

    weights0 = tf.Variable(tf.truncated_normal([image_size**2, num_labels]))
    biases0 = tf.Variable(tf.zeros([num_labels]))

    hidden1 = tf.nn.relu(tf.matmul(tf_test_dataset, weights0) + biases0)

    weights1 = tf.Variable(tf.truncated_normal([num_labels, image_size * image_size]))
    biases1 = tf.Variable(tf.zeros([image_size**2]))

    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights1) + biases1)


    logits = tf.matmul(hidden2, weights0) + biases0

    labels = tf.expand_dims(tf_train_labels, 1)

    indices = tf.expand_dims(tf.range(0, batch_size), 1)

    concated = tf.concat(1, [indices, tf.cast(labels,tf.int32)])

    onehot_labels = tf.sparse_to_dense(concated, tf.pack([batch_size, num_labels]), 1.0, 0.0)


    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, onehot_labels))

    optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

    train_prediction = tf.nn.softmax(logits)
    valid_prediction = tf.nn.softmax(tf.matmul(tf.nn.relu(tf.matmul(tf.nn.relu(tf.matmul(tf_valid_dataset,weights0) + biases0),weights1)+biases1),weights0)+biases0)
    test_prediction = tf.nn.softmax(tf.matmul(tf.nn.relu(tf.matmul(tf.nn.relu(tf.matmul(tf_test_dataset,weights0) + biases0),weights1)+biases1),weights0)+biases0)

错误是:

^{pr2}$

这是完整的代码:http://pastebin.com/sX7RqbAf

我使用了TensorFlow和Python2.7。我对神经网络和机器学习还很陌生,所以请原谅我的任何错误,提前谢谢。在


Tags: testimagesizelabelstfbatchtrainnn
1条回答
网友
1楼 · 发布于 2024-09-19 23:32:53

在您的例子中:

  • tf_train_labels具有形状[batch_size, num_labels]
  • 因此labels具有[batch_size, 1, num_labels]的形状
  • indices具有形状[batch_size, 1]

因此,当你写下:

concated = tf.concat(1, [indices, tf.cast(labels,tf.int32)])

它抛出一个错误,因为labelsindices的三维不同。labels的三维尺寸为num_labels(大概为10),而{}没有第三维度。在

相关问题 更多 >