Mandelbrot Numba/Numpy矢量化?

2024-09-27 09:36:03 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用kivy在Python中编写了一个交互式mandelbrot渲染器,您可以使用鼠标指针进行缩放,我正在尽可能地优化它。我目前使用此实现来渲染集/缩放(这是一个小片段,仅用于渲染它的两个函数):

import numba as nb
import numpy as np


@nb.njit(cache= True, parallel = True)
def mandelbrot(c_r, c_i,maxIt): #mandelbrot function
        z_r = 0 
        z_i = 0
        z_r2 = 0
        z_i2= 0
        for x in nb.prange(maxIt):
            z_i = 2 * z_r * z_i + c_i
            z_r = z_r2 - z_i2 + c_r
            z_r2 = z_r * z_r
            z_i2 = z_i * z_i
            if z_r2 + z_i2 > 4:
                return x
        return maxIt

@nb.njit(cache= True, parallel = True)
def DrawSet(W, H, xStart, xDist, yStart, yDist, maxIt):
        array = np.zeros((H, W, 3), dtype=np.uint8) #array that holds 'hsv' tuple for every pixel
        for x in nb.prange(0, W):
            c_r = (x/W)* xDist + xStart #some math to calculate real part
            for y in range (0, H):
                c_i = -((y/H) * yDist + yStart) #some more math to calculate imaginary part
                cIt = mandelbrot(c_r, c_i, maxIt) 
                color = int((255 * cIt) / maxIt)
                array[y,x] = (color, 255, 255) #adds hue value 
        return array #returns hsv array, gets later displayed using PIL

我目前的表现相当不错。它可以在大约0.08-0.09秒的时间内,以300次迭代,渲染出一个500x500的区域,其中每个点都有边界(因此基本上是一张黑色图片,最坏的情况)。我正在使用Numba JIT和并行范围函数“prange()”,这非常有帮助

然而,我听说矢量化通常是绘制此类分形的最快方法。经过大量研究(我对矢量化非常陌生),我成功地将这个实现组合在一起:

import numba as nb
import numpy as np

def DrawSet(W, H, xStart, xEnd, yStart, yEnd, maxIt):

    array = np.zeros((H,W,3), dtype = np.uint8) # 3D array containing 'hsv' tuple (hue,saturation,value) of each pixel

    x = np.linspace(xStart, xEnd, W).reshape((1, W)) #scaling horizontal pixels to x-axis
    y = np.linspace(yStart, yEnd, H).reshape((H, 1)) #scaling vertical pixels to y-axis
    c = x + 1j * y #creating complex plane out of x axis (real) and y axis (imaginary)
    z = np.zeros(c.shape, dtype= np.complex128)
    div_time = np.zeros(z.shape, dtype= int)
    m = np.full(c.shape, True, dtype= bool)

    div_time = loop(z, c, div_time, m, maxIt)
    
    array[:,:,0] = (div_time/maxIt) * 255 -20 #adding 'hue' value
    array[:,:,1] = 255 #adding 'saturation' value
    array[:,:,2] = 255 #adding 'value'
    
    return array


@nb.vectorize(nb.int64[:,:](nb.complex128[:,:], nb.complex128[:,:], nb.int64[:,:], nb.boolean[:,:], nb.int64))
def loop(z, c, div_time, m, maxIt):

    for i in range(maxIt):
        z[m] = z[m]**2 + c[m]
        diverged = np.greater(np.abs(z), 2, out=np.full(c.shape, False), where=m)
        div_time[diverged] = i      
        m[np.abs(z) > 2] = False
    return div_time

没有@nb.vectorize decorator,它的运行速度非常慢。(500x500、300 It的最坏情况为4秒。)。使用@nb.vectorize decorator,我得到以下错误:


Traceback (most recent call last):
   File "MandelBrot.py", line 13, in <module>
     from test import DrawSet
   File "C:\Users\User\Documents\Code\Python\Mandelbrot-GUI\test.py", line 25, in <module>
     def loop(z, c, div_time, m, maxIt):
   File "C:\Users\User\AppData\Local\Programs\Python\Python38\lib\site-packages\numba\np\ufunc\decorators.py", line 119, in wrap
     for sig in ftylist:
 TypeError: 'Signature' object is not iterable

