np.random.randint导致ValueError:低>=高

2024-09-29 23:28:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我在here的CapsNet上工作 ,它是在具有10个数字的MNIST数据集上实现的,但我已将代码更改为使用具有三个类的数据集。模型训练和测试工作正常,但操纵潜在函数会导致错误:

  def manipulate_latent(model, data, args):
        x_test, y_test = data
        index = np.argmax(y_test, 1) == args.digit
        print(index)
        number = np.random.randint(low=0, high=sum(index) - 1)
        x, y = x_test[index][number], y_test[index][number]
        x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
        noise = np.zeros([1, 3, 16])
        x_recons = []
        for dim in range(16):
            for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
                tmp = np.copy(noise)
                tmp[:,:,dim] = r
                x_recon = model.predict([x, y, tmp])
                x_recons.append(x_recon)
        x_recons = np.concatenate(x_recons)
        img = combine_images(x_recons, height=16)
        image = img*255
        Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)

输出为:

number=np.random.randint(低=0,高=sum(索引)-1) ValueError:低>;=高

函数调用:

model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
                                                  n_class=len(np.unique(np.argmax(y_train, 1))),
                                                  routings=args.routings)
manipulate_latent(manipulate_model, (x_test, y_test), args)

Tags: 数据testnumberdataindexmodelnpargs

热门问题