python我的Tflite模型在jupyter笔记本中运行良好。然而,当我试图在Android studio中运行它时,它向我显示了错误的输出
因此,我根据患者输入的症状创建了用于疾病识别的自定义模型。它在python上运行得非常好,但是在我将其转换为tflite并在安卓 studio上运行之后,无论输入值如何,它都会给我错误的预测。 我已经在python上测试了tflite模型,它运行良好
以下代码显示了我如何将模型转换为tflite:-
from tensorflow.keras.models import load_model
saved_model= load_model('Disease-predictor_ANN.h5')
#Convert to tflite model
from tensorflow import lite
converter= lite.TFLiteConverter.from_keras_model(saved_model)
tflite_model= converter.convert()
#Save the model
with open('dpredictor.tflite','wb') as f:
f.write(tflite_model)
以下代码显示了我在安卓 studio中实现tflite模型的活动文件:-
package com.example.prototype_idoc;
import 安卓x.annotation.NonNull;
import 安卓x.appcompat.app.AppCompatActivity;
import 安卓.app.Activity;
import 安卓.content.Context;
import 安卓.content.pm.PackageManager;
import 安卓.content.res.AssetFileDescriptor;
import 安卓.content.res.AssetManager;
import 安卓.database.sqlite.SQLiteOpenHelper;
import 安卓.os.Bundle;
import 安卓.util.Log;
import 安卓.view.View;
import 安卓.widget.ArrayAdapter;
import 安卓.widget.Button;
import 安卓.widget.MultiAutoCompleteTextView;
import 安卓.widget.TextView;
import 安卓.widget.Toast;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
import java.io.FileInputStream;
import java.io.IOException;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
public class diseaseDiagnosis extends AppCompatActivity {
private MultiAutoCompleteTextView multiAutoCompleteTextView;
private Button button;
private TextView textView;
public diseaseDiagnosis() throws PackageManager.NameNotFoundException {
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_disease_diagnosis);
MapdatabaseHelper mapdatabaseHelper= new MapdatabaseHelper(this);
try {
mapdatabaseHelper.createDatabase();
mapdatabaseHelper.openDatabase();
} catch (IOException e) {
e.printStackTrace();
}
ArrayList<String> symptomList= (ArrayList<String>) mapdatabaseHelper.getEverySymptom(this);
ArrayList<String> diseaseList= (ArrayList<String>) mapdatabaseHelper.getEveryDisease(this);
String[] symptoms= symptomList.toArray(new String[symptomList.size()]);
String[] diseases= diseaseList.toArray(new String[diseaseList.size()]);
mapdatabaseHelper.close();
multiAutoCompleteTextView= findViewById(R.id.multiAutoView);
button= findViewById(R.id.checkDisease);
textView= findViewById(R.id.diseaseName);
ArrayAdapter<String> arrayAdapter= new ArrayAdapter<String>(this, 安卓.R.layout.simple_list_item_1,symptoms);
multiAutoCompleteTextView.setAdapter(arrayAdapter);
multiAutoCompleteTextView.setThreshold(1);
multiAutoCompleteTextView.setTokenizer(new MultiAutoCompleteTextView.CommaTokenizer());
button.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
float[] model_input= new float[132];
Arrays.fill(model_input, new Float(0));
String itemsSelected= multiAutoCompleteTextView.getText().toString();
int startIndex=0;
for(int i=0;i<itemsSelected.length();i++){
if(itemsSelected.charAt(i)==','){
String s=itemsSelected.substring(startIndex,i);
int currIndex=symptomList.indexOf(s);
Log.v("CURR INDEX:",s+ ":"+currIndex);
model_input[currIndex]= 1.0F;
startIndex=i+2;
}
}
Context context = null;
try {
context = createPackageContext("com.example.prototype_idoc", 0);
} catch (PackageManager.NameNotFoundException e) {
e.printStackTrace();
}
AssetManager assetManager = context.getAssets();
Interpreter interpreter=null;
try {
interpreter= new Interpreter(loadModelFile(assetManager,"dpredictor.tflite"));
} catch (IOException e) {
e.printStackTrace();
}
ByteBuffer inputBuffer=ByteBuffer.allocateDirect(132*4).order(ByteOrder.nativeOrder());
inputBuffer.rewind();
for(int i=0;i<model_input.length;i++){
inputBuffer.putFloat(model_input[i]);
}
float[] outputArray= new float[41];
TensorBuffer outputBuffer= TensorBuffer.createFixedSize(new int[]{1,41},DataType.FLOAT32);
interpreter.run(inputBuffer,outputBuffer.getBuffer());
outputArray= outputBuffer.getFloatArray();
//
int modelOutputIndex= 0;
float maxNum= outputArray[0];
for(int i=0;i<outputArray.length;i++){
if(maxNum>outputArray[i]){
modelOutputIndex=i;
maxNum=outputArray[i];
}
}
Log.v("TAG:", "onClick: "+ outputArray[15] );
Log.v("TAG:", "onClick: "+ outputArray[25] );
textView.setText(diseases[modelOutputIndex]);
textView.setVisibility(View.VISIBLE);
}
});
}
private MappedByteBuffer loadModelFile(AssetManager assetManager, String modelPath) throws IOException {
AssetFileDescriptor fileDescriptor = assetManager.openFd(modelPath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
共 (0) 个答案