擅长:python、mysql、java
<p>它不应该表现得更像这样:</p>
<p>此版本将保持选定索引中索引的顺序和频率,因此可以多次选择同一行:</p>
<pre><code>import tensorflow as tf
tf.enable_eager_execution()
def sparse_gather(indices, values, selected_indices, axis=0):
"""
indices: [[idx_ax0, idx_ax1, idx_ax2, ..., idx_axk], ... []]
values: [ value1, , ..., valuen]
"""
mask = tf.equal(indices[:, axis][tf.newaxis, :], selected_indices[:, tf.newaxis])
to_select = tf.where(mask)[:, 1]
return tf.gather(indices, to_select, axis=0), tf.gather(values, to_select, axis=0)
indices = tf.constant([[1, 0], [2, 0], [3, 0], [7, 0]])
values = tf.constant([1.0, 2.0, 3.0, 7.0], dtype=tf.float32)
needed_row_ids = tf.constant([7, 3, 2, 2, 3, 7])
slice_indices, slice_values = sparse_gather(indices, values, needed_row_ids)
print(slice_indices, slice_values)
</code></pre>