前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用人工神经网络预测急诊科患者幸存还是死亡

用人工神经网络预测急诊科患者幸存还是死亡

作者头像
用户2176511
发布2018-05-31 16:40:30
1.3K0
发布2018-05-31 16:40:30

引言

Apache Spark是一个基于集群的开源计算系统,主要用于处理非常大的数据集。并行计算和容错功能是Spark体系结构的内置功能。Spark Core是Spark的主要组件,并通过一组机器提供通用数据处理功能。基于Spark Core构建的其他组件带来更多功能,如机器学习。关于Apache Spark的全面介绍的文档已发布,请参阅Apache Spark官方文档Apache Spark简介Spark中的大数据处理Spark Streaming入门

本文重点介绍Spark MLlib库,它提供了用于实现机器学习和统计计算算法的应用程序接口(API)。我们将讨论因心脏病引起的急诊部(ED)死亡预测的例子,并将其作为二分类问题。我们将尝试用Spark MLlib Java API实现的人工神经网络(ANN)来解决这个问题。

在下一节中,我们将对这个问题进行解释并将其表示为二分类问题,然后描述如何利用ANN来解决这个问题。我们还会利用各种性能指标来评价最终预测结果的正确性。接下来,我们将讨论如何选择解决预测急诊科死亡问题的人工神经网络(ANN)。最后,我们将回顾Java代码并讨论本文的研究成果。

问题描述

国家卫生统计中心是美国卫生和人类服务部的一部分,定期发布国家医院门诊医疗调查(NHAMCS)结果,其中包括医院急诊科(ED)的患者统计数据。我们将根据患者各种特征(如年龄,基本生命测量指标和是否患有心肌梗塞,即心脏病发作)等,尝试预测急诊时因心脏病引起的死亡。

分类问题

简单地说,分类就是输入一组数据,判断系统的输出属于哪一个类别或种类的问题。解决分类问题的算法称为分类器。我们将在这里考虑的特定的分类问题描述如下。假设一个患者由于心脏问题而看了急诊,现在我们将尝试预测该患者在医院(ED或医院病房)是否会死亡。

这可以被表述为二分类问题,对于一组输入变量只可能有两个输出结果(因此称为二分类):患者要么幸存要么死亡。每个结果都是一个类别。每个类都由一个标签唯一标识,总结如下。

标签

说明

0

病人能够幸存下来,即不会在医院里死去。

1

患者在医院死亡(或者在急诊科,或者在急诊科统一之后转入病房)。

表1 标签说明。

每个输入变量称为一个特征。对于这里考虑的问题,特征解释如下。

特征

名称

说明

1

年龄重新编码

患者年龄分组:0 = 15岁以下,1 = 15-24岁,2 = 25-44岁,3 = 45-64岁,4 = 65-74岁,5 = 75-84岁,6 = 85-95岁,7 = 95岁以上

2

温度

体温在正常范围内,定义为97-99 F:0 =正常,1 =异常

3

脉搏血氧仪(百分比)

脉搏血氧饱和度在正常范围内,定义为95%-100%:0 =正常,1 =异常

4

舒张压

舒张压在正常范围内,定义为60-80 mm HG:0 =正常,1 =异常

5

收缩压

收缩压在正常范围内,定义为90-120 mm HG:0 =正常,1 =异常

6

呼吸频率

呼吸频率在正常范围内,定义为12-25次呼吸/分钟之间:0 =正常,1 =异常

7

脉冲

脉冲在正常范围内,定义为60-100次/分钟之间:0 =正常,1 =异常

8

是否有心脏病

患者是否被诊断为心脏病发作:0 =未被诊断为心脏病,1 =诊断为心脏病

表2 特征描述。

