前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >Keras 在fit-generator中获取验证数据的y_true和y_preds

Keras 在fit-generator中获取验证数据的y_true和y_preds

作者头像
为为为什么
发布于 2022-08-05 02:40:22
发布于 2022-08-05 02:40:22
1.4K00
代码可运行
举报
文章被收录于专栏:又见苍岚又见苍岚
运行总次数:0
代码可运行

在Keras网络训练过程中,fit-generator为我们提供了很多便利。调用fit-generator时,每个epoch训练结束后会使用验证数据检测模型性能,Keras使用model.evaluate_generator提供该功能。然而我遇到了需要提取验证集y_pred的需求,在网上没有找到现有的功能实现方法,于是自己对源码进行了微调,实现了可配置提取验证集模型预测结果的功能,记录如下。

原理简介

通过查看源代码,发现Keras调用了model.evaluate_generator验证数据,该函数最终调用的是TensorFlow(我用的后端是tf)的TF_SessionRunCallable函数,封装得很死,功能是以数据为输入,输出模型预测的结果并与真实标签比较并计算评价函数得到结果。 过程中不保存、不返回预测结果,这部分没有办法修改,但可以在评价数据的同时对数据进行预测,得到结果并记录下来,传入到epoch_logs中,随后在回调函数的on_epoch_end中尽情使用

代码修改

  • Keras版本 2.2.4 其他版本不保证一定使用相同的方法,但大体思路不变

model.fit_generator

找到fit_generator函数定义位置,加入控制参数get_predict

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def fit_generator(self, generator,
                      steps_per_epoch=None,
                      epochs=1,
                      verbose=1,
                      callbacks=None,
                      validation_data=None,
                      validation_steps=None,
                      class_weight=None,
                      max_queue_size=10,
                      workers=1,
                      use_multiprocessing=False,
                      shuffle=True,
                      initial_epoch=0,
                      get_predict = False):    # 加入 get_predict
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
return training_generator.fit_generator(
            self, generator,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            verbose=verbose,
            callbacks=callbacks,
            validation_data=validation_data,
            validation_steps=validation_steps,
            class_weight=class_weight,
            max_queue_size=max_queue_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
            shuffle=shuffle,
            initial_epoch=initial_epoch,
            get_predict = get_predict) # 加入 get_predict

training_generator.fit_generator

找到training_generator.fit_generator定义位置,加入get_predict:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def fit_generator(model,
                  generator,
                  steps_per_epoch=None,
                  epochs=1,
                  verbose=1,
                  callbacks=None,
                  validation_data=None,
                  validation_steps=None,
                  class_weight=None,
                  max_queue_size=10,
                  workers=1,
                  use_multiprocessing=False,
                  shuffle=True,
                  initial_epoch=0,
                  get_predict = False): # 加入 get_predict

修改 # Epoch finished. 注释后的模块,可以看到Keras中fit_generator就是用model.evaluate_generator对验证集评估的:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Epoch finished.
if steps_done >= steps_per_epoch and do_validation:
    if val_gen:
        
        if get_predict:
            ## 如果启动获取预测结果功能,那么将get_predict设置为True
            ## 返回值会包括 gts_and_preds 
            val_outs, gts_and_preds = model.evaluate_generator(
                val_enqueuer_gen,
                validation_steps,
                workers=0,
                get_predict=get_predict)
        else:
            val_outs = model.evaluate_generator(
                val_enqueuer_gen,
                validation_steps,
                workers=0)                            
    else:
        # No need for try/except because
        # data has already been validated.
        val_outs = model.evaluate(
            val_x, val_y,
            batch_size=batch_size,
            sample_weight=val_sample_weights,
            verbose=0)
    val_outs = to_list(val_outs)
    # Same labels assumed.
    for l, o in zip(out_labels, val_outs):
        epoch_logs['val_' + l] = o
    
    ## 将返回值 gts_and_preds 保存到 log 中
    if get_predict:
        epoch_logs['val_gts_and_preds'] = gts_and_preds
        
if callback_model.stop_training:
    break

model.evaluate_generator

进入model.evaluate_generator函数,加入get_predict变量:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def evaluate_generator(self, generator,
                       steps=None,
                       max_queue_size=10,
                       workers=1,
                       use_multiprocessing=False,
                       verbose=0,
                       get_predict=False): # 加入get_predict变量
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
return training_generator.evaluate_generator(
            self, generator,
            steps=steps,
            max_queue_size=max_queue_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
            verbose=verbose,
            get_predict=get_predict) # 加入get_predict变量

training_generator.evaluate_generator

