我在here的CapsNet上工作 ,它是在具有10个数字的MNIST数据集上实现的,但我已将代码更改为使用具有三个类的数据集。模型训练和测试工作正常,但操纵潜在函数会导致错误:
def manipulate_latent(model, data, args):
x_test, y_test = data
index = np.argmax(y_test, 1) == args.digit
print(index)
number = np.random.randint(low=0, high=sum(index) - 1)
x, y = x_test[index][number], y_test[index][number]
x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
noise = np.zeros([1, 3, 16])
x_recons = []
for dim in range(16):
for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
tmp = np.copy(noise)
tmp[:,:,dim] = r
x_recon = model.predict([x, y, tmp])
x_recons.append(x_recon)
x_recons = np.concatenate(x_recons)
img = combine_images(x_recons, height=16)
image = img*255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)
输出为:
number=np.random.randint(低=0,高=sum(索引)-1) ValueError:低>;=高
函数调用:
model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
routings=args.routings)
manipulate_latent(manipulate_model, (x_test, y_test), args)
这是因为您正在使用
sum()
而不是len()
输出
注意,
False
数组的sum()
等于0。而len()
是数组的大小相关问题 更多 >
编程相关推荐