在Tensorflow中是否可以输入动态形状矩阵?

2024-10-03 02:41:03 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个非静态形状的数据集,比如(batch_size, None, None, None, 92)

for x in x_data:
    print(x.shape)
(4, 4, 8, 92)
(3, 3, 7, 92)
(4, 4, 8, 92)
(3, 3, 7, 92)
(4, 4, 8, 92)
(4, 4, 7, 92)
(3, 3, 7, 92)
(4, 4, 8, 92)
(4, 4, 8, 92)
(3, 3, 8, 92)

但是当我试图把这个x\u数据提供给我的x占位符时,我遇到了错误

X = tf.placeholder(tf.float32, [None, None, None, None, 92])
with tf.Session() as sess:
    c, _ = sess.run([cost, optimizer], feed_dict={X: x_data, Y: y_data})

这个错误可能是由于输入数据的形状不稳定造成的。
错误消息

Traceback (most recent call last):
  File "C:/Users/bsjun/Documents/GitHub/CCpyNN/CCpyNN/Inception_v.2.py", line 274, in <module>
    c, hy, _ = sess.run([cost, logit_layer, optimizer], feed_dict={X: batch_x[i], Y: batch_y[i], keep_prob: 0.8})
  File "C:\Users\bsjun\AppData\Local\conda\conda\envs\tf_normal\lib\site-packages\tensorflow\python\client\session.py", line 929, in run
    run_metadata_ptr)
  File "C:\Users\bsjun\AppData\Local\conda\conda\envs\tf_normal\lib\site-packages\tensorflow\python\client\session.py", line 1121, in _run
    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
  File "C:\Users\bsjun\AppData\Local\conda\conda\envs\tf_normal\lib\site-packages\numpy\core\numeric.py", line 501, in asarray
    return array(a, dtype, copy=False, order=order)
ValueError: setting an array element with a sequence.

是否不可能输入动态形状矩阵?你知道吗

下面是我的代码摘要。你知道吗

import tensorflow as tf
import numpy as np

shapes = [(4, 4, 8, 92), (3, 3, 7, 92), (4, 4, 8, 92), (3, 3, 7, 92)]
x_data = []
for s in shapes:
    x = np.zeros(shape=s)
    print(x.shape)
    x_data.append(x)

X = tf.placeholder(tf.float32, [None, None, None, None, 92])
with tf.Session() as sess:
    sess.run(X, feed_dict={X: x_data})

Tags: runinpynonedatatfasnp
1条回答
网友
1楼 · 发布于 2024-10-03 02:41:03

它对您提供的数据不起作用,但是有一些方法可以解决这个问题。你知道吗

为什么不起作用

线路

 x = tf.placeholder(tf.float32, [None, None, None, 92])

意味着输入数组的形状事先未知。但它必须是可以转换为numpy数组的对象。因为您的输入数据是一系列不同形状的numpy数组,所以它不会被转换。你知道吗

如何处理

<强>1。为您的模型提供单独的输入。您可能可以修改模型的代码,为其提供两种输入:

  • 以“最大”形状输入数据。这意味着,序列中的任何数组都将符合此形状
  • 输入数组的实际形状。你知道吗

例如,在tf.nn.dynamic_rnn()中使用这种方法,其中一个参数是实际数据,另一个参数是sequence_length-每个序列的长度。你知道吗

<强>2。不同批次的数据形状不同。另一种选择是在每个批次上输入不同形状的数组。例如,您将batch_size形状数组(4,4,8,92)分组到一个批中,并将其输入到模型中。然后采用batch_size形状数组(3,3,8,92),再进行一次传递,依此类推。因此,形状在数据集中可能会有所不同,但在单个批处理中应该是不变的。你知道吗

相关问题 更多 >