如何使用Numpy将带有多个“if”语句的嵌套“for”循环矢量化?

2024-09-30 20:34:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个简单的2D光线投射程序,当障碍物数量增加时,它会变得非常缓慢

该例行程序由以下部分组成:

  • 2for个循环(外循环在每条光线/方向上迭代,然后内循环在每条线障碍物上迭代)

  • 多个if语句(检查一个值是否大于另一个值或数组是否为空)

问题:如何使用Numpy将所有这些操作压缩成一个向量化指令块

更具体地说,我面临两个问题:

  • 我已设法对内部循环(光线和每个障碍物之间的交点)进行矢量化,但无法同时对所有光线运行此操作

  • 我发现处理if语句的唯一解决方法是使用掩码数组。有些事情告诉我,在这种情况下,处理这些陈述的方式并不恰当(它看起来笨拙、笨重、不和谐)

原始代码:

from math import radians, cos, sin
import matplotlib.pyplot as plt
import numpy as np

N = 10  # dimensions of canvas (NxN)
sides = np.array([[0, N, 0, 0], [0, N, N, N], [0, 0, 0, N], [N, N, 0, N]])
edges = np.random.rand(5, 4) * N # coordinates of 5 random segments (x1, x2, y1, y2)
edges = np.concatenate((edges, sides))
center = np.array([N/2, N/2]) # coordinates of center point
directions = np.array([(cos(radians(a)), sin(radians(a))) for a in range(0, 360, 10)]) # vectors pointing in all directions
intersections = []

# for each direction
for d in directions:
    
    min_dist = float('inf')
    
    # for each edge
    for e in edges:
        p1x, p1y = e[0], e[2]
        p2x, p2y = e[1], e[3]
        p3x, p3y = center
        p4x, p4y = center + d
        
        # find intersection point
        den = (p1x - p2x) * (p3y - p4y) - (p1y - p2y) * (p3x - p4x)
        
        if den:
            t = ((p1x - p3x) * (p3y - p4y) - (p1y - p3y) * (p3x - p4x)) / den
            u = -((p1x - p2x) * (p1y - p3y) - (p1y - p2y) * (p1x - p3x)) / den
            
            # if any:
            if t > 0 and t < 1 and u > 0:
                sx = p1x + t * (p2x - p1x)
                sy = p1y + t * (p2y - p1y)
                
                isec = np.array([sx, sy])           
                dist = np.linalg.norm(isec-center) 
                
                # make sure to select the nearest one (from center)
                if dist < min_dist:
                    min_dist = dist
                    nearest = isec
    
    # store nearest interesection point for each ray
    intersections.append(nearest)

    
# Render
plt.axis('off')

for x, y in zip(edges[:,:2], edges[:,2:]):
    plt.plot(x, y)
    
for isec in np.array(intersections):
    plt.plot((center[0], isec[0]), (center[1], isec[1]), '--', color="#aaaaaa", linewidth=.8)

enter image description here

矢量化版本(尝试)

from math import radians, cos, sin
import matplotlib.pyplot as plt
from scipy import spatial
import numpy as np

N = 10  # dimensions of canvas (NxN)
sides = np.array([[0, N, 0, 0], [0, N, N, N], [0, 0, 0, N], [N, N, 0, N]])
edges = np.random.rand(5, 4) * N # coordinates of 5 random segments (x1, x2, y1, y2)
edges = np.concatenate((edges, sides))
center = np.array([N/2, N/2]) # coordinates of center point
directions = np.array([(cos(radians(a)), sin(radians(a))) for a in range(0, 360, 10)]) # vectors pointing in all directions
intersections = []

# Render edges
plt.axis('off')
for x, y in zip(edges[:,:2], edges[:,2:]):
    plt.plot(x, y)
    
    