我们使用了NHAMCS急诊部提供的2010年,2011年和2012年的公用微型数据文件,它们可从官方下载网站获取。这些都是固定长度的ASCII文件,每行数据都属于一个单独的患者。上述提及的特征在数据文件中都有固定的位置。我们通过为95岁以上的患者增加一个年龄组来扩展年龄分组记录。(在年龄记录的初始定义中,第6组涵盖所有85岁或以上的患者)。我们将抵达急诊科(ED)后死亡的患者排除在外。在每个病人的数据文件中,最多有三个诊断记录。由于我们只考虑那些由于心脏问题而到急诊科(ED)就诊过的患者,因此我们要求诊断记录中至少有一项的ICD9代码在410 - 414之间。(这些ICD9代码及其扩展码涵盖冠状动脉疾病的所有诊断。)否则,丢弃患者记录。最终的数据文件有915例(行),其中888例存活(第 0 类),27例死亡(第1类)。

对于是否患有心脏病,我们继续如下处理。如果三个诊断中的任何一个具有ICD9代码410或其扩展码之一,即410.0-410.9(急性心肌梗塞),则我们认为存在心脏病,反之没有。

人工神经网络

人工神经网络是一种具有多种的科学和技术应用的数学模型。特别地,人工神经网络可以用来解决上面介绍的分类问题。人工神经网络有很多中。多层感知器是一种特殊类型的人工神经网络。Spark MLlib库为建立在多层感知器上的称为多层感知分类器(MLPC)的分类提供了一个API。在我们的例子中将要用到的多层感知分类器(MLPC)由多个输入和一个单独的输出组成,示意图如下图所示。

图1. MLPC的示意图。分类器的输入对应特征,其输出对应标签。

每个圆圈代表一个神经元,它是一个计算单位,即数学函数,它接受输入(输入箭头)并产生输出(输出箭头)。每个计算单元中的数学函数的模型已经确定,但是函数中各种参数的初始值未确定。在我们的例子中,数学函数使得对于任何输入来说,输出是0或1(受到近似值的影响,这实际上没有任何意义的)。ANN的实现死亡预测的思想是基于一组已知的输入(特征)和相应的输出(标签)来“训练”该ANN以确定数学函数中的参数。一旦人工神经网络得到训练,就应该学习原系统的行为,以便有新的(没有用于训练的)输入时,人工神经网络应该产生与原系统相同的输出。

所谓的“隐藏”层的名字由此而来,因为这些隐层的数量与特征或标签的数量无直接关联。每一层中计算单元的数目可以不同。随着层数和计算单元数量的增加,通过训练确定的参数数量也增加。参数越多,所训练的ANN就越灵活,能更好地学习原系统的行为。另一方面,Hastie 等人指出,当计算单元的数量增加超过一定的限制,人工神经网络开始出现过拟合,即在除训练集之外的数据上没有较好的泛化能力。该参考文献还指出“隐层数的选择取决于背景知识和实验。”许多研究将ANN用于诊断预测目的的医学科学,例如中风诊断肺癌检测。除了医学科学,人工神经网络还有许多其他应用,例如一般决策功能。

性能评价

训练完模型后,我们应该能够针对测试数据定量测量其性能,测试数据和训练数据是分开的。然后,在不同的模型中,我们选择对测试数据具有最佳性能的模型。下面我们讨论将混淆矩阵精度以及召回率作为性能指标。

混淆矩阵

在二分类中,混淆矩阵是一个每项都为非负整数的2*2的矩阵。第一行和第二行分别代表标签0和1。第一列和第二列分别表示预测的标签0和1。对于特定的某一行,所有列的数字的和就是数据集中某个特定标签的实例的数量。对于特定的列来说,所有行的数字的和为模型预测的某个特定标签的次数。举个例子,考虑下面的混淆矩阵。

69

3

4

1

在数据集中,标签为0的实例有72(= 69 + 3)个,标签为1的实例有5(= 4 + 1)个。即,72个患者存活,5个患者死亡。该模型正确预测了69名幸存的患者,然而,它错误地将3名幸存的病人预测为死亡。另一方面,该模型正确地预测了1例死亡患者,但是错误地将4名死亡的患者预测为幸存。

准确率和召回率

