numpy怎么能比我的Fortran程序快这么多呢?

2024-05-19 21:14:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我得到了一个512^3的数组,表示一个模拟的温度分布(用Fortran编写)。数组存储在一个大约1/2G大小的二进制文件中。我需要知道这个数组的最小值、最大值和平均值,因为我很快就需要理解Fortran代码,所以我决定尝试一下,并提出了以下非常简单的例程。

  integer gridsize,unit,j
  real mini,maxi
  double precision mean

  gridsize=512
  unit=40
  open(unit=unit,file='T.out',status='old',access='stream',&
       form='unformatted',action='read')
  read(unit=unit) tmp
  mini=tmp
  maxi=tmp
  mean=tmp
  do j=2,gridsize**3
      read(unit=unit) tmp
      if(tmp>maxi)then
          maxi=tmp
      elseif(tmp<mini)then
          mini=tmp
      end if
      mean=mean+tmp
  end do
  mean=mean/gridsize**3
  close(unit=unit)

在我使用的机器上,每个文件大约需要25秒。我觉得这段时间很长,所以我继续用Python做了以下工作:

    import numpy

    mmap=numpy.memmap('T.out',dtype='float32',mode='r',offset=4,\
                                  shape=(512,512,512),order='F')
    mini=numpy.amin(mmap)
    maxi=numpy.amax(mmap)
    mean=numpy.mean(mmap)

现在,我当然希望这会更快,但我真的被吹走了。在相同的条件下不到一秒钟。平均值与我的Fortran例程找到的值(我也用128位浮点运算,所以我更相信它)不同,但只在7位左右有效。

努比怎么能这么快?我的意思是你必须查看数组的每个条目才能找到这些值,对吧?我在Fortran程序中做了一些愚蠢的事情,让它花了这么长时间吗?

编辑:

回答评论中的问题:

  • 是的,我还运行了带有32位和64位浮点的Fortran例程,但它对性能没有影响。
  • 我使用了提供128位浮点的^{}
  • 使用32位浮点运算,我的平均值有点偏离,所以精度确实是个问题。
  • 我以不同的顺序在不同的文件上运行这两个例程,所以缓存应该是公平的,我想比较一下吧?
  • 我试过打开MP,但同时从不同的位置读取文件。读了你的评论和答案之后,这听起来真的很蠢,这也让你的日常工作花费了更长的时间。我可能会尝试一下数组操作,但可能根本不需要。
  • 这些文件的大小实际上是1/2G,这是一个打字错误,谢谢。
  • 我现在将尝试数组实现。

编辑2:

我实现了@Alexander Vogt和@casey在他们的答案中所建议的,而且速度和numpy一样快,但是现在我遇到了一个精确的问题,正如@Luaan指出的那样。使用32位浮点数组,由sum计算的平均值是20%的折扣。做

...
real,allocatable :: tmp (:,:,:)
double precision,allocatable :: tmp2(:,:,:)
...
tmp2=tmp
mean=sum(tmp2)/size(tmp)
...

解决了这个问题,但是增加了计算时间(虽然不是很大,但是很明显)。 有没有更好的方法来解决这个问题?我找不到直接从文件中读取单打到双打的方法。 如何避免这种情况?

谢谢你到目前为止的帮助。


Tags: 文件numpyreadunit数组mean例程tmp
2条回答

Fortran实现有两个主要缺点:

  • 您可以混合IO和计算(并逐项读取文件项)。
  • 不使用向量/矩阵运算。

此实现执行的操作与您的相同,在我的计算机上速度快20倍:

program test
  integer gridsize,unit
  real mini,maxi,mean
  real, allocatable :: tmp (:,:,:)

  gridsize=512
  unit=40

  allocate( tmp(gridsize, gridsize, gridsize))

  open(unit=unit,file='T.out',status='old',access='stream',&
       form='unformatted',action='read')
  read(unit=unit) tmp

  close(unit=unit)

  mini = minval(tmp)
  maxi = maxval(tmp)
  mean = sum(tmp)/gridsize**3
  print *, mini, maxi, mean

end program

其思想是一次性将整个文件读入一个数组tmp。然后,我可以直接在数组上使用函数^{}^{}^{}


对于精度问题:只需使用双精度值,并在运行时进行转换

mean = sum(real(tmp, kind=kind(1.d0)))/real(gridsize**3, kind=kind(1.d0))

只会略微增加计算时间。我试着以片段的方式执行操作元素,但这只增加了默认优化级别所需的时间。

-O3处,按元素添加比数组操作执行大约3%的操作。在我的机器上,双精度操作和单精度操作之间的差异平均不到2%(单个运行的偏差要大得多)。


下面是一个使用LAPACK的快速实现:

program test
  integer gridsize,unit, i, j
  real mini,maxi
  integer  :: t1, t2, rate
  real, allocatable :: tmp (:,:,:)
  real, allocatable :: work(:)
!  double precision :: mean
  real :: mean
  real :: slange

  call system_clock(count_rate=rate)
  call system_clock(t1)
  gridsize=512
  unit=40

  allocate( tmp(gridsize, gridsize, gridsize), work(gridsize))

  open(unit=unit,file='T.out',status='old',access='stream',&
       form='unformatted',action='read')
  read(unit=unit) tmp

  close(unit=unit)

  mini = minval(tmp)
  maxi = maxval(tmp)

!  mean = sum(tmp)/gridsize**3
!  mean = sum(real(tmp, kind=kind(1.d0)))/real(gridsize**3, kind=kind(1.d0))
  mean = 0.d0
  do j=1,gridsize
    do i=1,gridsize
      mean = mean + slange('1', gridsize, 1, tmp(:,i,j),gridsize, work)
    enddo !i
  enddo !j
  mean = mean / gridsize**3

  print *, mini, maxi, mean
  call system_clock(t2)
  print *,real(t2-t1)/real(rate)

end program

这在矩阵列上使用单精度矩阵1-范数^{}。运行时甚至比使用单精度数组函数的方法更快,而且不显示精度问题。

numpy的速度更快,因为您用python编写的代码效率更高(而且numpy后端的大部分代码都是用优化的Fortran和C编写的),而Fortran编写的代码效率非常低。

看看你的python代码。一次加载整个数组,然后调用可以对数组进行操作的函数。

看看你的fortran代码。一次只读取一个值并对其执行分支逻辑。

您的大部分差异是您用Fortran编写的零碎IO。

你可以像编写python一样编写Fortran,这样你会发现它运行得更快。

program test
  implicit none
  integer :: gridsize, unit
  real :: mini, maxi, mean
  real, allocatable :: array(:,:,:)

  gridsize=512
  allocate(array(gridsize,gridsize,gridsize))
  unit=40
  open(unit=unit, file='T.out', status='old', access='stream',&
       form='unformatted', action='read')
  read(unit) array    
  maxi = maxval(array)
  mini = minval(array)
  mean = sum(array)/size(array)
  close(unit)
end program test

相关问题 更多 >