为什么我的张量流而循环不工作

2024-09-28 20:47:55 发布

您现在位置:Python中文网/ 问答频道 /正文

import math
import numpy as np
import tensorflow as tf

myx=np.array([2,4,5])
myy=np.array([10,3,7,8,6,4,11,18,1])


Xxx=np.transpose(np.repeat(myx[:, np.newaxis], myy.size , axis=1))
Yyy=np.repeat(myy[:, np.newaxis], myx.size , axis=1)


X = tf.placeholder(tf.float64, shape=(myy.size,myx.size))
Y = tf.placeholder(tf.float64, shape=(myy.size,myx.size))
calp=tf.constant(1)

with tf.device('/cpu:0'):

    #minCord=tf.argmin(tfslic,0)

    dist = tf.abs(tf.subtract(X,Y))

    i =  tf.placeholder(dtype='int32')

    def condition(i):
        return i < 2

    def b(i):

        dist = tf.abs(tf.subtract(X,Y))
        tfslic=tf.slice(dist,[0,i],[myy.size,1])
        minVal=tf.reduce_min(tfslic,0)
        y = tf.cond(tf.less_equal(minVal, 1), lambda: tf.argmin(tfslic,0), lambda: 99999)

        return i+1, y


i, r = tf.while_loop(condition, b, [i])


sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
dmat=sess.run(i, feed_dict={X:Xxx, Y: Yyy, i:0})
sess.close()

print(dmat)

我一直收到以下错误:

^{pr2}$

有人能帮我解决这个错误吗?我试着让这个张量流“while”循环起作用。在

基本上,我尝试用张量流框架来做一个贪婪的1:1匹配数组“myx”和“myy”。在


Tags: importsizedisttfasnparrayplaceholder
1条回答
网友
1楼 · 发布于 2024-09-28 20:47:55

^{}函数要求pred是标量(“秩0”)张量。在程序中,它是长度为1的向量(“秩1”)张量。在

有很多方法可以解决这个问题。例如,可以使用tf.reduce_min()而不指定轴来计算tfslic的全局最小值作为标量:

minVal = tf.reduce_min(tfslic)

…或者可以显式使用tf.reshape()将参数设为tf.cond()标量:

^{pr2}$

我冒昧地稍微修改了一下你的程序,得到了一个有效的版本。按照注释查看哪些地方需要更改:

with tf.device('/cpu:0'):

    dist = tf.abs(tf.subtract(X,Y))

    # Use an explicit shape for `i`.
    i = tf.placeholder(dtype='int32', shape=[])

    # Add a second unused argument to `condition()`.
    def condition(i, _):
        return i < 2

    # Add a second unused argument to `b()`.
    def b(i, _):
        dist = tf.abs(tf.subtract(X,Y))

        # Could use `tfslic = dist[0:myy.size, i]` here to avoid later reshapes.
        tfslic = tf.slice(dist, [0,i], [myy.size,1])

        # Drop the `axis` argument from `tf.reduce_min()`
        minVal=tf.reduce_min(tfslic)

        y = tf.cond(
            tf.less_equal(minVal, 1),
            # Reshape the output of `tf.argmin()` to be a scalar.
            lambda: tf.reshape(tf.argmin(tfslic, 0), []),
            # Explicitly convert the false-branch value to `tf.int64`.
            lambda: tf.constant(99999, dtype=tf.int64))

        return i+1, y

# Add a dummy initial value for the second loop variable.
# Rename the first return value to `i_out` to avoid clashing with `i` above.
i_out, r = tf.while_loop(condition, b, [i, tf.constant(0, dtype=tf.int64)])

sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
# Fetch the value of `i_out`.
dmat = sess.run(i_out, feed_dict={X:Xxx, Y: Yyy, i:0})

相关问题 更多 >