进入training_generator.evaluate_generator,添加get_predict变量,新建三个变量:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def evaluate_generator(model, generator,
                       steps=None,
                       max_queue_size=10,
                       workers=1,
                       use_multiprocessing=False,
                       verbose=0,
                       get_predict=False):  # 加入get_predict变量
    """See docstring for `Model.evaluate_generator`."""
    model._make_test_function()

    if hasattr(model, 'metrics'):
        for m in model.stateful_metric_functions:
            m.reset_states()
        stateful_metric_indices = [
            i for i, name in enumerate(model.metrics_names)
            if str(name) in model.stateful_metric_names]
    else:
        stateful_metric_indices = []

    steps_done = 0
    wait_time = 0.01
    outs_per_batch = []
    batch_sizes = []
      
    if get_predict:
        preds_dict={} # 新建保存结果的dict
        gt_per_batch = [] # 新建 y_true 的 list
        pr_per_batch = [] # 新建 y_pred 的 list

在核心循环while steps_done < steps:中加入预测变量的内容:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
while steps_done < steps:
            generator_output = next(output_generator)
            if not hasattr(generator_output, '__len__'):
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' +
                                 str(generator_output))
            if len(generator_output) == 2:
                x, y = generator_output
                sample_weight = None
            elif len(generator_output) == 3:
                x, y, sample_weight = generator_output
            else:
                raise ValueError('Output of generator should be a tuple '
                                 '(x, y, sample_weight) '
                                 'or (x, y). Found: ' +
                                 str(generator_output))
            outs = model.test_on_batch(x, y, sample_weight=sample_weight)
            outs = to_list(outs)
            outs_per_batch.append(outs)
            
            ## 加入预测功能,保存preds和y_true
            if get_predict:
                preds = model.predict_on_batch(x)
                gt_per_batch.append(y.tolist())
                pr_per_batch.append(preds.tolist())

            if x is None or len(x) == 0:
                # Handle data tensors support when no input given
                # step-size = 1 for data tensors
                batch_size = 1
            elif isinstance(x, list):
                batch_size = x[0].shape[0]
            elif isinstance(x, dict):
                batch_size = list(x.values())[0].shape[0]
            else:
                batch_size = x.shape[0]
            if batch_size == 0:
                raise ValueError('Received an empty batch. '
                                 'Batches should contain '
                                 'at least one item.')
            steps_done += 1
            batch_sizes.append(batch_size)
            if verbose == 1:
                progbar.update(steps_done)
        ## 将结果保存到dict中
        if get_predict:
            preds_dict['y_true'] = gt_per_batch
            preds_dict['y_pred'] = pr_per_batch

修改返回值:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
if get_predict:
    return unpack_singleton(averages), preds_dict

else:
    return unpack_singleton(averages)

至此核心的功能已经实现,但还有一个小问题。

keras.callbacks.TensorBoard._write_logs

Keras的Tensorboard会记录logs中的内容,但是他只认识 int, float 等数值格式,我们保存在log中的复杂字典他没办法写入tesnorboard,需要对_write_logs做微小的调整:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def _write_logs(self, logs, index):
        for name, value in logs.items():
            if name in ['batch', 'size']:
                continue
            summary = tf.Summary()
            summary_value = summary.value.add()
            if isinstance(value, np.ndarray):
                summary_value.simple_value = value.item()
            ## 跳过我们生成的字典
            elif isinstance(value, dict):
                pass
            else:
                summary_value.simple_value = value
            summary_value.tag = name
            self.writer.add_summary(summary, index)
        self.writer.flush()

大功告成!

测试

随便写个带on_epoch_end的回调函数,将get_predict设置为True,测试logs中是否有我们想要的数据:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model.fit_generator(
       generator = train_data_generator,
       steps_per_epoch = 10,
       epochs = config.Epochs,
       verbose = 1,
       use_multiprocessing=False,
       validation_data=val_data_generator,
       validation_steps=10,
       callbacks = callbacks,
       get_predict= True
   )      

