Numpy中稀疏双邻接矩阵的高效构造

2024-09-30 00:31:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图将这个CSV文件加载到一个稀疏的numpy矩阵中,这个矩阵将表示这个用户的双相邻矩阵到subreddit二部图:http://figshare.com/articles/reddit_user_posting_behavior/874101

下面是一个示例:

603,politics,trees,pics
604,Metal,AskReddit,tattoos,redditguild,WTF,cocktails,pics,funny,gaming,Fitness,mcservers,TeraOnline,GetMotivated,itookapicture,Paleo,trackers,Minecraft,gainit
605,politics,IAmA,AdviceAnimals,movies,smallbusiness,Republican,todayilearned,AskReddit,WTF,IWantOut,pics,funny,DIY,Frugal,relationships,atheism,Jeep,Music,grandrapids,reddit.com,videos,yoga,GetMotivated,bestof,ShitRedditSays,gifs,technology,aww

共有876961行(每个用户一个)和15122个子编,共有8495597个用户到子编程序的关联。在

下面是我现在掌握的代码,在我的MacBook Pro上运行需要20分钟:

^{pr2}$

似乎很难相信这是如此之快。。。将82MB文件加载到列表列表中需要5秒,但构建稀疏矩阵需要200倍。我该怎么做才能加快速度?有没有一些文件格式,我可以转换成这个CSV在不到20分钟,将导入更快?我在这里做的手术显然很昂贵,不好吗?我尝试过构建一个稠密矩阵,并尝试创建一个lil_matrix和一个dok_matrix,一次分配一个1的矩阵,但速度并不快。在


Tags: 文件csv用户numpycom列表矩阵matrix
2条回答

首先,您可以将内部for替换为以下内容:

reddit_idx = np.nonzero(np.in1d(reddits_list,row))[0]
sl = slice(i,i+len(reddit_idx))
cols[sl] = user_idx
rows[sl] = reddit_idx
i = sl.stop

使用nonzero(in1d())查找匹配项看起来不错,但我还没有探索其他替代方法。另一种通过切片赋值的方法是extend列表,但这可能较慢,尤其是对于许多行。在

构建行时,cols是迄今为止最慢的部分。对csr_matrix的调用是次要的。在

由于行(用户)比subreddit多很多,因此可能值得为每个subreddit收集一个用户id列表。您已经在一个集合中收集了subreddits。相反,您可以在默认字典中收集它们,然后从中构建矩阵。当在你的3条线上测试时,它的速度明显更快。在

^{pr2}$

睡不着,最后一件事。。。我可以用这种方式把时间缩短到10秒,最后:

import numpy as np
from scipy.sparse import csr_matrix 

user_ids = []
subreddit_ids = []
subreddits = {}
i=0
with open("reddit_user_posting_behavior.csv", 'r') as f:
    for line in f:
        for sr in line.rstrip().split(",")[1:]: 
            if sr not in subreddits: 
                subreddits[sr] = len(subreddits)
            user_ids.append(i)
            subreddit_ids.append(subreddits[sr])
        i+=1

adj = csr_matrix( 
    ( np.ones((len(userids),)), (np.array(subreddit_ids),np.array(user_ids)) ), 
    shape=(len(subreddits), i) )

相关问题 更多 >

    热门问题