混合密度网络

2024-06-26 10:44:44 发布

您现在位置:Python中文网/ 问答频道 /正文

在混合密度网络中,有一个函数用于对多个预测进行采样,即:

y_test = model.predict(x_test)
y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, OUTPUT_DIMS, N_MIXES, temp=1.0)

所以,这里这个函数只返回一个样本,而我应该得到样本数=N_混合。关于如何获得多个样本有什么想法吗


Tags: sample函数fromtest网络modelnppredict
1条回答
网友
1楼 · 发布于 2024-06-26 10:44:44

在init.py的第251行中,更改如下:

sample = np.random.multivariate_normal(mus_vector, cov_matrix, n_samp)

将第224行中的n_samp作为参数传递为:

def sample_from_output(params, output_dim, num_mixes, n_samp=100, temp=1.0, sigma_temp=1.0):

相关问题 更多 >