回调函数设断点,输出logs:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
logs['val_gts_and_preds']
{'y_pred': [[[2.5419962184969336e-05, 0.9999746084213257],
             [0.6694663763046265, 0.33053362369537354],
             [0.3561754524707794, 0.643824577331543]],
            [[5.548826155499231e-12, 1.0],
             [2.701560219975363e-08, 1.0],
             [4.0011427699937485e-06, 0.9999959468841553]],
            [[7.97858723533551e-11, 1.0],
             [2.3924835659272503e-06, 0.999997615814209],
             [3.359668880875688e-07, 0.9999996423721313]],
            [[0.06622887402772903, 0.9337711930274963],
             [4.1211248458239425e-07, 0.9999996423721313],
             [8.561290087527595e-06, 0.9999914169311523]],
            [[9.313887403550325e-07, 0.9999990463256836],
             [2.614793537247806e-08, 1.0],
             [8.66139725985704e-06, 0.9999912977218628]],
            [[7.047830763440288e-09, 1.0],
             [0.010548637248575687, 0.9894513487815857],
             [1.8744471252940542e-10, 1.0]],
            [[8.760089875714527e-11, 1.0],
             [0.0015734446933493018, 0.9984265565872192],
             [1.5642463040421717e-06, 0.9999984502792358]],
            [[0.004750440828502178, 0.9952495098114014],
             [6.984401466070267e-07, 0.9999992847442627],
             [0.00013592069444712251, 0.9998641014099121]],
            [[7.22906318140204e-11, 1.0],
             [2.402198795437016e-08, 1.0],
             [9.673745138272238e-10, 1.0]],
            [[3.1848256298872e-07, 0.9999996423721313],
             [0.0035940599627792835, 0.9964058995246887],
             [1.9458911912351162e-11, 1.0]]],
 'y_true': [[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
            [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]]}

之后这些结果任君处置了;

get_predict设为 False 时则屏蔽了我们做出的所有修改,与原始Keras代码完全相同; 目前没有发现其他的问题,有任何不对头可以随时交流。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020年6月10日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
想用Python爬小姐姐图片?那你得先搞定分布式进程
导读:分布式进程指的是将Process进程分布到多台机器上,充分利用多台机器的性能完成复杂的任务。我们可以将这一点应用到分布式爬虫的开发中。
IT阅读排行榜
2019/06/18
4640
想用Python爬小姐姐图片?那你得先搞定分布式进程
python爬虫 | 一文搞懂分布式进程爬虫
今天咱们来扯一扯分布式进程爬虫,对爬虫有所了解的都知道分布式爬虫这个东东,今天我们来搞懂一下分布式这个概念,从字面上看就是分开来布置,确实如此它是可以分开来运作的。
Python数据科学
2019/06/10
7490
一篇文章带你了解Python的分布式进程接口
在Thread和Process中,应当优选Process,因为Process更稳定,而且,Process可以分布到多台机器上,而Thread最多只能分布到同一台机器的多个CPU上。
Go进阶者
2021/05/24
3520
一篇文章带你了解Python的分布式进程接口
python网络爬虫(2)回顾Python编程
把内存中的数据变为可保存和共享,实现状态保存。cPickle使用C语言编写,效率高,优先使用。如果不存在则使用pickle。pickle使用dump和dumps实现序列化。
嘘、小点声
2019/07/31
6520
python 性能的优化
NumPy的创始人Travis,创建了CONTINUUM,致力于将Python大数据处理方面的应用。 推出的Numba项目能够将处理NumPy数组的Python函数JIT编译为==机器码执行==,从而上百倍的提高程序的运算速度。
Tim在路上
2020/08/04
1.1K0
Python Windows下分布式进程的坑(分布式进程的一个简单例子)
下面这个例子基于”廖雪峰的Python教程:分布式进程”原例在Linux上运行,直接在Windows上运行会出现错误,下面是针对原例进行的改进,使之能成功运行。 https://www.liaoxuefeng.com/wiki/0014316089557264a6b348958f449949df42a6d3a2e542c000/001431929340191970154d52b9d484b88a7b343708fcc60000#0 博主也对代码注释作了更精确的改进。 原例在Wi
Steve Wang
2018/02/05
2.2K0
Python学习笔记(十)·进程和线程
很多同学都听说过,现代操作系统比如Mac OS X,UNIX,Linux,Windows等,都是支持“多任务”的操作系统。
公爵
2022/09/28
5340
Python学习笔记(十)·进程和线程
Python3.6学习笔记(四)
程序运行中,可能会遇到BUG、用户输入异常数据以及其它环境的异常,这些都需要程序猿进行处理。Python提供了一套内置的异常处理机制,供程序猿使用,同时PDB提供了调试代码的功能,除此之外,程序猿还应该掌握测试的编写,确保程序的运行符合预期。
大江小浪
2018/07/24
7760
python网络爬虫(3)python爬虫遇到的各种问题(python版本、进程等)
Python3中,import cookielib改成 import http.cookiejar
嘘、小点声
2019/07/31
1.3K0
Python学习—pyhton中的进程
进程: 进程就是一个程序在一个数据集上的一次动态执行过程。进程一般由程序、数据、进程控制块(pcb)三部分组成。 (1)我们编写的程序用来描述进程要完成哪些功能以及如何完成; (2)数据则是程序在执行过程中所需要使用的资源; (3)进程控制块用来记录进程的所有信息。系统可以利用它来控制和管理进程,它是系统感知进程存在的唯一标志。
py3study
2020/01/06
5630
7-并发编程
对于CPU计算密集型的任务,python的多线程跟单线程没什么区别,甚至有可能会更慢,但是对于IO密集型的任务,比如http请求这类任务,python的多线程还是有用处。在日常的使用中,经常会结合多线程和队列一起使用,比如,以爬取simpledestops 网站壁纸为例:
py3study
2020/01/02
3450
Python 分布式进程Master
from multiprocessing.managers import BaseManager
py3study
2020/01/15
4640
Python使用Manager对象实现不同机器上的进程跨网络传输数据
本文主要演示不同机器上的进程之间如何通过网络进行数据交换。 (1)首先编写程序文件multiprocessing_server.py,启动服务器进程,创建可共享的队列对象。 from multiprocessing.managers import BaseManager from queue import Queue q = Queue() class QueueManager(BaseManager): pass QueueManager.register('get_queue', callable=l
Python小屋屋主
2018/04/16
1.9K0
python网络爬虫(10)分布式爬虫爬取静态数据
爬虫应该能够快速高效的完成数据爬取和分析任务。使用多个进程协同完成一个任务,提高了数据爬取的效率。
嘘、小点声
2019/07/31
6200
Python多进程并行编程实践:以multiprocessing模块为例
專 欄 ❈Pytlab,Python 中文社区专栏作者。主要从事科学计算与高性能计算领域的应用,主要语言为Python,C,C++。熟悉数值算法(最优化方法,蒙特卡洛算法等)与并行化 算法(MPI,OpenMP等多线程以及多进程并行化)以及python优化方法,经常使用C++给python写扩展。 blog:http://ipytlab.com github:https://github.com/PytLab ❈— 前言 并行计算是使用并行计算机来减少单个计算问题所需要的时间,我们可以通过利用编程语言显
Python中文社区
2018/01/31
2.7K0
Python多进程并行编程实践:以multiprocessing模块为例
Python3 与 C# 并发编程之~ 进程实战篇
之前说过 Queue:在 Process之间使用没问题,用到 Pool,就使用 Manager().xxx, Value和 Array,就不太一样了:
逸鹏
2018/09/07
9480
Python3 与 C# 并发编程之~ 进程实战篇
《Python分布式计算》 第3章 Python的并行计算 (Distributed Computing with Python)多线程多进程多进程队列一些思考总结
我们在前两章提到了线程、进程,还有并发编程。我们在很高的层次,用抽象的名词,讲了如何组织代码,已让其部分并发运行,在多个CPU上或在多台机器上。 本章中,我们会更细致的学习Python是如何使用多个CPU进行并发编程的。具体目标是加速CPU密集型任务,提高I/O密集型任务的反馈性。 好消息是,使用Python的标准库就可以进行并发编程。这不是说不用第三方的库或工具。只是本章中的代码仅仅利用到了Python的标准库。 本章介绍如下内容: 多线程 多进程 多进程队列 多线程 Python从1.4版本开始就支持多
SeanCheney
2018/04/24
1.6K0
《Python分布式计算》 第3章 Python的并行计算 (Distributed Computing with Python)多线程多进程多进程队列一些思考总结
一篇文章搞定Python多进程(全)
前面写了三篇关于python多线程的文章,大概概况了多线程使用中的方法,文章链接如下:
南山烟雨
2019/05/05
6450
一篇文章搞定Python多进程(全)
Python 多进程
上面的代码开启了5个子进程去执行函数,我们可以观察结果,是同时打印的,这里实现了真正的并行操作,就是多个CPU同时执行任务。我们知道进程是python中最小的资源分配单元,也就是进程中间的数据,内存是不共享的,每启动一个进程,都要独立分配资源和拷贝访问的数据,所以进程的启动和销毁的代价是比较大了,所以在实际中使用多进程,要根据服务器的配置来设定。
为为为什么
2022/08/05
3920
一篇文章搞定Python多进程
Python中的多进程是通过multiprocessing包来实现的,和多线程的threading.Thread差不多,它可以利用multiprocessing.Process对象来创建一个进程对象。这个进程对象的方法和线程对象的方法差不多也有start(), run(), join()等方法,其中有一个方法不同Thread线程对象中的守护线程方法是setDeamon,而Process进程对象的守护进程是通过设置daemon属性来完成的。
程序员鑫港
2022/01/05
5750
推荐阅读
相关推荐
想用Python爬小姐姐图片?那你得先搞定分布式进程
更多 >
LV.0
这个人很懒,什么都没有留下~
目录
  • 原理简介
  • 代码修改
    • model.fit_generator
    • training_generator.fit_generator
    • model.evaluate_generator
    • training_generator.evaluate_generator
    • keras.callbacks.TensorBoard._write_logs
  • 测试
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档