# for each direction
for d in directions:

    p1x, p1y = edges[:,0], edges[:,2]
    p2x, p2y = edges[:,1], edges[:,3]
    p3x, p3y = center
    p4x, p4y = center + d

    # denominator
    den = (p1x - p2x) * (p3y - p4y) - (p1y - p2y) * (p3x - p4x)

    # first 'if' statement -> if den > 0
    mask = den > 0
    den = den[mask]
    p1x = p1x[mask]
    p1y = p1y[mask]
    p2x = p2x[mask]
    p2y = p2y[mask]

    t = ((p1x - p3x) * (p3y - p4y) - (p1y - p3y) * (p3x - p4x)) / den
    u = -((p1x - p2x) * (p1y - p3y) - (p1y - p2y) * (p1x - p3x)) / den

    # second 'if' statement -> if (t>0) & (t<1) & (u>0)
    mask2 = (t > 0) & (t < 1) & (u > 0)
    t = t[mask2]
    p1x = p1x[mask2]
    p1y = p1y[mask2]
    p2x = p2x[mask2]
    p2y = p2y[mask2]
    
    # x, y coordinates of all intersection points in the current direction
    sx = p1x + t * (p2x - p1x)
    sy = p1y + t * (p2y - p1y)
    pts = np.c_[sx, sy]
    
    # if any:
    if pts.size > 0:
        
        # find nearest intersection point
        tree = spatial.KDTree(pts)
        nearest = pts[tree.query(center)[1]]

        # Render
        plt.plot((center[0], nearest[0]), (center[1], nearest[1]), '--', color="#aaaaaa", linewidth=.8)

Tags: inforifnppltarraycenteredges
1条回答
网友
1楼 · 发布于 2024-09-30 20:34:56

问题的重新表述–查找线段和线光线之间的交点

qq2为段(障碍物)的端点。为了方便起见,让我们定义一个类来表示平面中的点和向量。除了通常的操作外,向量乘法由u × v = u.x * v.y - u.y * v.x定义

Caution: here Coord(2, 1) * 3 returns Coord(6, 3) while Coord(2, 1) * Coord(-1, 4) outputs 9. To avoid this confusion it might have been possible to restrict * to the scalar multiplication and use ^ via __xor__ for the vector multiplication.

class Coord:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    @property
    def radius(self):
        return np.sqrt(self.x ** 2 + self.y ** 2)

    def _cross_product(self, other):
        assert isinstance(other, Coord)
        return self.x * other.y - self.y * other.x

    def __mul__(self, other):
        if isinstance(other, Coord):
            # 2D "cross"-product
            return self._cross_product(other)
        elif isinstance(other, int) or isinstance(other, float):
            # scalar multiplication
            return Coord(self.x * other, self.y * other)
    
    def __rmul__(self, other):
        return self * other

    def __sub__(self, other):
        return Coord(self.x - other.x, self.y - other.y)

    def __add__(self, other):
        return Coord(self.x + other.x, self.y + other.y)

    def __repr__(self):
        return f"Coord({self.x}, {self.y})"

现在,我发现在极坐标系中处理光线更容易:对于给定的角度theta(方向),目标是确定它是否与线段相交,如果是,则确定相应的半径。这里有一个函数可以找到它。请参见here了解原因和方式的解释。我尝试使用与上一个链接中相同的变量名

def find_intersect_btw_ray_and_sgmt(q, q2, theta):
    """
    Args:
        q (Coord): first endpoint of the segment
        q2 (Coord): second endpoint of the segment
        theta (float): angle of the ray

    Returns:
        (float): np.inf if the ray does not intersect the segment,
                 the distance from the origin of the intersection otherwise
   """
    assert isinstance(q, Coord) and isinstance(q2, Coord)
    s = q2 - q
    r = Coord(np.cos(theta), np.sin(theta))
    cross = r * s  # 2d cross-product
    t_num = q * s
    u_num = q * r
    ## the intersection point is roughly at a distance t_num / cross
    ## from the origin. But some cases must be checked beforehand.

    ## (1) the segment [PQ2] is aligned with the ray
    if np.isclose(cross, 0) and np.isclose(u_num, 0):
        return min(q.radius, q2.radius)
    ## (2) the segment [PQ2] is parallel with the ray
    elif np.isclose(cross, 0):
        return np.inf
   
    t, u = t_num / cross, u_num / cross
    ## There is actually an intersection point
    if t >= 0 and 0 <= u <= 1:
        return t
   
    ## (3) No intersection point
    return np.inf

例如find_intersect_btw_ray_and_sgmt(Coord(1, 2), Coord(-1, 2), np.pi / 2)应该返回2

Note that here for simplicity, I only considered the case where the origin of the rays is at Coord(0, 0). This can be easily extended to the general case by setting t_num = (q - origin) * s and u_num = (q - origin) * r.

让我们把它矢量化

这里非常有趣的是Coord类中定义的操作也适用于x和y是numpy数组的情况!因此,对Coord(np.array([1, 2, 0]), np.array([2, -1, 3]))应用任何已定义的操作都等于将其元素级应用于点(1, 2), (2, -1) and (0, 3)。因此Coord的操作已经矢量化了。构造函数可以修改为:

    def __init__(self, x, y):
        x, y = np.array(x), np.array(y)
        assert x.shape == y.shape
        self.x, self.y = x, y
        self.shape = x.shape

