(编辑以包括数据集和模型代码)
我正在训练Keras CNN 2d矩阵。我正在创建自己的训练数据集,其中每个矩阵单元的形状为[[list], int]
。单元格的第一个列表项是我转换为列表的字符串类的产物(使用tf.keras.utils.to_categorical
):
cell[0] = to_categorical(
rnd_type-1, num_classes=num_types)
第二个是简单的int:
cell[1] = random.randint(0, max_val)
数据集创建函数:
def make_data(num_of_samples, num_types, max_height, grid_x, grid_y):
grids_list = []
target_list = []
target = 0
for _ in range(num_of_samples):
# create empty grid
grid = [[[[],0] for i in range(grid_y)] for j in range(grid_x)]
for i in range(grid_x):
for j in range(grid_y):
rnd_type = random.randint(
0, num_types)
# get random class
# and convert to cat list
cat = to_categorical(
rnd_type-1, num_classes=num_types)
# get random type
rnd_height = random.randint(0, max_height)
# inject the two values into the cell
grid[i][j] = [cat, rnd_height]
# get some target value
target += rnd_type * 5 + random.random()*5
target_list.append(target)
grids_list.append(grid)
# make np arrs out of the lists
t = np.array(target_list)
g = np.array(grids_list)
return t, g
我的模型是使用model = models.create_cnn(grid_size, grid_size, 2, regress=True)
创建的,其中Input
深度是2
模型创建代码:
num_types = 20
max_height = 50
num_of_samples = 10
grid_size = 10
epochs = 5000
# get n results of X x Y grid with target
targets_list, grids_list = datasets.make_data(
num_of_samples, num_types, max_height, grid_size, grid_size)
split = train_test_split(targets_list, grids_list,
test_size=0.25, random_state=42)
(train_attr_X, test_attr_X, train_grids_X, test_grids_X) = split
# find the largest value in the training set and use it to
# scale values to the range [0, 1]
max_target = train_attr_X.max()
train_attr_Y = train_attr_X / max_target
test_attr_Y = test_attr_X / max_target
model = models.create_cnn(grid_size, grid_size, 2, regress=True)
但是,我无法训练它,因为存在以下错误:ValueError: Failed to convert a NumPy array to a Tensor (Unsupported object type list).
回答我自己的问题:
model
只能接受int
作为深度。因此,我的矩阵的深度必须由intlen
列表决定,而不是2D矩阵。因此,将类数据与连续字段rnd_height
合并的方法是:cat = to_categorical
cell = np.append(cat, [rnd_height])
这样,
cat
列表就添加了rnd_height
值。 整个dataset函数现在如下所示:相关问题 更多 >
编程相关推荐