使用np.select创建具有多索引dataframe的新列

2024-10-02 00:26:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图确定我的数据是否越过了一条线,从哪个方向。我使用np.select在单个索引帧上工作,但是,当我尝试在多索引数据帧上执行相同操作时,我得到了所有NaN

这是我的助手函数:

def calc_crossings_helper(df, line):
# define crossing conditions - corresponding choices are [1,-1] to denote direction, otherwise NaN
line_crossed_cond = [(df['Close'] < df[line]) & (df['Close'].shift(1) > df[line].shift(1)),
                     (df['Close'] > df[line]) & (df['Close'].shift(1) < df[line].shift(1))] 
return np.select(line_crossed_cond, [1, -1], default = np.nan)

这样称呼:

df['Hcross'] = df.groupby(level=0, group_keys=False).apply(calc_crossings_helper, ('highBound'))

helper函数返回:

Symbol
AAPL    [nan, nan, -1.0, nan, nan, 1.0, nan, -1.0, nan...
AMZN    [nan, nan, nan, nan, nan, nan, nan, -1.0, nan,...

但是df['Hcross']列被分配了所有的NaN

                    Close   Hcross
Symbol Date                    
AAPL   2019-12-02   264.16  NaN
       2019-12-03   259.45  NaN
       2019-12-04   261.74  NaN
       2019-12-05   265.58  NaN
       2019-12-06   270.71  NaN
       2019-12-09   266.92  NaN
       2019-12-10   268.48  NaN
       2019-12-11   270.77  NaN
       2019-12-12   271.46  NaN
       2019-12-13   275.15  NaN
AMZN   2019-12-02  1781.60  NaN
       2019-12-03  1769.96  NaN
       2019-12-04  1760.69  NaN
       2019-12-05  1740.48  NaN
       2019-12-06  1751.60  NaN
       2019-12-09  1749.51  NaN
       2019-12-10  1739.21  NaN
       2019-12-11  1748.72  NaN
       2019-12-12  1760.33  NaN
       2019-12-13  1760.94  NaN

我想我需要以某种方式展平从helper函数返回的数组,但我不知道如何展平


Tags: 数据函数helperdfcloseshiftnpline
1条回答
网友
1楼 · 发布于 2024-10-02 00:26:28

一个简单的修复方法是返回一个像DataFrame这样的索引序列。这提供了正确的对齐方式,因为np.select返回与数据帧长度相同的数组

def calc_crossings_helper(df, line):
    # define crossing conditions - corresponding choices are [1,-1] to denote direction, otherwise NaN
    line_crossed_cond = [(df['Close'] < df[line]) & (df['Close'].shift(1) > df[line].shift(1)),
                         (df['Close'] > df[line]) & (df['Close'].shift(1) < df[line].shift(1))] 

    return pd.Series(np.select(line_crossed_cond, [1, -1], default = np.nan), index=df.index)

现在gropuby返回是一个类似索引的多索引:

df.assign(highbound=265).groupby(level=0, group_keys=False).apply(calc_crossings_helper, ('highbound'))

Symbol  Date      
AAPL    2019-12-02    NaN
        2019-12-03    NaN
        2019-12-04    NaN
        2019-12-05   -1.0
        2019-12-06    NaN
        2019-12-09    NaN
        2019-12-10    NaN
        2019-12-11    NaN
        2019-12-12    NaN
        2019-12-13    NaN
AMZN    2019-12-02    NaN
        2019-12-03    NaN
        2019-12-04    NaN
        2019-12-05    NaN
        2019-12-06    NaN
        2019-12-09    NaN
        2019-12-10    NaN
        2019-12-11    NaN
        2019-12-12    NaN
        2019-12-13    NaN

更好的是,考虑到数据帧排序,不需要groupby.apply()。您可以在符号级别使用shift来添加分组条件,因此只需要一个np.select调用

line = 'highbound'
# Series b/c there is no pd.Index.shift method
s = pd.Series(df.index.get_level_values('Symbol'), index=df.index)

line_crossed_cond = [(s.eq(s.shift()) 
                      & (df['Close'] < df[line]) 
                      & (df['Close'].shift(1) > df[line].shift(1))),
                     (s.eq(s.shift())
                      & (df['Close'] > df[line]) 
                      & (df['Close'].shift(1) < df[line].shift(1)))]

df['Hcross'] = np.select(line_crossed_cond, [1, -1], default = np.nan)

相关问题 更多 >

    热门问题