擅长:python、mysql、java
<p>获取topk值:</p>
<pre><code>values, _ = tf.math.top_k(test, 2)
</code></pre>
<pre><code><tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[3, 2],
[4, 3],
[5, 4],
[8, 7]])>
</code></pre>
<p>洗牌每行中的值:</p>
<pre><code>shuffled = tf.map_fn(tf.random.shuffle, values)
</code></pre>
<pre><code><tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[2, 3],
[4, 3],
[4, 5],
[7, 8]])>
</code></pre>
<p>选择每个无序行的第一行:</p>
<pre><code>tf.gather(shuffled, [0], axis=1)
</code></pre>
<pre><code><tf.Tensor: shape=(4, 1), dtype=int32, numpy=
array([[2],
[4],
[4],
[7]])>
</code></pre>
<p>复制/可复制代码:</p>
<pre><code>import tensorflow as tf
import numpy as np
test = np.array([
[1, 2, 3],
[2, 3, 4],
[5, 4, 3],
[8, 7, 2]])
values, _ = tf.math.top_k(test, 2)
shuffled = tf.map_fn(tf.random.shuffle, values)
tf.gather(shuffled, [0], axis=1)
</code></pre>