基于pyspark的随机梯度下降算法求解最优参数

2024-10-01 04:54:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在pyspark中实现一个算法,使用随机梯度下降和梯度下降来获得代价函数/误差值。我使用numpy模块在python中实现了以下代码:

from numpy import *
from numpy.random import *
import sys

# Gradient descent function
def GDFunc(weight): 
    return sum((x.dot(weight) - y)**2 for x, y in data_points) / len(data_points)

def dGDFunc(weight):
    return sum(2 * (x.dot(weight) - y) * x for x, y in data_points) / len(data_points)

# Stochastic gradient descent function
def SGDFunc(weight, i):
    x, y = data_points[i]
    return (x.dot(weight) - y)**2

def dSGDFunc(weight, i):
    x, y = data_points[i]
    return 2 * (x.dot(weight) - y) * x


# Data generation
true_weights = array([1, 2, 3, 4, 5])
d = len(true_weights)
data_points = []
for _ in range(1000):
    x = randn(d)
    y = x.dot(true_weights) + randn()
    data_points.append((x, y))

# Find gradient descent optimal parameters
def gradientDescent(GDFunc, dGDFunc, d):
    weight = zeros(d)
    alpha = 0.01
    for t in range(100):
        value = GDFunc(weight)
        gradient = dGDFunc(weight)
        print 'iteration %d: weight = %s, GDFunc(weight) = %s' % (t, weight, value)
        weight = weight - alpha * gradient

def stochasticGD(SGDFunc, dSGDFunc, d, n):
    weight = zeros(d)
    numUpdates = 0
    oldValue=0
    for t in range(100):
        for i in range(n):  # Nested Loop to process each sample individually
            value = SGDFunc(weight, i)
            gradient = dSGDFunc(weight, i)
            numUpdates += 1
            alpha = 1.00 / numUpdates  # Decreasing step size
            weight = weight - alpha * gradient
        print 'iteration %d: Cost: %s.  weight = %s' % (t,value, weight)
        if t>5 and value>oldValue:
            print "Found Local Minima"
            sys.exit()
        oldValue=value


gradientDescent(GDFunc, dGDFunc, d)
stochasticGD(SGDFunc, dSGDFunc, d, len(data_points))

如果我在上面的代码中增加迭代次数,那么在执行过程中会花费大量时间,因此我正在寻找一种使用pyspark实现相同的方法。你知道吗

虽然我知道mllib pyspark库提供了一个类LinearRegressionWithSGD来训练模型,但我不清楚如何从中获得最佳参数。你知道吗


Tags: infordatalenreturnvaluedefdot