擅长:python、mysql、java
<p>对于纯TF2.x方法,还可以执行以下操作</p>
<pre><code>import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
@tf.function # Remove this to see the tf.print array values
def get_one_hot():
label_ids = [0,5,10]
mask_orig = tf.constant([[0,10], [0,10]], dtype=tf.float32) # [2,2]
mask_onehot = tf.concat([tf.expand_dims(tf.math.equal(mask_orig, label_id),axis=-1) for label_id in label_ids], axis=-1) # [2,2,2]
mask_label_present = tf.reduce_any(mask_onehot, axis=[0,1]) # [2]
tf.print('\n - label_ids:{}'.format(label_ids))
tf.print('\n - mask_orig:\n{}\n'.format(mask_orig))
for id_, label_id in enumerate(label_ids):
tf.print(' - mask_onehot:{}\n{}'.format(label_id, mask_onehot[:,:,id_]))
tf.print('\n - mask_label_present:\n ', mask_label_present)
get_one_hot()
</code></pre>