为什么这个布尔值出现在这个贝叶斯分类器中?(Python问题?)

2024-09-30 22:16:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在学习GANs(我是python的初学者),在前面的练习中我发现这部分代码我不理解。具体地说,我不明白为什么要用第9行的布尔值(Xk=X[Y==k]),原因如下

class BayesClassifier:
  def fit(self, X, Y):
    # assume classes are numbered 0...K-1
    self.K = len(set(Y))

    self.gaussians = []
    self.p_y = np.zeros(self.K)
    for k in range(self.K):
      Xk = X[Y == k]
      self.p_y[k] = len(Xk)
      mean = Xk.mean(axis=0)
      cov = np.cov(Xk.T)
      g = {'m': mean, 'c': cov}
      self.gaussians.append(g)
    # normalize p(y)
    self.p_y /= self.p_y.sum()
  1. 根据Y的真实性,该布尔值返回0或1== k、 因此Xk总是X列表的第一个或第二个值。你不觉得这有什么用。你知道吗
  2. 在第10行中,len(Xk)总是1,为什么用这个参数而不是一个1呢?你知道吗
  3. 下一行的平均值和协方差每次仅用一个值计算。你知道吗

我觉得我不懂一些很基本的东西。你知道吗


Tags: 代码selflendefnp原因meancov
2条回答

您应该考虑到X, Y, k是NumPy数组,而不是标量,并且有些操作符对它们重载。尤其是==和基于布尔的索引。==将是元素比较,而不是整个数组比较。你知道吗

了解其工作原理:

In [9]: Y = np.array([0,1,2])                                                                                        
In [10]: k = np.array([0,1,3])                                                                                       
In [11]: Y==k                                                                                                        

Out[11]: array([ True,  True, False])

因此,==的结果是一个布尔数组。你知道吗

In [12]: X=np.array([0,2,4])                                                                                         
In [13]: X[Y==k]                                                                                                     

Out[13]: array([0, 2])

当条件为True时,结果是一个包含从X中选择的元素的数组

因此len(Xk)将是Xk之间匹配元素的数量。你知道吗

谢谢,阿特姆

你说得对。我从另一个渠道找到了另一个答案,这是:

It's a Numpy array - it's a special feature of NumPy arrays called boolean indexing that lets you filter out only the values in the array where the filter returns True:

https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html?fbclid=IwAR3sGlgSwhv3i7IETsIxp4ROu9oZvNaaaBxZS01DrM5ShjWWRz22ShP2rIg#boolean-or-mask-index-arrays

import numpy as np

a = np.array([1, 2, 3, 4, 5]) filter = a > 3

print(filter)

[False, False, False, True, True]

打印(一个[过滤器])

[4, 5]

相关问题 更多 >