我训练了一个tensorflow模型来预测输入文本的下一个单词。我将它保存为一个.h5文件。
我可以在另一个python代码中使用该模型来预测word,如下所示:
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.models import load_model
model = load_model('model.h5')
model.compile(
loss = "categorical_crossentropy",
optimizer = "adam",
metrics = ["accuracy"]
)
data = open("dataset.txt").read()
corpus = data.lower().split("\n")
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)
seed_text = input()
sequence_text = tokenizer.texts_to_sequences([seed_text])[0]
padded_sequence = np.array(pad_sequences([sequence_text], maxlen = 11 -1))
predicted = np.argmax(model.predict(padded_sequence))
是否有一种方法可以直接使用颤振内部的模型,从TextField()获取输入,并按下按钮显示预测的单词??
发布于 2021-03-28 21:09:44
步骤
.tflite
模型。# https://www.tensorflow.org/lite/convert/#convert_a_savedmodel_recommended_
import tensorflow as tf
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
assets/
目录中。android/
assets/
model.tflite
ios/
lib/
pubspec.yaml
中dependencies:
flutter:
sdk: flutter
tflite: ^1.0.5
.
.
labels.txt
是包含类的文本文件:import 'package:tflite/tflite.dart';
.
.
.
class _MyAppState extends State<MyApp> {
. . .
@override
void initState() {
super.initState();
_loading = true;
loadModel().then((value) {
setState(() {
_loading = false;
});
});
}
classifyImage(File image) async {
var output = await Tflite.runModelOnImage(
path: image.path,
numResults: 2,
threshold: 0.5,
imageMean: 127.5,
imageStd: 127.5,
);
setState(() {
_loading = false;
_outputs = output;
});
}
loadModel() async {
await Tflite.loadModel(
model: "assets/model_unquant.tflite",
labels: "assets/labels.txt",
);
}
@override
void dispose() {
Tflite.close();
super.dispose();
}
. . .
}
SideNote
tflite插件不支持文本分类AFAIK,如果您想专门进行文本分类,我建议使用tflite_flutter
插件。下面是使用插件进行文本分类的文章的链接。
发布于 2021-03-28 20:34:03
不能在颤振中直接使用.h5文件。您需要将其转换为.tflite文件并使用该文件,或者创建REST .。
将其转换为.tflite文件是最简单的。有关更多细节,您可以查看以下文章:https://medium.com/analytics-vidhya/run-cnn-model-in-flutter-10c944cadcba
如果您想创建一个REST ,请查看本文:https://medium.com/analytics-vidhya/deploy-ml-models-using-flask-as-rest-api-and-access-via-flutter-app-7ce63d5c1f3b
https://stackoverflow.com/questions/66474583
复制相似问题