前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow Lite Model Maker --- 图像分类篇+源码

Tensorflow Lite Model Maker --- 图像分类篇+源码

原创
作者头像
XianxinMao
修改2021-10-11 10:32:18
1.1K0
修改2021-10-11 10:32:18
举报
文章被收录于专栏:深度学习框架深度学习框架

TFLite_tutorials

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications. 解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署

下表罗列的是 TFLite Model Maker 目前支持的几个任务类型

Supported Tasks

Task Utility

Image Classification: tutorial, api

Classify images into predefined categories.

Object Detection: tutorial, api

Detect objects in real time.

Text Classification: tutorial, api

Classify text into predefined categories.

BERT Question Answer: tutorial, api

Find the answer in a certain context for a given question with BERT.

Audio Classification: tutorial, api

Classify audio into predefined categories.

Recommendation: demo, api

Recommend items based on the context information for on-device scenario.

If your tasks are not supported, please first use TensorFlow to retrain a TensorFlow model with transfer learning (following guides like images, text, audio) or train it from scratch, and then convert it to TensorFlow Lite model. 解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite

想用使用 Tensorflow Lite Model Maker 我们需要先安装:

代码语言:javascript
复制
pip install tflite-model-maker

本质完成的是分类任务 更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等

下面的代码片段是用于下载数据集的

代码语言:javascript
复制
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
img
img

数据集结构如下所示: flower_photos |__ daisy |__ 100080576_f52e8ee070_n.jpg |__ 14167534527_781ceb1b7a_n.jpg |__ ... |__ dandelion |__ 10043234166_e6dd915111_n.jpg |__ 1426682852_e62169221f_m.jpg |__ ... |__ roses |__ 102501987_3cdb8e5394_n.jpg |__ 14982802401_a3dfb22afb.jpg |__ ... |__ sunflowers |__ 12471791574_bb1be83df4.jpg |__ 15122112402_cafa41934f.jpg |__ ... |__ tulips |__ 13976522214_ccec508fe7.jpg |__ 14487943607_651e8062a1_m.jpg |__ ...

加载数据集并切分

代码语言:javascript
复制
data = DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
代码语言:javascript
复制
assert tf.__version__.startswith('2')

判断是否为 '2' 开头

模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律

img
img
img
img
代码语言:javascript
复制
import os
import time
​
import numpy as np
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
​
assert tf.__version__.startswith('2')
​
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
​
data = DataLoader.from_folder(image_path)
# data = data.gen_dataset(batch_size=1)
train_data, rest_data = data.split(0.8)
# for batch in data.take(1):
#     print(batch)
#     break
​
validation_data, test_data = rest_data.split(0.5)
​
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('efficientnet_lite0'), epochs=20)
​
loss, accuracy = model.evaluate(test_data)
​
model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))
​
start = time.time()
print(model.evaluate_tflite('./testTFlite/model.tflite', test_data))
end = time.time()
print('elapsed time: ', end - start)

从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0) 从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s

img
img
代码语言:javascript
复制
config = QuantizationConfig.for_float16()
model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))

如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多

代码语言:javascript
复制
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('mobilenet_v2'), epochs=20)

将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s

代码语言:javascript
复制
inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=inception_v3_spec, epochs=20)

将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s

img
img
代码语言:javascript
复制
Common Dataset used for tasks.
​
class DataLoader(object):
  """This class provides generic utilities for loading customized domain data that will be used later in model retraining.
​
  For different ML problems or tasks, such as image classification, text
  classification etc., a subclass is provided to handle task-specific data
  loading requirements.
  """
​
  def __init__(self, dataset, size):
    """Init function for class `DataLoader`.
​
    In most cases, one should use helper functions like `from_folder` to create
    an instance of this class.
​
    Args:
      dataset: A tf.data.Dataset object that contains a potentially large set of
        elements, where each element is a pair of (input_data, target). The
        `input_data` means the raw input data, like an image, a text etc., while
        the `target` means some ground truth of the raw input data, such as the
        classification label of the image etc.
      size: The size of the dataset. tf.data.Dataset donesn't support a function
        to get the length directly since it's lazy-loaded and may be infinite.
    """
    self._dataset = dataset
    self._size = size
