从numpy数组内存效率中删除元素

2024-09-28 19:22:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我目前正在制作一个需要培训和验证数据的机器学习模型。我首先将所有数据加载到image_set(输入)和{}(输出)中。我要做的是编写一个python脚本,它获取10%(min:1个数据点)的数据,并分别存储在validata_image和{}中。问题是,当加载这些图像时,我的计算机内存不足,所以我不断遇到内存错误。下面是我对这个过程的代码:

# "image_set" and "array_set" initialized above
length = image_set.shape[0]
numValData = 1
if (math.floor(length * .1) > 1):
    numValData = math.floor(length * .1)
    #Have the Validation data take up 10% of total data

validata_image = np.zeros((numValData,1000,1000,3), dtype='float32')
validata_array = np.zeros((numValData,1000,1000,1), dtype='float32')
#Each image is 1000x1000x3 (R, G, and B channels), while each array is (1000x1000x1)
#This large size is the reason why the data takes so much memory.

for i in range(numValData):
    rand = random.randint(0, image_set.shape[0] - 1)
    a = image_set[rand]
    b = array_set[rand]
    image_set = np.delete(image_set, rand, axis = 0)
    array_set = np.delete(array_set, rand, axis = 0)
    validata_image[i] = a
    validata_array[i] = b

#Rest of program executes here...

我相信很多内存使用来自np.delete()行。This问题与我的问题相似,但似乎完整的问题从来没有真正回答过。deepcopy()函数看起来很有前途,但是没有提供一种方法来从初始列表中删除元素。转换为常规数组并使用“pop()”似乎在内存使用方面也相当低效,所以真的不知道最好的方法是什么。如果有人能帮我解决这个问题,我将不胜感激。在


Tags: andthe数据内存imagedataisnp