擅长:python、mysql、java
<p>我找了一位对这个领域了解更多的工程师,他告诉我的是:</p>
<p>我不确定我们是否有一个有效的实现,但这里有一个使用动态分区和聚集操作的不太理想的实现。在</p>
<pre><code>def sparse_slice(indices, values, needed_row_ids):
num_rows = tf.shape(indices)[0]
partitions = tf.cast(tf.equal(indices[:,0], needed_row_ids), tf.int32)
rows_to_gather = tf.dynamic_partition(tf.range(num_rows), partitions, 2)[1]
slice_indices = tf.gather(indices, rows_to_gather)
slice_values = tf.gather(values, rows_to_gather)
return slice_indices, slice_values
with tf.Session().as_default():
indices = tf.constant([[0,0], [1, 0], [2, 0], [2, 1]])
values = tf.constant([1.0, 1.0, 0.3, 0.7], dtype=tf.float32)
needed_row_ids = tf.constant([1])
slice_indices, slice_values = sparse_slice(indices, values, needed_row_ids)
print(slice_indices.eval(), slice_values.eval())
</code></pre>
<p>更新:</p>
<p>工程师还发送了一个示例来帮助处理多行,感谢您指出这一点!在</p>
^{pr2}$