首先,我已经访问了这里的所有类似主题和其他网站,但没有一个在我的案例中起作用
假设我有一个类处理加载模型和预测:
class MyModel():
def __init__(self):
pass
def load_model(self, model_path):
self.model = tf.keras.models.load_model(model_path)
def predict(self, img):
return self.model.predict(img)
现在,我在另一个文件中有另一个类,它调用MyModel
:
from mymodel import MyModel
class MyDetector():
def __init__(self):
self.detector = MyModel()
self.detector.load_model('mymodel.h5')
def detect(self, img: numpy.ndarray):
return self.detector.predict(img)
但是,这会引发一个错误,即<tensor> is not an element of this graph
。我已经尝试了所有与tf.Graph.as_default()
相关的答案,但没有任何改变。最常见的建议是修改模型加载和预测部分,如下所示:
def load_model(self, model_path):
global model
model = tf.keras.models.load_model(model_path)
global graph
graph = tf.get_default_graph()
def predict(self, img):
with graph.as_default():
preds = model.predict(img)
return preds
这仍然没有帮助,因为所有其他的建议也可以在:https://github.com/keras-team/keras/issues/6462
我认为我的案例与那些已经解决类似案例的案例不同,因为我试图从一个完全不同的类文件调用模型类。我的Tensorflow版本是2.6.0。有没有更好的办法解决这个问题
更新
实际情况是,我正在使用gRPC与远程服务器通信以进行模型推断。为此,我使用了基于gRPC的非常简单的客户机-服务器通信。我的客户机代码定义如下(client.py
):
import cv2
import grpc
import pybase64
import protos.mydetector_pb2 as mydetector_pb2
import protos.mydetector_pb2_grpc as mydetector_pb2_grpc
# open a gRPC channel
channel = grpc.insecure_channel('[::]:50051')
stub = mydetector_pb2_grpc.MyDetectionServiceStub(channel)
img = cv2.imread('test.jpg')
retval, buffer = cv2.imencode('.jpg', img)
b64img = pybase64.b64encode(buffer)
print('\nSending single request to port 50051...')
request = mydetector_pb2.MyDetectionRequest(image=b64img)
response = stub.detect(request)
然后,在接收服务器端,主服务器实现如下(server.py
):
import grpc
from concurrent import futures
import protos.mydetector_pb2_grpc as reid_grpc
import MyDetectionService
MAX_MESSAGE_IN_MB = 10
options = [
('grpc.max_send_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024),
('grpc.max_receive_message_length', MAX_MESSAGE_IN_MB * 1024 * 1024)
]
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=options)
reid_grpc.add_MyDetectionServiceServicer_to_server(MyDetectionService(), server)
print('Starting server. Listening on port 50051.')
server.add_insecure_port('[::]:50051')
server.start()
MyDetectionService
类的实现如下:
import protos.mydetector_pb2 as mydetector
import protos.mydetector_pb2_grpc as mydetector_grpc
from mydetector.service.detector import MyDetector
from mydetector.utils import img_converter
import cv2
import numpy as np
class MyDetectionService(mydetector_grpc.MyDetectionServiceServicer):
def __init__(self):
self.detector = MyDetector()
def detect(self, request, context):
print('detecting on received image...')
encoded_img = request.image
img = img_converter(encoded_img)
img = cv2.resize(img, (240, 240))
img2 = np.expand_dims(img, axis=0)
result = self.detector.detect(img2)
return mydetector.MyDetectionResponse(ans=result)
其中,MyDetector
类的实现如上图所示
我发现,如果我不使用基于gRPC的服务器-客户端通信,而是从任何其他常规的外部类调用MyDetector
,那么一切都会顺利进行。但是,当我通过gRPC从客户端发送图像时,它成功地在MyDetector
类中加载了模型(我可以调用model.summary()
来获取模型的完整描述),但在detect
函数中失败
重要提示:根据可用信息here,我相信每次发送gRPC请求时,它都会使用自己的Tensorflow会话创建新线程,这就是这里的主要问题所在。然而,我仍然无法使它工作,即使在遵循了所有的说明在该网站上描述
服务器在不同的图形/会话中加载模型,而不是在用于接收客户端请求和响应的图形/会话中加载模型。将
MyModel
类修改为以下内容:我在https://colab.research.google.com/drive/1OaH7ZoAsY_V1sMUmc1NumWmJPnNr_54F?usp=sharing有一个问题意图的工作复制。colab成功地使用MyDetector预测MNIST图像
作为这个练习的一部分,我看到这里发生了一些事情:
此外,以下是一些想法:
tf.get_default_graph()在TensorFlow 2.6.0中不起作用。TF2.6有一个类似的TF.compat.v1.get\u default\u graph()。我强烈建议运行tf.version,以确认执行的代码实际上使用的是2.6.0
如果可以,将MyDetector添加到MyModel文件中。如果它能工作,那么您知道问题与某些代码位于单独的文件中这一事实有关,这可能有助于进行故障排除
基于以上所述,我建议在启用“急切执行”的情况下调试该问题,并查看是否能够以这种方式工作
相关问题 更多 >
编程相关推荐