重塑输出尺寸以适合Keras模型

2024-09-28 01:32:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个用于cat数据集关键点检测的Keras模型。对于每个彩色图像,有3个关键点和相应的3个热图。 模型的输入图像为64,64,3 相应的输出为形状3,64,64

我正在使用以下功能准备裁剪和调整大小的热图:

def crop_heatmaps():

   dataset['cropped_heatmaps'] = []
   
   for i in range(len(dataset['heatmaps'])):

        cropped_heats = []
        heatmaps = dataset['heatmaps'][i]
        bb = dataset['bbs'][i]
        

        x1 = max(bb[0] - 20, 0) #avoid negative coordinates of the extended bounding box
        y1 = max(bb[1] - 20, 0)
        x2 = bb[2] + 20
        y2 = bb[3] + 20

        for heat in heatmaps:

            cropped_heat = heat[y1:y2, x1:x2]
            resized_heat = cv2.resize(cropped_heat, (64, 64))
            #plt.imshow(resized_heat)

            cropped_heats.append(resized_heat)

            cropped_heatmaps = np.array(cropped_heats)
            
            dataset['cropped_heatmaps'].append(cropped_heats)

我为输入和输出创建了2个DataImageGenerator,并将它们压缩在一起

train_generator = zip(img_train_generator, heatmaps_train_generator)

history = model.fit((pair for pair in train_generator),
                    epochs=30,
                    validation_data=(),
                    verbose=1
                  )

在拟合模型时,我遇到以下错误:不兼容的形状:[128,64,3,64]与[128,3,64,64]

模型如下所示:

输入_1(输入层)[(无、64、64、3)]0


区块1_conv1(Conv2D)(无、64、64、64)1792


区块1_conv2(Conv2D)(无、64、64、64)36928


block1_池(MaxPoolig2D)(无、32、32、64)0


区块2_conv1(Conv2D)(无、32、32、128)73856


区块2_conv2(Conv2D)(无、32、32、128)147584


瓶颈_1(Conv2D)(无、32、32、160)5243040


瓶颈_2(Conv2D)(无、32、32、160)25760


上样本_1(conv2dtranpse)(无,64,64,3)1920

我试过了

np.reshape(cropped_heatmaps,(64,64,3))

但它没有起作用。如何重塑热图以获得64,64,3?(3个频道)


Tags: in模型fortrain区块generatordataset热图
1条回答
网友
1楼 · 发布于 2024-09-28 01:32:26

如果您想更改轴1和2 ip,可以使用: np.moveaxis(x,1,2)

样本:

import numpy as np

x = np.zeros((128,64,3,32))
print(x.shape)

y = np.moveaxis(x,1,2)
print(y.shape)

输出:

(128, 64, 3, 32)
(128, 3, 64, 32)
>>> 

带(64,64,3)到(3,64,64)

可以使用:

import numpy as np

x = np.zeros((64,64,3))
print(x.shape)

y = np.moveaxis(x,-1,0)
print(y.shape)

输出:

(64, 64, 3)
(3, 64, 64)
>>> 

相关问题 更多 >

    热门问题