在sklearn预测器上使用dask Parallel post fit(Parallel post fit包装器)

2024-10-03 02:42:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试评估一个sklearn预测器,它是我在一个大于内存的dask输入数组上生成的。我已经阅读了并行post-fit文档https://dask-ml.readthedocs.io/en/latest/modules/generated/dask_ml.wrappers.ParallelPostFit.html,仍然有一些问题。以下代码说明了我遇到的问题:

from dask.base import tokenize
import numpy as np
import dask.array as da
from dask.array import Array
from sklearn.linear_model import LinearRegression
from dask_ml.wrappers import ParallelPostFit
"""
for stack overflow question
"""
x = np.linspace(0,100,100,dtype=np.int32)
y = np.linspace(0,100,100,dtype=np.int32)
z = np.linspace(0,100,100,dtype=np.int32)

Y = np.random.normal(size=(100,))
X = np.stack([x,y,z],axis=1)
reg = LinearRegression().fit(X,Y)

#now try to compute on dask arrays over the whole space
x= da.linspace(0,100,100,chunks=(10,)).astype(np.int32)
y= da.linspace(0,100,100,chunks=(10,)).astype(np.int32)
z= da.linspace(0,100,100,chunks=(10,)).astype(np.int32)
x,y,z = da.meshgrid(x,y,z,sparse=False,indexing='ij')
stacked = da.stack([x.flatten(),y.flatten(),z.flatten()],axis=1)
clf = ParallelPostFit(estimator=reg)
clf.predict(stacked)

执行clf.预测抛出一个值错误不能删除一个块数超过1的轴。请改用atop。在

我不知道怎么改正。 谢谢你的帮助。在


Tags: fromimportstacknpmlchunksdaskda