如何使用numpy访问矩阵的相邻元素?

2024-06-13 08:12:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经做了一个代码,计算流体的溶解,问题是代码是非常差的,所以我一直在看,与numpy我可以优化它,但我一直没有知道如何做下面的代码使用numpy和辊函数。基本上我有一个矩阵,索引不能超过1024,为此我用%来计算它是什么索引。但这需要很长时间。你知道吗

我试着使用numpy,使用roll,旋转矩阵,然后我就不需要计算模块了。但我不知道如何看待邻居的价值观。你知道吗

def evolve(grid, dt, D=1.0):
  xmax, ymax = grid_shape
  new_grid = [[0.0,] * ymax for x in range(xmax)]
  for i in range(xmax):
    for j in range(ymax):
      grid_xx = grid[(i+1)%xmax][j] + grid[(i-1)%xmax][j] - 2.0 * grid[i][j]
      grid_yy = grid[i][(j+1)%ymax] + grid[i][(j-1)%ymax] - 2.0 * grid[i][j]
      new_grid[i][j] = grid[i][j] + D * (grid_xx + grid_yy) * dt
  return new_grid 

Tags: 函数代码innumpynewfordtrange
1条回答
网友
1楼 · 发布于 2024-06-13 08:12:11

必须使用numpy从(几乎)零重写evolve函数。你知道吗

以下是指导原则:

  • 首先,grid必须是2dnumpy数组,而不是列表列表。你知道吗
  • 你的老师建议使用roll函数:看看它的docs,试着理解它是如何工作的。roll将通过移动(或滚动)矩阵中的一个轴来解决在矩阵中查找相邻项的问题。然后可以在四个方向上创建grid的移位版本并使用它们,而不是搜索邻居。你知道吗
  • 一旦你有了移位的grid,你会发现你将不需要for循环来计算new_grid的每个单元格:你可以使用向量化计算,这是更快的。你知道吗

所以代码如下所示:

def evolve(grid, dt, D=1.0):
    if not isinstance(grid, np.ndarray): #ensuring that is a numpy array.
        grid = np.array(grid)
    u_grid = np.roll(grid, 1, axis=0)
    d_grid = np.roll(grid, -1, axis=0)
    r_grid = np.roll(grid, 1, axis=1)
    l_grid = np.roll(grid, -1, axis=1)
    new_grid = grid + D * (u_grid + d_grid + r_grid + l_grid - 4.0*grid) * dt
    return new_grid

对于1024x1024矩阵,每个numpy evolve需要(在我的机器上)约0.15秒才能返回new_grid。使用for循环的evolve大约需要3.85秒。你知道吗

相关问题 更多 >