​
  def gen_dataset(self,
                  batch_size=1,
                  is_training=False,
                  shuffle=False,
                  input_pipeline_context=None,
                  preprocess=None,
                  drop_remainder=False):
    """Generate a shared and batched tf.data.Dataset for training/evaluation.
img
img
代码语言:javascript
复制
Image dataloader
​
class ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
  """DataLoader for image classifier."""
​
  @classmethod
  def from_folder(cls, filename, shuffle=True):
    """Image analysis for image classification load images with labels.
​
    Assume the image data of the same label are in the same subdirectory.
​
    Args:
      filename: Name of the file.
      shuffle: boolean, if shuffle, random shuffle data.
​
    Returns:
      ImageDataset containing images and labels and other related info.
    """
   @classmethod
   def from_tfds(cls, name):
     """Loads data from tensorflow_datasets."""
img
img
代码语言:javascript
复制
ImageNet preprocessing
​
class Preprocessor(object):
  """Preprocessing for image classification."""
​
  def __init__(self,
               input_shape,
               num_classes,
               mean_rgb,
               stddev_rgb,
               use_augmentation=False):
    self.input_shape = input_shape
    self.num_classes = num_classes
    self.mean_rgb = mean_rgb
    self.stddev_rgb = stddev_rgb
    self.use_augmentation = use_augmentation
​
  def __call__(self, image, label, is_training=True):
    if self.use_augmentation:
      return self._preprocess_with_augmentation(image, label, is_training)
    return self._preprocess_without_augmentation(image, label)
​
  def _preprocess_with_augmentation(self, image, label, is_training):
    """Image preprocessing method with data augmentation."""
    image_size = self.input_shape[0]
    if is_training:
      image = preprocess_for_train(image, image_size)
    else:
      image = preprocess_for_eval(image, image_size)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
​
  # TODO(yuqili): Changes to preprocess to support batch input.
  def _preprocess_without_augmentation(self, image, label):
    """Image preprocessing method without data augmentation."""
    image = tf.cast(image, tf.float32)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    image = tf.compat.v1.image.resize(image, self.input_shape)
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
img
img
代码语言:javascript
复制
class ImageClassifier(classification_model.ClassificationModel):
  """ImageClassifier class for inference and exporting to tflite."""
​
  def __init__(self,
               model_spec,
               index_to_label,
               shuffle=True,
               hparams=hub_lib.get_default_hparams(),
               use_augmentation=False,
               representative_data=None):
    """Init function for ImageClassifier class.
​
    Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      hparams: A namedtuple of hyperparameters. This function expects
        .dropout_rate: The fraction of the input units to drop, used in dropout
          layer.
        .do_fine_tuning: If true, the Hub module is trained together with the
          classification layer on top.
      use_augmentation: Use data augmentation for preprocessing.
      representative_data:  Representative dataset for full integer
        quantization. Used when converting the keras model to the TFLite model
        with full interger quantization.
    """
    super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle,
                                          hparams.do_fine_tuning)
    num_classes = len(index_to_label)
    self._hparams = hparams
    self.preprocess = image_preprocessing.Preprocessor(
        self.model_spec.input_image_shape,
        num_classes,
        self.model_spec.mean_rgb,
        self.model_spec.stddev_rgb,
        use_augmentation=use_augmentation)
    self.history = None  # Training history that returns from `keras_model.fit`.
    self.representative_data = representative_data
​
  def _get_tflite_input_tensors(self, input_tensors):
    """Gets the input tensors for the TFLite model."""
    return input_tensors
​
  def create_model(self, hparams=None, with_loss_and_metrics=False):
    """Creates the classifier model for retraining."""
    hparams = self._get_hparams_or_default(hparams)
​
    module_layer = hub_loader.HubKerasLayerV1V2(
        self.model_spec.uri, trainable=hparams.do_fine_tuning)
    self.model = hub_lib.build_model(module_layer, hparams,
                                     self.model_spec.input_image_shape,
                                     self.num_classes)
    if with_loss_and_metrics:
      # Adds loss and metrics in the keras model.
      self.model.compile(
          loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
          metrics=['accuracy'])
img
img
代码语言:javascript
复制
Custom classification model that is already retained by data
​
class ClassificationModel(custom_model.CustomModel):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
                           ExportFormat.SAVED_MODEL, ExportFormat.TFJS)
​
  def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
    """Initialize a instance with data, deploy mode and other related parameters.
