当从另一个类调用时,Tensorflow会出现“<tensor>不是此图的元素”错误

2024-09-29 23:25:56 发布

您现在位置:Python中文网/ 问答频道 /正文

首先,我已经访问了这里的所有类似主题和其他网站,但没有一个在我的案例中起作用

假设我有一个类处理加载模型和预测:

class MyModel():
  def __init__(self):
    pass

  def load_model(self, model_path):
    self.model = tf.keras.models.load_model(model_path)

  def predict(self, img):
    return self.model.predict(img)

现在,我在另一个文件中有另一个类,它调用MyModel

from mymodel import MyModel
class MyDetector():
    def __init__(self):
        self.detector = MyModel()
        self.detector.load_model('mymodel.h5')

    def detect(self, img: numpy.ndarray):        
        return self.detector.predict(img)

但是,这会引发一个错误,即<tensor> is not an element of this graph。我已经尝试了所有与tf.Graph.as_default()相关的答案,但没有任何改变。最常见的建议是修改模型加载和预测部分,如下所示:

def load_model(self, model_path):
    global model
    model = tf.keras.models.load_model(model_path)
    global graph
    graph = tf.get_default_graph() 

def predict(self, img):
    with graph.as_default():
      preds = model.predict(img)
    return preds

这仍然没有帮助,因为所有其他的建议也可以在:https://github.com/keras-team/keras/issues/6462

我认为我的案例与那些已经解决类似案例的案例不同,因为我试图从一个完全不同的类文件调用模型类。我的Tensorflow版本是2.6.0。有没有更好的办法解决这个问题

更新

实际情况是,我正在使用gRPC与远程服务器通信以进行模型推断。为此,我使用了基于gRPC的非常简单的客户机-服务器通信。我的客户机代码定义如下(client.py):

import cv2
import grpc 
import pybase64

import protos.mydetector_pb2 as mydetector_pb2
import protos.mydetector_pb2_grpc as mydetector_pb2_grpc 

# open a gRPC channel
channel = grpc.insecure_channel('[::]:50051')
stub = mydetector_pb2_grpc.MyDetectionServiceStub(channel)

img = cv2.imread('test.jpg')
retval, buffer = cv2.imencode('.jpg', img)
b64img = pybase64.b64encode(buffer)

print('\nSending single request to port 50051...')
request = mydetector_pb2.MyDetectionRequest(image=b64img)

response = stub.detect(request)

然后,在接收服务器端,主服务器实现如下(server.py):

import grpc
from concurrent import futures
import protos.mydetector_pb2_grpc as reid_grpc
import MyDetectionService

MAX_MESSAGE_IN_MB = 10

options = [
    ('grpc.max_send_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024),
    ('grpc.max_receive_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024)
]

server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=options)
reid_grpc.add_MyDetectionServiceServicer_to_server(MyDetectionService(), server)

print('Starting server. Listening on port 50051.')
server.add_insecure_port('[::]:50051')
server.start()

MyDetectionService类的实现如下:

import protos.mydetector_pb2 as mydetector
import protos.mydetector_pb2_grpc as mydetector_grpc
from mydetector.service.detector import MyDetector
from mydetector.utils import img_converter

import cv2
import numpy as np

class MyDetectionService(mydetector_grpc.MyDetectionServiceServicer):
def __init__(self):
    self.detector = MyDetector()

def detect(self, request, context):
    print('detecting on received image...')
    encoded_img = request.image
    img = img_converter(encoded_img)
    img = cv2.resize(img, (240, 240))
    img2 = np.expand_dims(img, axis=0)
    result = self.detector.detect(img2)
    return mydetector.MyDetectionResponse(ans=result)

其中,MyDetector类的实现如上图所示

我发现,如果我不使用基于gRPC的服务器-客户端通信,而是从任何其他常规的外部类调用MyDetector,那么一切都会顺利进行。但是,当我通过gRPC从客户端发送图像时,它成功地在MyDetector类中加载了模型(我可以调用model.summary()来获取模型的完整描述),但在detect函数中失败

重要提示:根据可用信息here,我相信每次发送gRPC请求时,它都会使用自己的Tensorflow会话创建新线程,这就是这里的主要问题所在。然而,我仍然无法使它工作,即使在遵循了所有的说明在该网站上描述


Tags: 模型importselfimggrpcmodelserverdef
2条回答

服务器在不同的图形/会话中加载模型,而不是在用于接收客户端请求和响应的图形/会话中加载模型。将MyModel类修改为以下内容:

class MyModel():
  def __init__(self):
    pass

  def load_model(self, model_path):
    self.graph = tf.compat.v1.get_default_graph()

    with self.graph.as_default():
      self.model = tf.keras.models.load_model(model_path)

    self.sess = tf.compat.v1.keras.backend.get_session()

  def predict(self, img):
    with self.graph.as_default():
      try:
        preds= self.model.predict(img)
      except tf.errors.FailedPreconditionError:
        tf.compat.v1.keras.backend.set_session(self.sess)
        preds= self.model.predict(img)

    return preds

我在https://colab.research.google.com/drive/1OaH7ZoAsY_V1sMUmc1NumWmJPnNr_54F?usp=sharing有一个问题意图的工作复制。colab成功地使用MyDetector预测MNIST图像

作为这个练习的一部分,我看到这里发生了一些事情:

  1. 未定义MyModel.model\u路径。即使在MyModel.load\u model中提供了model\u path作为参数,但它是未使用的。换言之,我猜在load_模型部分或问题描述中有一个输入错误

此外,以下是一些想法:

  1. tf.get_default_graph()在TensorFlow 2.6.0中不起作用。TF2.6有一个类似的TF.compat.v1.get\u default\u graph()。我强烈建议运行tf.version,以确认执行的代码实际上使用的是2.6.0

  2. 如果可以,将MyDetector添加到MyModel文件中。如果它能工作,那么您知道问题与某些代码位于单独的文件中这一事实有关,这可能有助于进行故障排除

基于以上所述,我建议在启用“急切执行”的情况下调试该问题,并查看是否能够以这种方式工作

相关问题 更多 >

    热门问题