标签的精确率是正确预测为某个标签的次数除以任何标签被预测为此标签的次数。标签的召回率(又名灵敏度)是指正确预测为某个标签的次数除以某个标签的实例数。混淆矩阵可以用来计算准确率和召回率。在上例中,标签为0的准确率为69 /(69 + 4)= 0.945,标签为0的召回率为69 /(69 + 3)= 0.958。

精确度和召回率都是介于0和1之间的数字。当它们都接近1时,模型的性能就越好; 当它们其中任何一个接近0时,模型的性能就会下降。在最理想的情况下,当模型完美地预测每个标签时,混淆矩阵在非对角线上的项为0。

请注意,二分类是一个多分类问题的一个特例。混淆矩阵,准确率和召回率的定义都可以可以扩展到多分类,其中涉及问题多于两类。

解决方案

在本节中,我们总结了如何得到包含MLPC的最佳数学模型。

  1. 选择一组候选特征。
  2. 定义隐层的数量和每层中计算单元的数量。(从一个简单的模型开始。)
  3. 使用k重交叉验证技术获得基于候选特征的训练集数据和测试集数据。(将会有k个这样的对)对于每个这样的对,使用训练数据集训练一个不同的模型,并根据测试数据集测量其性能。
  4. 比较所有的模型并选择最佳性能的一个模型。
  5. 如果最佳性能模型的结果令人满意,则停止。除此以外:
    • 如果观察到模型的性能得到改善,则转到步骤3,通过增加具有更多计算单元和/或隐层数,增加模型的复杂度。
    • 如果模型的性能得到没有进一步的改进,则转到步骤1重新定义特征(全部重新开始)。

图2. 模型选择过程。

NHAMCS数据文件包含超过500项数据,包括患者人口统计数据,重要测量数据,诊断数据,慢性病症数据,家族病史数据以及患者访问的特定医院的统计数据。在根据领域知识去除大多数数据项后,我们最初确定了一组候选特征并且生成了一个LIBSVM格式的数据文件。这是机器学习应用中常用的格式。

我们从一个简单的模型开始,该模型有2个隐层,每层5个计算单元。我们应用k = 10的k重交叉验证来获得10对训练数据集和测试数据集。性能指标表明没有任何一个模型的的预测结果是成功的。特别是,有的模型未能预测死亡患者,即标签为1的召回率非常接近0。

