在matplotlib的子图上正确使用Lasso

4 投票
1 回答
2050 浏览
提问于 2025-04-18 04:37

我想创建一个散点图矩阵,这个矩阵会由几个小图组成。我从一个.txt文件中提取了我的数据,并创建了一个形状为 (x,y,z,p1,p2,p3,p4) 的数组。数组的前三列代表了这些数据来源于原始图像的x、y、z坐标,最后四列(p1, p2, p3, p4)是一些其他参数。因此,在数组的每一行中,参数p1、p2、p3、p4都有相同的坐标(x,y,z)。在散点图中,我想要把每个p_i(例如p1)参数与其他p_i(例如p2、p3、p4)参数进行可视化。

我想在每个小图中画一个感兴趣区域(ROI),并突出显示每个小图中包含在ROI内的点。在每个小图中可视化不同的参数(例如p1与p2),但在每个小图中有一个点的x、y、z坐标在其他小图中也有对应的点。我通过使用 matplotlib 的一个示例 Lasso 来实现ROI的绘制。以下图展示了这个代码实现的一个示例。Figure 1

但是我的实现出现了问题。我可以在每个小图中绘制lasso,但只有在特定小图中绘制lasso时,点才会被突出显示,而这个特定小图对应于我代码中第一次调用的 LassoManager 函数(在我的代码中是 selector1)。如下一幅图所示,给不同小图中可以绘制的lasso设置了初始值,但只有与selector 1对应的id被使用,这导致了代码的故障,无论我在哪个小图中绘制ROI。

Figure 2

这是我的代码:

"""
Show how to use a lasso to select a set of points and get the indices
of the selected points.  A callback is used to change the color of the
selected points

This is currently a proof-of-concept implementation (though it is
usable as is).  There will be some refinement of the API.
"""


from matplotlib.widgets import Lasso
from matplotlib.colors import colorConverter
from matplotlib.collections import RegularPolyCollection
from matplotlib import path


import matplotlib.pyplot as plt
import numpy as np

class Datum(object):
      colorin = colorConverter.to_rgba('red')
      colorout = colorConverter.to_rgba('blue')
      def __init__(self, x, y, include=False):
          self.x = x
          self.y = y
          if include: self.color = self.colorin
          else: self.color = self.colorout

class LassoManager(object):
    #class for highlighting region of points within a Lasso
      def __init__(self, ax, data):


          self.axes = ax
          self.canvas = ax.figure.canvas
          self.data = data
          self.Nxy = len(data)

          facecolors = [d.color for d in data]
          self.xys = [(d.x, d.y) for d in data]
          fig = ax.figure
          self.collection = RegularPolyCollection(
             fig.dpi, 6, sizes=(1,),
             facecolors=facecolors,
             offsets = self.xys,
             transOffset = ax.transData)

          ax.add_collection(self.collection)

          self.cid = self.canvas.mpl_connect('button_press_event', self.onpress)

      def callback(self, verts):
          facecolors = self.collection.get_facecolors()
          print "The id of this lasso is", id(self)


          p = path.Path(verts)
          ind = p.contains_points(self.xys)
          #ind prints boolean array of points in subplot where true means that the point is included

          for i in range(len(self.xys)):


              if ind[i]:

                # facecolors[i] = Datum.colorin
                axes[0][0].plot(x[i], y[i], 'ro',  ls='',  picker=3)
                axes[2][0].plot(x[i], y1[i], 'ro',  ls='',  picker=3)
                axes[1][0].plot(x[i], x1[i], 'ro',  ls='',  picker=3)
                axes[1][4].plot(y[i], x1[i], 'ro',  ls='',  picker=3)
                axes[2][5].plot(x1[i], y1[i], 'ro',  ls='',  picker=3)
                axes[2][6].plot(y[i], y1[i], 'ro',  ls='',  picker=3)
                # print ind[i], x[i], y[i], x1[i], y1[i]

              else:

                # facecolors[i] = Datum.colorout
                axes[0][0].plot(x[i], y[i], 'bo',  ls='',  picker=3)
                axes[2][0].plot(x[i], y1[i], 'bo',  ls='',  picker=3)
                axes[1][0].plot(x[i], x1[i], 'bo',  ls='',  picker=3)
                axes[1][7].plot(y[i], x1[i], 'bo',  ls='',  picker=3)
                axes[2][8].plot(x1[i], y1[i], 'bo',  ls='',  picker=3)
                axes[2][9].plot(y[i], y1[i], 'bo',  ls='',  picker=3)

          plt.draw()


          self.canvas.draw_idle()
          self.canvas.widgetlock.release(self.lasso)
          del self.lasso
          # noinspection PyArgumentList


      def onpress(self, event):
          if self.canvas.widgetlock.locked(): return
          if event.inaxes is None: return

          self.lasso = Lasso(event.inaxes, (event.xdata, event.ydata), self.callback)

          # acquire a lock on the widget drawing
          self.canvas.widgetlock(self.lasso)