我做错了什么?我是否以正确的方式定义了所有的NUMA签名? 这种矢量化方法会超过我目前的实现吗

我将感谢每一个建议!先谢谢你


Tags: inimportdivtrueforreturntimevalue
1条回答
网友
1楼 · 发布于 2024-09-27 09:36:03

您的实现已经矢量化了

矢量化的思想是创建universal functions在数组上按元素操作。您只需要定义在单个元素上执行的操作,矢量化机制将允许使用数组调用您的函数

此函数用于计算单点c:

def mandelbrot_point(c, max_it):
    z = 0j
    for i in range(max_it):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return 0

可以使用Numpy将其矢量化:

@np.vectorize
def mandelbrot_numpy(c, max_it):
    z = 0j
    for i in range(max_it):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return 0

或者你可以使用Numba对其进行矢量化。请注意,函数的签名描述了如何处理单个点:

@nb.vectorize([nb.int64(nb.complex128, nb.int64)])
def mandelbrot_numba(c, max_it):
    z = 0j
    for i in range(max_it):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return 0

然后,可以使用标量或任意维数的数组调用矢量化函数:

>>> p = 0.4+0.4j
>>> mandelbrot_point(p, 99)
8
>>> mandelbrot_numpy(p, 99)
array(8)
>>> mandelbrot_numba(p, 99)
8

>>> x = np.linspace(-2, 2, 11)
>>> mandelbrot_numpy(x, 99)
array([0, 0, 0, 0, 0, 0, 6, 2, 1, 1, 1])
>>> mandelbrot_numba(x, 99)
array([0, 0, 0, 0, 0, 0, 6, 2, 1, 1, 1])

>>> x = np.atleast_2d(x)
>>> y = x.T
>>> c = x + 1j * y
>>> mandelbrot_numpy(c, 99)
array([[ 0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0],
       [ 0,  0,  1,  1,  1,  1,  1,  1,  1,  0,  0],
       [ 0,  1,  1,  2,  2,  2,  1,  1,  1,  1,  0],
       [ 0,  2,  2,  3,  5, 17,  3,  1,  1,  1,  0],
       [ 0,  2,  6,  6,  0,  0,  8,  2,  1,  1,  0],
       [ 0,  0,  0,  0,  0,  0,  6,  2,  1,  1,  1],
       [ 0,  2,  6,  6,  0,  0,  8,  2,  1,  1,  0],
       [ 0,  2,  2,  3,  5, 17,  3,  1,  1,  1,  0],
       [ 0,  1,  1,  2,  2,  2,  1,  1,  1,  1,  0],
       [ 0,  0,  1,  1,  1,  1,  1,  1,  1,  0,  0],
       [ 0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0]])
>>> mandelbrot_numba(c, 99)
array([[ 0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0],
       [ 0,  0,  1,  1,  1,  1,  1,  1,  1,  0,  0],
       [ 0,  1,  1,  2,  2,  2,  1,  1,  1,  1,  0],
       [ 0,  2,  2,  3,  5, 17,  3,  1,  1,  1,  0],
       [ 0,  2,  6,  6,  0,  0,  8,  2,  1,  1,  0],
       [ 0,  0,  0,  0,  0,  0,  6,  2,  1,  1,  1],
       [ 0,  2,  6,  6,  0,  0,  8,  2,  1,  1,  0],
       [ 0,  2,  2,  3,  5, 17,  3,  1,  1,  1,  0],
       [ 0,  1,  1,  2,  2,  2,  1,  1,  1,  1,  0],
       [ 0,  0,  1,  1,  1,  1,  1,  1,  1,  0,  0],
       [ 0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0]])

Numpy的vectorize极大地简化了代码,但正如文档所说,它主要是为了方便而不是为了性能。该实现本质上是一个for循环

根据我的测量,Numpy矢量化版本比原始实现稍微快一点,而Numba矢量化版本快一个数量级

相关问题 更多 >

    热门问题