HMM学习GMMHMM

2024-09-27 21:32:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试初始化几个GMM,以便与GMMHMM的gmms_u属性一起使用。每个GMM实例具有不同的平均值、权重和协方差,并作为GMMHMM的5组分混合物的组成部分。均值、权重和协方差是根据我想要拟合的数据集的(5-聚类)k-均值算法确定的,其中均值是每个聚类的中心,权重是每个聚类的权重,协方差是每个聚类的协方差。在

下面是一个代码片段:

X_clusters = cls.KMeans(n_clusters=5)
fitted_X = X_clusters.fit(X)
means = fitted_X.cluster_centers_
cluster_arrays = extract_feat(X, fitted_X.labels_)
print ('Means: {0}'.format(means))

total_cluster = float(len(X)) 
all_GMM_params = []
for cluster in cluster_arrays:
    GMM_params = []
    weight = float(len(cluster))/total_cluster
    covar = np.cov(cluster)
    GMM_params.append(weight)
    GMM_params.append(covar)
    all_GMM_params.append(GMM_params)

for i in range(len(means)):
    all_GMM_params[i].append(means[i])


model = GMMHMM(n_components=4, covariance_type="diag", n_iter=1000,
            n_mix = 5, algorithm='map')

for i in range(len(all_GMM_params)):
    GMM_n = mix.GMM(init_params = '')
    GMM_n.weights_ = np.array(all_GMM_params[i][0])
    GMM_n.covars_ = np.array(all_GMM_params[i][1])
    GMM_n.means_ = np.array(all_GMM_params[i][2])
    model.gmms_.append(GMM_n)

model.fit(X)

但是,当我试图拟合模型时,我得到了以下错误:

^{pr2}$

我以前从未见过这样的错误,这是我第一次与sklearn和HMMlearn一起工作。我如何着手修复这个错误?在


Tags: lennp聚类paramsallmeans均值权重
1条回答
网友
1楼 · 发布于 2024-09-27 21:32:41

我能够用一个随机样本从一个双组分高斯混合物中重现这个问题:

import numpy as np

X = np.append(np.random.normal(0, size=1024),
              np.random.normal(4, size=1024))[:, np.newaxis]

下面是我对为什么你的代码不起作用的看法。^{}将给定数组的每一行视为变量。因此,对于形状(N, 1)的数组,输出必然是形状(N, N)。显然,这不是你想要的,因为一维高斯的协方差矩阵只是一个标量。在

解决方案是在将cluster传递给np.cov之前将其转置:

^{pr2}$

在切换到3D之后,X我又发现了两个问题:

  • n_mixGMM中的组分数目,而{}表示马尔可夫链状态的数目(或等效的混合物数目)。请注意,您将n_components=4传递给GMMHMM构造函数,然后将5GMM实例附加到model.gmms_。在
  • 而且,GMMHMM预先填充了model.gmms_,所以最终得到了n_components + 5而不是4个混合物,这解释了(9, )不匹配。在

更新代码:

#      the updated parameter value.
#              vvvvvvvvvvvvvv
model = GMMHMM(n_components=5, covariance_type="diag", n_iter=1000,
               n_mix=5, algorithm='map')
#              ^^^^^^^
#  doesn't have to match n_components

for i, GMM_n in enumerate(model.gmms_):
    GMM_n.weights_ = ...
    # Change the attributes of an existing instance 
    # instead of appending a new one to ``model.gmms_``.

相关问题 更多 >

    热门问题