擅长:python、mysql、java
<p>对于已知3个类的特定场景,这应该可以更快地工作</p>
<pre><code>def one_hot2(img):
class1 = [255,0,0]
class2 = [0,0,255]
class3 = [255,255,255]
label = np.zeros_like(img)
label[np.sum(img==np.array([[class2]]), 2)==3] = 1
label[np.sum(img==np.array([[class3]]), 2)==3] = 2
onehot = np.eye(3)[label]
return onehot
</code></pre>