现在,我们希望函数find_intersect_btw_ray_and_sgmt能够处理参数qq2包含端点序列的情况。在进行健全性检查之前,所有操作都正常工作,因为正如我们所提到的,它们已经矢量化了。正如您提到的,条件语句可以使用掩码“矢量化”。以下是我的建议:

def find_intersect_btw_ray_and_sgmts(q, q2, theta):
    assert isinstance(q, Coord) and isinstance(q2, Coord)
    assert q.shape == q2.shape
    EPS = 1e-14
    s = q2 - q
    r = Coord(np.cos(theta), np.sin(theta))
    cross = r * s
   
    cross_sign = np.sign(cross)
    cross = cross * cross_sign
    t_num = (q * s) * cross_sign
    u_num = (q * r) * cross_sign

    radii = np.zeros_like(t_num)
    mask = ~np.isclose(cross, 0) & (t_num >= -EPS) & (-EPS <= u_num) & (u_num <= cross + EPS)

    radii[~mask] = np.inf  # no intersection
    radii[mask] = t_num[mask] / cross[mask]  # intersection
    return radii

Note that cross, t_num and u_num are multiplied by the sign of cross to ensure that the division by cross keeps the sign of the dividends. Hence conditions of the form ((t_num >= 0) & (cross >= 0)) | ((t_num <= 0) & (cross <= 0)) can be replaced by (t_num >= 0).

For simplicity, we omitted the case (1) where the radius and the segment were aligned ((cross == 0) & (u_num == 0)). This could be incorporated by carefully adding a second mask.

对于给定的θ值,我们可以确定相应的光线是否同时与多个线段相交

## Some useful functions
def polar_to_cartesian(r, theta):
    return Coord(r * np.cos(theta), r * np.sin(theta))

def plot_segments(p, q, *args, **kwargs):
    plt.plot([p.x, q.x], [p.y, q.y], *args, **kwargs)

def plot_rays(radii, thetas, *args, **kwargs):
    endpoints = polar_to_cartesian(radii, thetas)
    n = endpoints.shape
    origin = Coord(np.zeros(n), np.zeros(n))
    plot_segments(origin, endpoints, *args, **kwargs)

## Data generation
M = 5  # size of the canvas
N = 10  # number of segments
K = 16  # number of rays

q = Coord(*np.random.uniform(-M/2, M/2, size=(2, N)))
p = q + Coord(*np.random.uniform(-M/2, M/2, size=(2, N)))
thetas = np.linspace(0, 2 * np.pi, K, endpoint=False)

## For each ray, find the minimal distance of intersection
## with all segments
plt.figure(figsize=(5, 5))
plot_segments(p, q, "royalblue", marker=".")

for theta in thetas:
    radii = find_intersect_btw_ray_and_sgmts(p, q, theta)
    radius = np.min(radii)
    if not np.isinf(radius):
        plot_rays(radius, theta, color="orange")
    else:
        plot_rays(2*M, theta, ':', c='orange')

plt.plot(0, 0, 'kx')
plt.xlim(-M, M)
plt.ylim(-M, M)

这还不是全部!多亏了python的传播,才有可能避免对θ值的迭代。例如,回想一下np.array([1, 2, 3]) * np.array([[1], [2], [3], [4]])生成一个大小为4×3的成对乘积矩阵。以同样的方式Coord([[5],[7]], [[5],[1]]) * Coord([2, 4, 6], [-2, 4, 0])输出一个2×3矩阵,其中包含向量(5,5)、(7,1)和(2,-2)、(4,4)、(6,0)之间的所有成对叉积

最后,可通过以下方式确定交叉点:

radii_all = find_intersect_btw_ray_and_sgmts(p, q, np.vstack(thetas))
# p and q have a shape of (N,) and np.vstack(thetas) of (K, 1)
# this radii_all have a shape of (K, N)
# radii_all[k, n] contains the distance from the origin of the intersection
# between k-th ray and n-th segment (or np.inf if there is no intersection point)
radii = np.min(radii_all, axis=1)
# radii[k] contains the distance from the origin of the closest intersection
# between k-th ray and all segments


do_intersect = ~np.isinf(radii)
plot_rays(radii[do_intersect], thetas[do_intersect], color="orange")
plot_rays(2*M, thetas[~do_intersect], ":", color="orange")

results

相关问题 更多 >