有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

python tensorflow java api错误:java。lang.IllegalStateException:张量不是标量

我正在尝试将一个经过预训练的模型(使用python)加载到java项目中

问题是

 Exception in thread "Thread-9" java.lang.IllegalStateException: Tensor is not a scalar
    at org.tensorflow.Tensor.scalarFloat(Native Method)
    at org.tensorflow.Tensor.floatValue(Tensor.java:279)

密码

    float[] arr=context.csvintarr(context.getPlayer(playerId));
    float[][] martix={arr};
    try (Graph g=model.graph()){
        try(Session s=model.session()){

            Tensor y=s.runner().feed("input/input", Tensor.create(martix))
            .fetch("out/predict").run().get(0);
            logger.info("a {}",y.floatValue());
        }
    }

用于训练和保存模型的python代码

with tf.Session() as sess:
    with tf.name_scope('input'):
        x=tf.placeholder(tf.float32,[None,bucketlen],name="input")
......
    with tf.name_scope('out'):
        y=tf.tanh(tf.matmul(h,hW)+hb,name="predict")
    builder=tf.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(sess,['foo-tag'])
......after the train process

    builder.save()

看来我已经成功加载了模型和图表,因为

  try (Graph g=model.graph()){
        try(Session s=model.session()){
            Operation operation=g.operation("input/input");
            logger.info(operation.name());

        }
    }

成功打印出名称


共 (1) 个答案

  1. # 1 楼答案

    错误消息表明输出张量不是浮点值标量,因此可能是高维张量(向量、矩阵)

    你可以使用^{}或者特别是使用^{}来学习张量的形状。在Python代码中,这将对应于y.shape

    对于非标量,使用^{}获得浮点数组(对于向量),或浮点数组数组(对于矩阵)等

    例如:

    System.out.println(y);
    // If the above printed something like:
    // "FLOAT tensor with shape [1]"
    // then you can get the values using:
    float[] vector = y.copyTo(new float[1]);
    
    // If the shape was something like [2, 3]
    // then you can get the values using:
    float[][] matrix = y.copyTo(new float[2][3]);
    

    有关{}与{}与{}的更多信息,请参见{a4}

    希望有帮助