计算持久性图之间的wasserstein距离

2024-10-02 12:23:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图计算通过Python-Ripser库生成的两个持久性图之间的Wasserstein距离。我在Persim中发现了两个有趣的函数:切片的wasserstein和wasserstein匹配。你知道吗

我的图表生成如下所示:

    data = json.loads(data)
    data = pd.DataFrame.from_dict(data)
    rips = Rips()
    dgms = rips.fit_transform(data)
    for i in dgms:
        print(type(i))
        i.tofile(directory+"diagram.txt")
    plot_diagrams(dgms, show=False)
    plt.savefig("persistence_diagram.png")
    plt.close()

“dgms”是一个包含numpy数组的列表,所以我要在“for”行中将它们取出。你知道吗

我的Wasserstein函数用法如下:

with open(loc) as f:
    img1 = np.fromfile(f)
    f.close()
with open(loc2) as f:
    img2 = np.fromfile(f)
    f.close()
persim.sliced_wasserstein(img1, img2)

我试图将匹配三种数据(diagram in.png、dgms list和np.数组)但我经常得到的只是一个错误“IndexError:数组的索引太多”。 所以我换成了切片的瓦瑟斯坦,在那里我得到了这样的错误:

Traceback (most recent call last):
  File "C:/Users/Patka/PycharmProjects/MGR/Mapper.py", line 26, in <module>
    persim.sliced_wasserstein(img1, img2)
  File "C:\Users\Patka\environmentpython\lib\site-packages\persim\sliced_wasserstein.py", line 53, in sliced_wasserstein
    sw += step * cityblock(sorted(V1), sorted(V2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

对我来说,一件奇怪的事情是,当我在保存到文件之前打印I.shape时,我得到了两个维度,例如(12,2),但是当我使用numpy.fromfile文件()我得到一个元组(12,)。你知道吗

有人能治好吗?我的最终目标是计算大量图表的距离,并对它们进行聚类,但我一直在比较两个图表。。。你知道吗


Tags: inclosedatawithnp图表数组diagram

热门问题