这就是我的问题:我实现了一个简单的函数,它返回以矩阵形式组织的信号峰值
@tf.function
def get_peaks(X, X_err):
prominence = 0.9
# X shape (B, N, 1)
max_pooled = tf.nn.pool(X, window_shape=(20, ), pooling_type='MAX', padding='SAME')
maxima = tf.equal(X, max_pooled) #shape (1, N, 1)
maxima = tf.cast(maxima, tf.float32)
peaks = tf.squeeze(X * maxima) #shape (1, N, 1) ==> shape (N,)
peaks_err = X_err * tf.squeeze(maxima)
peaks_idxs, idxs = tf.math.top_k(peaks, k=2)
return peaks_idxs, idxs
如您所见,输入具有shape(B, N, 1)
,即批处理样本,每个样本都是N个元素的一维向量。
返回的idxs
和peaks_idxs
都是正确的,它们具有形状(B,2),即批次中每个样本的两个最大值的位置(和峰值)
问题是,我还想取与idxs
对应的peak_err
。对于numpy
,我将使用:
np.take_along_axis(peaks_err, idxs, axis=1)
它实际上返回了一个正确的(B, 2)
形状的矩阵。我怎样才能用tf做同样的事情?
实际上,我已经尝试使用tf.gather
:
tf.gather(peaks_err, idxs, axis=1)
但它不起作用,结果是不正确的形状(B,B,2),和大量的零。 你知道我怎么解决吗?谢谢
我已经解决了添加三行的问题:
它起作用了! 如果你能找到/有一个更聪明/更快的解决方案,我将不胜感激
相关问题 更多 >
编程相关推荐