有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

java在tensorflow中沿轴选择随机数目的点

我有一个大小为(X,Y,N)的掩码的三维数组,每个值为1或0。当值为1时,我们希望在最后一个维度中收集索引

X, Y, N = 10, 10, 10
# Each point in ar is 1 or 0
ar = tf.random.uniform((X, Y, N), maxval=2, seed=1, name=None, dtype=tf.int32)
# We now want to collect 4 point indices along last dimension if the corresponding value is 1

当相应的值为1时,我想沿着第三维对n(=4)个索引进行采样。我如何在tensorflow中做到这一点?我的函数的输出应该是形状(X,Y,4)

if output[x, y] = [n1, n2, n3, n4] then
ar[x,y, n1] = 1
ar[x,y, n2] = 1
ar[x,y, n3] = 1
ar[x,y, n4] = 1
...
...
...

共 (1) 个答案

  1. # 1 楼答案

    我发现你可能在找^{}。当值为1时,保持概率相等,当值为0时,保持概率为tf.log(0.0)

    import tensorflow as tf
    
    samples = tf.multinomial(tf.log([[1., 1., 0.0]]), 4)
    
    with tf.Session() as sess:
        print(sess.run(samples))
    
    #print
    [[0 1 1 1]]
    

    你还需要使用tf.map_fn来包装tf.multinomial,因为tf.multinomial的对数概率需要二维张量

    import tensorflow as tf
    
    X, Y, N = 10, 10, 10
    ar = tf.random.uniform((X, Y, N), maxval=2, seed=1, name=None, dtype=tf.int32)
    
    samples = tf.map_fn(lambda x:tf.multinomial(tf.log(x), 4,output_dtype=tf.int32)
                        ,tf.cast(ar,tf.float32)
                        ,dtype=tf.int32)
    
    with tf.Session() as sess:
        val1,val2 = sess.run([ar,samples])
        print('ar[0]: \n',val1[0])
        print('samples[0]: \n',val2[0])
    
    ar[0]: 
     [[1 1 1 0 1 0 0 1 1 0]
     [0 1 1 0 0 1 1 1 1 0]
     [0 1 0 1 1 1 1 1 0 1]
     [0 1 0 0 1 0 0 1 1 0]
     [0 0 0 1 0 0 0 1 1 0]
     [1 1 0 0 1 0 1 1 0 1]
     [0 0 1 1 1 0 0 1 1 1]
     [1 1 1 1 1 1 0 0 1 1]
     [1 1 1 1 1 0 0 1 1 0]
     [0 1 1 1 0 1 1 1 1 0]]
    samples[0]: 
     [[0 1 0 4]
     [6 7 5 2]
     [4 4 4 7]
     [1 8 7 4]
     [8 7 7 3]
     [7 7 6 0]
     [7 4 3 9]
     [1 5 0 3]
     [4 1 1 7]
     [7 7 1 6]]
    

    请注意tf.multinomial将在新版本中删除。更新说明:使用tf。随机的而是直截了当的