if __name__ == '__main__':

   dat = np.loadtxt(r"parameters.txt")
   x, y = dat[:, 3], dat[:, 4]  #p1,p2
   x1, y1 = dat[:, 5], dat[:, 6]  #p3,p4

   a = np.array([x,y])  #p1,p2
   a = a.transpose()

   b = np.array([x,y1])  #p1,p4
   b = b.transpose()

   c = np.array([x,x1])  #p1,p3
   c = c.transpose()

   d = np.array([y,x1])  #p3,p2
   d = d.transpose()

   e = np.array([x1,y1])  #p3,p4
   e = e.transpose()

   f = np.array([y,y1])  ##p2, p4
   f = f.transpose()


   data = []

   data0 = [Datum(*xy) for xy in a]   #p1,p2
   data.append(data0)
   data1 = [Datum(*xy) for xy in b]   #p1,p4
   data.append(data1)
   data2 = [Datum(*xy) for xy in c]   #p1,p3
   data.append(data2)
   data3 = [Datum(*xy) for xy in d]   #p3,p2
   data.append(data3)
   data4 = [Datum(*xy) for xy in e]   #p3,p4
   data.append(data4)
   data5 = [Datum(*xy) for xy in f]   #p2, p4
   data.append(data5)

   #print data
   #print len(data)

   fig, axes = plt.subplots(ncols=3, nrows=3)

   axes[0][0].plot(x, y, 'bo',  ls='',  picker=3)
   axes[0][0].set_xlabel('p1')
   axes[0][0].set_ylabel('p2')
   axes[0][0].set_xlim((min(x)-50, max(x)+50))
   axes[0][0].set_ylim((min(y)-50, max(y)+50))
   selector1 = LassoManager(axes[0][0], data[0])
   print "selector1 is", id(selector1)      #lman.append(l1)

   #p1 vs p4
   axes[2][0].plot(x, y1, 'bo',  ls='',  picker=3)
   axes[2][0].set_xlabel('p1')
   axes[2][0].set_ylabel('p4')
   axes[2][0].set_xlim((min(x)-50, max(x)+50))
   axes[2][0].set_ylim((min(y1)-40, max(y1)+50))
   selector2 = LassoManager(axes[2][0], data[1])
   print "selector2 is", id(selector2)


   #p1 vs p3
   axes[1][0].plot(x, x1, 'bo',  ls='',  picker=3)
   axes[1][0].set_xlabel('p1')
   axes[1][0].set_ylabel('p3')
   axes[1][0].set_xlim((min(x)-50, max(x)+50))
   axes[1][0].set_ylim((min(x1)-40, max(x1)+50))
   selector3 = LassoManager(axes[1][0], data[2])
   print "selector3 is", id(selector3)

   #p2 vs p3
   axes[1][10].plot(y, x1, 'bo',  ls='',  picker=3)
   axes[1][11].set_xlabel('p2')
   axes[1][12].set_ylabel('p3')
   axes[1][13].set_xlim((min(y)-50, max(y)+50))
   axes[1][14].set_ylim((min(x1)-40, max(x1)+50))
   selector4 =  LassoManager(axes[1][15], data[3])
   print "selector4 is", id(selector4)




   #p2 vs p4
   axes[2][16].plot(y, y1, 'bo',  ls='',  picker=3)
   axes[2][17].set_xlabel('p2')
   axes[2][18].set_ylabel('p4')
   axes[2][19].set_xlim((min(y)-50, max(y)+50))
   axes[2][20].set_ylim((min(y1)-40, max(y1)+50))
   selector5 = LassoManager(axes[2][21], data[5])
   print "selector5 is", id(selector5)


   #p3 vs p4
   axes[2][22].plot(x1, y1, 'bo',  ls='',  picker=3)
   axes[2][23].set_xlabel('p3')
   axes[2][24].set_ylabel('p4')
   axes[2][25].set_xlim((min(x1)-50, max(x1)+50))
   axes[2][26].set_ylim((min(y1)-40, max(y1)+50))
   selector6 = LassoManager(axes[2][27], data[4])
   print "selector6 is", id(selector6)


   #non-visible subplots
   axes[0][28].plot(x,x)
   axes[0][29].set_visible(False)
   axes[0][30].plot(y,y)
   axes[0][31].set_visible(False)
   axes[1][32].plot(x1,x1)
   axes[1][33].set_visible(False)

   plt.subplots_adjust(left=0.1, right=0.95, wspace=0.6, hspace=0.7)

   plt.show()

