将Python函数应用于输出向量的所有元素(从R转换)

2024-09-28 03:24:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在尝试将R中的一些代码转换为Python来绘制曲线,但遇到了一些错误,主要与将函数rss(残差平方和)应用于Beta2s有关,在原始R代码中,这是通过sapply()完成的。我尝试过使用map(),但它在Matplotlib中运行不好,因为我得到的错误是does not support generators as input。我已经完成了list(map()),得到了'int' object is not iterable的错误。非常感谢您的帮助

以下是R中的代码:

rss <- function(Beta0,Beta1,Beta2){
  r <- y - (Beta0+Beta1*tt+Beta2*tt^2)
  sum(r^2)
}

Beta2s <- seq(-10,0,len=100)
RSS <- sapply(Beta2s, rss, Beta0=55, Beta1=0)
plot(Beta2s,RSS,Type="l")

以下是我在Python中的尝试:

def rss(Beta0, Beta1, Beta2):
    r = y - (Beta0 + Beta1*t + Beta2*t**2)
    return np.sum(r**2)

Beta2s = np.linspace(-10, 0, num = 100)
Beta0 = 55
Beta1 = 0
RSS = rss(Beta2s) #<-----------------Need help here
plt.plot(Beta2s, RSS)
plt.show()

Tags: 代码mapplot错误npnotrsssum
2条回答

在R中:

tt <- seq(1,10,length.out=100)
y <- seq(1,10,length.out=100)
Beta0 = 55
Beta1 = 0
Beta2s <- seq(-10,0,len=100)
RSS <- sapply(Beta2s, rss, Beta0=55, Beta1=0)

head(RSS)
[1] 19223571 18806870 18394761 17987243 17584318 17185985

在python中:

Beta2s = np.linspace(-10, 0, num = 100)
Beta0 = 55
Beta1 = 0
y = np.linspace(1,10,100)
t = np.linspace(1,10,100)

对于打印,您可以使用列表:

plt.plot(Beta2s,[rss(Beta0,Beta1,i) for i in Beta2s])

或者将函数矢量化:

RSS = np.vectorize(rss)(Beta0,Beta1,Beta2s)

RSS[:5]
array([19223570.88655147, 18806869.74602632, 18394760.55678168,
       17987243.31881757, 17584318.03213398])

使用地图:

# Import and initialise the packages required in session: 
import numpy as np
import matplotlib.pyplot as plt

# Define rss function: rss => function 
def rss(Beta0, Beta1, Beta2):
    r = y - (Beta0 + Beta1*t + Beta2*t**2)
    return np.sum(r**2)

# Generate data: Beta2s => numpy ndarray
Beta2s = np.linspace(-10, 0, num = 100)

# Store constant values as scalars to be applied over array: Beta0 => int, Beta1 => int
Beta0, Beta1 = 55, 0

# Generate y & t: y => numpy ndarray, t => numpy ndarray
y, t = np.linspace(1,10,100), np.linspace(1,10,100)

# Use map and a lambda function to plot the function: plt => stdout 
plt.plot(Beta2s,[*map(lambda x: rss(Beta0,Beta1,x), Beta2s)])
plt.show()

相关问题 更多 >

    热门问题