​
    Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top
        classification layer.
    """
    super(ClassificationModel, self).__init__(model_spec, shuffle)
    self.index_to_label = index_to_label
    self.num_classes = len(index_to_label)
    self.train_whole_model = train_whole_model
​
  def evaluate(self, data, batch_size=32):
    """Evaluates the model.
​
    Args:
      data: Data to be evaluated.
      batch_size: Number of samples per evaluation step.
​
    Returns:
      The loss value and accuracy.
    """
    ds = data.gen_dataset(
        batch_size, is_training=False, preprocess=self.preprocess)
    return self.model.evaluate(ds)
​
  def predict_top_k(self, data, k=1, batch_size=32):
    """Predicts the top-k predictions.
img
img
代码语言:javascript
复制
class CustomModel(abc.ABC):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
                           ExportFormat.TFJS)
​
  def __init__(self, model_spec, shuffle):
    """Initialize a instance with data, deploy mode and other related parameters.
​
    Args:
      model_spec: Specification for the model.
      shuffle: Whether the training data should be shuffled.
    """
    self.model_spec = model_spec
    self.shuffle = shuffle
    self.model = None
    # TODO(yuqili): remove this method once preprocess for image classifier is
    # also moved to DataLoader part.
    self.preprocess = None
​
  @abc.abstractmethod
  def train(self, train_data, validation_data=None, **kwargs):
    return
​
  def summary(self):
    self.model.summary()
​
  @abc.abstractmethod
  def evaluate(self, data, **kwargs):
    return
img
img
代码语言:javascript
复制
def export_tflite(model,
                  tflite_filepath,
                  quantization_config=None,
                  convert_from_saved_model_tf2=False,
                  preprocess=None,
                  supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)):
  """Converts the retrained model to tflite format and saves it.
​
  Args:
    model: model to be converted to tflite.
    tflite_filepath: File path to save tflite model.
    quantization_config: Configuration for post-training quantization.
    convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x.
    preprocess: A preprocess function to apply on the dataset.
        # TODO(wangtz): Remove when preprocess is split off from CustomModel.
    supported_ops: A list of supported ops in the converted TFLite file.
  """
  if tflite_filepath is None:
    raise ValueError(
        "TFLite filepath couldn't be None when exporting to tflite.")
​
  if compat.get_tf_behavior() == 1:
    lite = tf.compat.v1.lite
  else:
    lite = tf.lite
​
  convert_from_saved_model = (
      compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2)
  with _create_temp_dir(convert_from_saved_model) as temp_dir_name:
    if temp_dir_name:
      save_path = os.path.join(temp_dir_name, 'saved_model')
      model.save(save_path, include_optimizer=False, save_format='tf')
      converter = lite.TFLiteConverter.from_saved_model(save_path)
    else:
      converter = lite.TFLiteConverter.from_keras_model(model)
​
    if quantization_config:
      converter = quantization_config.get_converter_with_quantization(
          converter, preprocess=preprocess)
​
    converter.target_spec.supported_ops = supported_ops
    tflite_model = converter.convert()
​
  with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
    f.write(tflite_model)
​
​
def get_lite_runner(tflite_filepath, model_spec=None):
  """Gets `LiteRunner` from file path to TFLite model and `model_spec`."""
  # Gets the functions to handle the input & output indexes if exists.
  reorder_input_details_fn = None
  if hasattr(model_spec, 'reorder_input_details'):
    reorder_input_details_fn = model_spec.reorder_input_details
​
  reorder_output_details_fn = None
  if hasattr(model_spec, 'reorder_output_details'):
    reorder_output_details_fn = model_spec.reorder_output_details
​
  lite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn,
                           reorder_output_details_fn)
  return lite_runner

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • TFLite_tutorials
相关产品与服务
图像识别
腾讯云图像识别基于深度学习等人工智能技术,提供车辆,物体及场景等检测和识别服务, 已上线产品子功能包含车辆识别,商品识别,宠物识别,文件封识别等,更多功能接口敬请期待。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档