然后,我们回到步骤3来增加模型复杂度并增加更多的计算单元,并且还增加了一个隐层。当预测结果仍不理想时,我们总结得出我们选择的特征不合适。然后我们返回到第1步查看是否能简化特征。(当使用ANN解决分类问题时,不相关的特征,即冗余数据,可能会降低预测不准和计算量过大的问题,如O'Dea 等人所说。)我们删除了步骤1中的一些特征,并再次循环步骤2-5,最终得到如上表2中的特征集,和分别由28个和25个计算单元组成的2个隐藏层的ANN。

代码回顾

我们的演示程序将说明如何使用Spark API开始 配置MLPC(即基于ANN的分类器),如下:

  • 初始化Spark配置和上下文。
  • 初始化一个SQL上下文,它是由行和列组成的结构化数据的基础; MLPC运行需要SQL。
代码语言:txt
复制
public class MultilayerPerceptronClassifierDemo {

    public static void main(String[] args) {
        // 设置应用程序名称
        String appName = "MultilayerPerceptronClassifier";
        // 初始化Spark配置和上下文
        SparkConf conf = new SparkConf().setAppName(appName)
                .setMaster("local[1]").set("spark.executor.memory", "1g");
        SparkContext sc = new SparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
  • 接下来,将数据随机分开用来做10重交叉验证。
  • 循环重复10次以下步骤:(i)获得训练和测试数据集(ii)训练模型和测量模型的性能。
  • 最后,停止Spark上下文。这就终止了主程序。
代码语言:txt
复制
        // 从Hadoop导入训练数据和测试数据文件并解析
        String path = "hdfs://localhost:9000/user/konur/ED2010_2011_2012_SVM.txt";
        JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path)
                .toJavaRDD();

        // 得到10组训练集和测试集数据. 用12345作为随机分割数据的种子.
        Tuple2<RDD<LabeledPoint>,RDD<LabeledPoint>>[] myTuple = MLUtils.kFold(data.rdd(), 10, 12345, data.classTag());

        // 对每组数据训练/验证算法一次.
        for(int i = 0; i < myTuple.length; i++){
            JavaRDD<LabeledPoint> trainingData = (new JavaRDD<LabeledPoint>(myTuple[i]._1,data.classTag())).cache();
            JavaRDD<LabeledPoint> testData = new JavaRDD<LabeledPoint>(myTuple[i]._2,data.classTag());
            kRun(trainingData,testData,sqlContext);
        }
        sc.stop();
    }

帮助程序kRun首先准备用于训练和测试的数据结构。然后,定义MLPC的结构。

代码语言:txt
复制
    private static final void kRun(JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testData, SQLContext sqlContext){
        DataFrame train = sqlContext.createDataFrame(trainingData, LabeledPoint.class);
        DataFrame test = sqlContext.createDataFrame(testData, LabeledPoint.class);
        // 输入有8个特征组成;两个隐层分别包含28和25个计算单元
        // 输出为二进制数字
        int[] layers = new int[] {8,  28, 25, 2};

然后我们定义训练器并获得训练好的模型。

代码语言:txt
复制
        // 定义训练器
        MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
          .setLayers(layers)
          .setBlockSize(128)
          .setSeed(1234L)
          .setMaxIter(150);
        // 得到训练后的模型
        MultilayerPerceptronClassificationModel model = trainer.fit(train);

至此,我们获得了我们的模型。接下来,我们将模型应用测试集数据,并计算测试集上的性能指标。这样就完成了kRunF方法的步骤。

代码语言:txt
复制
        // 将模型用于测试集并计算输出
        DataFrame testResult = model.transform(test);
        // 显示输出的性能指标
        displayConfusionMatrix(testResult.collect());
    }

现在让我们回顾帮助程序方法displayConfusionMatrix,该方法能够计算并显示性能指标,以各种变量定义开始。

代码语言:txt
复制
    private static final void displayConfusionMatrix(Row[] rows){
        // 标签0正确预测的次数
        int correctlyPredicted0 = 0;

        // 标签1正确预测的次数
        int correctlyPredicted1 = 0;

        // 将标签1错判为标签0的次数
        int wronglyPredicted0 = 0;

        // 将标签0错判为1的次数
        int wronglyPredicted1 = 0;

转方法的每一行输出对应于特定特使数据的一行,测试数据的第1列和第2列分别对应于实际标签和预测标签。我们遍历所有的行并增加相应的增量。

代码语言:txt
复制
        for(int i=0; i < rows.length; i++){
            Row row = rows[i];
            double label = row.getDouble(1);
            double prediction = row.getDouble(2);

            if(label == 0.0){
                if(prediction == 0.0){
                    correctlyPredicted0++;
                }else{
                    wronglyPredicted1++;
                }
            }else{
                if(prediction == 1.0){
                    correctlyPredicted1++;
                }else{
                    wronglyPredicted0++;
                }
            }
        }

最后显示混淆矩阵并计算标签0和1的准确率和召回率。

代码语言:txt
复制
        float fcorrectlyPredicted0 = correctlyPredicted0 * 1.0f;
        float fcorrectlyPredicted1 = correctlyPredicted1 * 1.0f;
        float fwronglyPredicted0 = wronglyPredicted0 * 1.0f;
        float fwronglyPredicted1 = wronglyPredicted1 * 1.0f;

        System.out.println("************");
        System.out.println(correctlyPredicted0 + "      " + wronglyPredicted1);
        System.out.println(wronglyPredicted0 + "      " + correctlyPredicted1);

        System.out.println("Class 0 precision: " + ((fcorrectlyPredicted0 == 0.0f)?0.0:(fcorrectlyPredicted0 / (fcorrectlyPredicted0 + fwronglyPredicted0))));
        System.out.println("Class 0 recall: " + ((fcorrectlyPredicted0 == 0.0f)?0.0:(fcorrectlyPredicted0 / (fcorrectlyPredicted0 + fwronglyPredicted1))));

        System.out.println("Class 1 precision: " + ((fcorrectlyPredicted1 == 0.0f)?0.0:(fcorrectlyPredicted1 / (fcorrectlyPredicted1 + fwronglyPredicted1))));
        System.out.println("Class 1 recall: " + ((fcorrectlyPredicted1 == 0.0f)?0.0:(fcorrectlyPredicted1 / (fcorrectlyPredicted1 + fwronglyPredicted0))));
        System.out.println("************");
    }

我们在Spark服务器上运行了上述代码,其中包含单个节点的2.7.1版本的Hadoop安装,和1.6.1版本的Spark API,这是撰写文章时的最新版本。完整的Java代码可以从https://github.com/kunyelio/Spark-MLPC下载。

结果讨论

让我们首先看看具有两个隐藏层并且每个隐层有5个计算单元的的模型在测试数据上的混淆矩阵,准确率和召回率。

70

2

4

1

  • 第0类准确率:0.946
  • 第0类召回率:0.972
  • 第1类准确率:0.333
  • 第1类召回率:0.2

尽管模型对第0类(患者存活)具有合理的性能,但对于1类(患者死亡),模型表现不佳。

接下来,让我们展示最佳模型在测试数据上的混淆矩阵,准确率和召回率。它有两个隐含层,分别由28和25个计算单元组成。

89

0

0

1

  • 第0类准确率:1.0
  • 第0类召回率:1.0
  • 第1类准确率:1.0
  • 第1类召回率:1.0

该模型性能绝佳,正确预测所有标签。我们观察到,通过增加计算单元的数量可以提高模型性能。

结论

在本文中,我们使用了Spark机器学习库中的人工神经网络(ANN)作为分类器来预测因心脏病导致的急诊科患者幸存还是死亡的问题。我们讨论了特征选择,选择网络隐层数和计算单元数量等高层次过程。基于这个过程,我们找到了一个在测试数据上取得了非常好的性能的模型。我们观察到Spark MLlib API简单易用,可用于训练分类器并计算其性能指标。参照Hastie等人,我们最终得出一些建议。

  • 当使用ANN作为分类器时,建议特征在数量级保持平衡。
    • 事实上,在我们的例子中,除年龄重新编码外以外的所有特征都是二进制的。年龄重新编码从一组离散的8个值中接受值,这个差异在可接受范围内。
  • 通常情况下,计算单元的数量在5 - 100“之间......随着输入数量和训练集数量的增加,计算单元的数量也增加。”
    • 在我们的例子中,最佳模型的计算单位数是53。
  • 随着计算单元数量的增加,训练模型需要更多的计算时间。
    • 在我们的例子中,对于每层有5个计算单元,共有2个隐层的初始简单模型,对模型进行训练,即MultilayerPerceptronClassificationModel model = trainer.fit(train); 每次拆分平均花费约4秒钟的时间。最终的模型分别有28个和25个计算单元的2个隐藏层,耗时6秒。正如所料,我们观察到计算时间增加了。(因为我们安装单节点Hadoop的Spark服务器,所以计算时间不应该推广到真实场景中。在集群模式下,计算时间比单个节点要小。)
评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 引言
  • 问题描述
    • 分类问题
    • 人工神经网络
      • 性能评价
        • 准确率和召回率
    • 解决方案
    • 代码回顾
    • 结果讨论
    • 结论
    相关产品与服务
    GPU 云服务器
    GPU 云服务器(Cloud GPU Service,GPU)是提供 GPU 算力的弹性计算服务,具有超强的并行计算能力,作为 IaaS 层的尖兵利器,服务于深度学习训练、科学计算、图形图像处理、视频编解码等场景。腾讯云随时提供触手可得的算力,有效缓解您的计算压力,提升业务效率与竞争力。
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档