将元素插入到numpy数组中,以便最小间距是任意的

2024-09-24 10:29:49 发布

您现在位置:Python中文网/ 问答频道 /正文

给定一个有序的numpy浮点数组(从最小到最大),我需要确保元素之间的间距小于我调用的step的任意浮点。你知道吗

这是我的代码,据我所知,它可以工作,但我想知道是否有一个更优雅的方式来做到这一点:

import numpy as np

def slpitArr(arr, step=3.):
    """
    Insert extra elements into array so that the maximum spacing between
    elements is 'step'.
    """
    # Keep going until no more elements need to be added
    while True:
        flagExit = True
        for i, v in enumerate(arr):
            # Catch last element in list
            try:
                if abs(arr[i + 1] - v) > step:
                    new_v = (arr[i + 1] + v) / 2.
                    flagExit = False
                    break
            except IndexError:
                pass
        if flagExit:
            break
        # Insert new element
        arr = np.insert(arr, i + 1, new_v)

    return arr


aa = np.array([10.08, 14.23, 19.47, 21.855, 24.34, 25.02])

print(aa)
print(slpitArr(aa))

结果是:

[10.08  14.23  19.47  21.855 24.34  25.02 ]
[10.08  12.155 14.23  16.85  19.47  21.855 24.34  25.02 ]

Tags: innumpytruenewstepnpelementselement
2条回答

这里有一个单程解决方案

1)计算连续点之间的差值d

2)ceil将d分步得到m

2a)可选地将m四舍五入到最接近的二次方

3)将d除以m并重复结果m

4)形成累计和

这是密码。技术说明:d的第一个元素不是差,而是“锚”,因此它等于数据的第一个元素。你知道吗

def fill(data, step, force_power_of_two=True):
    d = data.copy()
    d[1:] -= data[:-1]
    if force_power_of_two:
        m = 1 << (np.frexp(np.nextafter(d / step, -1))[1]).clip(0, None)
    else:
        m = -(d // -step).astype(int)
    m[0] = 1
    d /= m
    return np.cumsum(d.repeat(m))

运行示例:

>>> inp
array([10.08 , 14.23 , 19.47 , 21.855, 24.34 , 25.02 ])
>>> fill(inp, 3)
array([10.08 , 12.155, 14.23 , 16.85 , 19.47 , 21.855, 24.34 , 25.02 ])

对于有序阵列:

def slpitArr(arr, step=3.):
    d = np.ediff1d(arr)
    n = (d / step).astype(dtype=np.int)
    idx = np.flatnonzero(n)
    indices = np.repeat(idx, n[idx]) + 1
    values = np.concatenate(
        [np.linspace(s1, s2, i+1, False)[1:] for s1, s2, i in zip(arr[:-1], arr[1:], n)])
    return np.insert(arr, indices, values)

那么

>>> aa = np.array([10.08, 14.23, 19.47, 21.855, 24.34, 25.02])
>>> print(slpitArr(aa))
[10.08  12.155 14.23  16.85  19.47  21.855 24.34  25.02 ]

>>> print(slpitArr(aa, 2.5))
[10.08       12.155      14.23       15.97666667 17.72333333 19.47
 21.855      24.34       25.02      ]

相关问题 更多 >