为什么我的代码会出现这种情况?代码没有错误,但它没有正确工作。任何帮助都将不胜感激!!

1 个回答

2

根据我的理解,问题在于每次调用 init 时,你都在用一个新的事件处理程序替换画布上的 button_press_event

很可能你需要用一个 button_press_event 的回调函数来处理所有的坐标轴,因为它们都是通过同一个画布对象进行交互的。

解决办法

下面是一个可以正常工作的例子,基于官方文档中的套索示例。

我尝试的方法是只创建一个 LassoManager(因为每个图形只与一个画布交互),但让坐标轴、数据等成为每个子图的列表。

然后,回调函数通过访问 current_axis 成员来确定当前哪个坐标轴是活跃的。

"""
Show how to use a lasso to select a set of points and get the indices
of the selected points.  A callback is used to change the color of the
selected points

This is currently a proof-of-concept implementation (though it is
usable as is).  There will be some refinement of the API.
"""
from matplotlib.widgets import Lasso
from matplotlib.colors import colorConverter
from matplotlib.collections import RegularPolyCollection
from matplotlib import path

import matplotlib.pyplot as plt
from numpy import nonzero
from numpy.random import rand

class Datum(object):
    colorin = colorConverter.to_rgba('red')
    colorout = colorConverter.to_rgba('blue')
    def __init__(self, x, y, include=False):
        self.x = x
        self.y = y
        if include: self.color = self.colorin
        else: self.color = self.colorout


class LassoManager(object):
    def __init__(self, ax, data):
        self.axes = [ax]
        self.canvas = ax.figure.canvas
        self.data = [data]

        self.Nxy = [ len(data) ]

        facecolors = [d.color for d in data]
        self.xys = [ [(d.x, d.y) for d in data] ]
        fig = ax.figure
        self.collection = [ RegularPolyCollection(
            fig.dpi, 6, sizes=(100,),
            facecolors=facecolors,
            offsets = self.xys[0],
            transOffset = ax.transData)]

        ax.add_collection(self.collection[0])

        self.cid = self.canvas.mpl_connect('button_press_event', self.onpress)

    def callback(self, verts):

        axind = self.axes.index(self.current_axes)
        facecolors = self.collection[axind].get_facecolors()
        p = path.Path(verts)
        ind = p.contains_points(self.xys[axind])
        for i in range(len(self.xys[axind])):
            if ind[i]:
                facecolors[i] = Datum.colorin
            else:
                facecolors[i] = Datum.colorout

        self.canvas.draw_idle()
        self.canvas.widgetlock.release(self.lasso)
        del self.lasso

    def onpress(self, event):
        if self.canvas.widgetlock.locked(): return
        if event.inaxes is None: return
        self.current_axes = event.inaxes

        self.lasso = Lasso(event.inaxes, (event.xdata, event.ydata), self.callback)
        # acquire a lock on the widget drawing
        self.canvas.widgetlock(self.lasso)

    def add_axis(self, ax,  data):
        self.axes.append(ax)
        self.data.append(data)

        self.Nxy.append( len(data) )

        facecolors = [d.color for d in data]
        self.xys.append( [(d.x, d.y) for d in data] )
        fig = ax.figure
        self.collection.append( RegularPolyCollection(
            fig.dpi, 6, sizes=(100,),
            facecolors=facecolors,
            offsets = self.xys[-1],
            transOffset = ax.transData))

        ax.add_collection(self.collection[-1])



if __name__ == '__main__':

    data = [Datum(*xy) for xy in rand(100, 2)]
    data2 = [Datum(*xy) for xy in rand(100, 2)]

    ax = plt.subplot(1,2,1)
    lman = LassoManager(ax, data)
    ax2 = plt.subplot(1,2,2)
    lman.add_axis(ax2, data2)
    plt.show()

撰写回答