前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >写给开发者的机器学习指南(十)

写给开发者的机器学习指南(十)

作者头像
哒呵呵
发布2018-08-06 17:47:49
3390
发布2018-08-06 17:47:49
举报
文章被收录于专栏:鸿的学习笔记鸿的学习笔记

An attempt at rank prediction for topselling books using text regression

在基于高度和性别预测权重的示例中,我们引入了线性回归的概念。但是,有时人们会想要对非数字数据(如文本)应用回归。在这个例子中,我们将展示如何通过试图预测O'Reilly的前100本销售书籍来完成文本回归。 此外,通过此示例,我们还将显示,对于这种特殊情况,使用文本回归是无效的。 原因是数据根本不包含我们的测试数据的信号。然而,这不会使此示例无用,因为在实际使用的数据中可能存在实际信号,然后可以使用此处解释的文本回归检测。

我们在这个例子中使用的数据文件可以在这里下载。 除了Smile库之外,在这个例子中,我们还将使用Scala-csv库处理csv包含逗号的字符串。让我们从获取我们需要的数据开始:

代码语言:javascript
复制
object TextRegression {
  def main(args:Array[String]): Unit = {
    //Get theexample data
      val basePath= "/users/.../TextRegression_Example_4.csv"
      val testData= getDataFromCSV(new File(basePath))
  }
  defgetDataFromCSV(file: File) : List[(String,Int,String)]= {
    val reader =CSVReader.open(file)
    val data =reader.all()
    val documents =data.drop(1).map(x => (x(1),x(3)toInt,x(4)))
    returndocuments
  }
}

我们现在有O'Reilly的前100名销售书的标题,等级和长长的描述。然而,当我们想做某种形式的回归时,我们需要数值数据。 这就是为什么我们将构建一个文档术语矩阵(DTM)。 请注意,此DTM类似于我们在垃圾邮件分类示例中构建的术语文档矩阵(TDM)。 它的不同之处在于,我们存储包含该文档中的术语的文档记录,与存储词语的记录的TDM相反,其中包含该词语可用的文档的列表。我们自己实现了如下:

代码语言:javascript
复制
import java.io.File
import scala.collection.mutable
class DTM {
  var records:List[DTMRecord] = List[DTMRecord]()
  var wordList:List[String] = List[String]()
  defaddDocumentToRecords(documentName: String, rank: Int, documentContent: String)= {
    //Find a recordfor the document
    val record =records.find(x => x.document == documentName)
    if(record.nonEmpty) {
      throw newException("Document already exists in the records")
    }
    var wordRecords= mutable.HashMap[String, Int]()
    valindividualWords = documentContent.toLowerCase.split(" ")
   individualWords.foreach { x =>
      valwordRecord = wordRecords.find(y => y._1 == x)
      if(wordRecord.nonEmpty) {
        wordRecords+= x -> (wordRecord.get._2 + 1)
      }
      else {
        wordRecords+= x -> 1
        wordList =x :: wordList
      }
    }
    records = newDTMRecord(documentName, rank, wordRecords) :: records
  }
  defgetStopWords(): List[String] = {
    val source =scala.io.Source.fromFile(newFile("/Users/.../stopwords.txt"))("latin1")
    val lines =source.mkString.split("\n")
    source.close()
    returnlines.toList
  }
  defgetNumericRepresentationForRecords(): (Array[Array[Double]], Array[Double]) = {
    //First filterout all stop words:
    val StopWords =getStopWords()
    wordList =wordList.filter(x => !StopWords.contains(x))
    var dtmNumeric= Array[Array[Double]]()
    var ranks =Array[Double]()
    records.foreach{ x =>
      //Add therank to the array of ranks
      ranks = ranks:+ x.rank.toDouble
      //And createan array representing all words and their occurrences 
      //for thisdocument:
      vardtmNumericRecord: Array[Double] = Array()
     wordList.foreach { y =>
        valtermRecord = x.occurrences.find(z => z._1 == y)
        if(termRecord.nonEmpty) {
         dtmNumericRecord = dtmNumericRecord :+ termRecord.get._2.toDouble
        }
        else {
         dtmNumericRecord = dtmNumericRecord :+ 0.0
        }
      }
      dtmNumeric =dtmNumeric :+ dtmNumericRecord
    }
    return(dtmNumeric, ranks)
  }
}
class DTMRecord(val document : String,
                valrank : Int,
                varoccurrences : mutable.HashMap[String,Int]
                )

如果你注意一下这个实现,你会看到有一个方法叫做getgetNumericRepresentationForRecords():(Array [Array[Double]],Array [Double])。 此方法返回一个以第一个参数为一个元组的矩阵,其中每行代表一个文档,每个列代表DTM文档的完整词汇表中的一个单词。 注意,第一个表中的双精度表示单词的出现次数。第二个参数是包含属于来自第一个表的记录的所有等级的数组。我们现在可以扩展我们的主代码,使得我们得到所有文档的数字表示如下:

代码语言:javascript
复制
val documen3)tTermMatrix = new DTM()
testData.foreach(x =>documentTermMatrix.addDocumentToRecords(x._1,x._2,x._)

通过从文本到数值的转换,我们可以打开我们的回归工具箱了。我们在预测基于身高的体重的示例中使用了普通最小二乘法(OLS),但是这次我们将使用最小绝对收缩和选择算子(Lasso)回归。 这是因为我们可以给这个回归方法一个特定的lambda,代表一个惩罚值。 该惩罚值允许LASSO算法选择相关特征(字),同时丢弃一些其他特征(字)。

在我们的案例中,Lasso执行的这个特征选择非常有用,因为文档描述中使用了大量的词。 Lasso将尝试使用这些单词的理想子集作为特征,而当应用OLS时,将使用所有单词,并且运行时间将是非常长的。此外,smile的OLS实现检测出秩很低。 这是维度诅咒之一。

然而,我们需要找到一个最佳的lambda,因此,我们应该尝试使用交叉验证几个lambda。 我们将这样做:

代码语言:javascript
复制
for (i <- 0 until cv.k) {
      //Split offthe training datapoints and classifiers from the dataset
      valdpForTraining = numericDTM
        ._1
        .zipWithIndex
        .filter(x=> cv
                   .test(i)
                   .toList
                   .contains(x._2)
                )
        .map(y=> y._1)
      valclassifiersForTraining = numericDTM
        ._2
       .zipWithIndex
        .filter(x=> cv
                   .test(i)
                   .toList
                   .contains(x._2)
                )
        .map(y=> y._1)
      //And thecorresponding subset of data points and their classifiers for testing
      val dpForTesting= numericDTM
        ._1
       .zipWithIndex
        .filter(x=> !cv
                   .test(i)
                   .contains(x._2)
                )
        .map(y=> y._1)
      valclassifiersForTesting = numericDTM
        ._2
       .zipWithIndex
        .filter(x=> !cv
                   .test(i)
                   .contains(x._2)
                )
        .map(y=> y._1)
      //These arethe lambda values we will verify against
      val lambdas:Array[Double] = Array(0.1, 0.25, 0.5, 1.0, 2.0, 5.0)
     lambdas.foreach { x =>
        //Define anew model based on the training data and one of the lambda's
        val model =new LASSO(dpForTraining, classifiersForTraining, x)
        //Compute the RMSE for this model withthis lambda
        val results= dpForTesting.map(y => model.predict(y)) zip classifiersForTesting
        val RMSE =Math
           .sqrt(results
                   .map(x => Math.pow(x._1 - x._2, 2)).sum /
                                 results.length
                       )
       println("Lambda: " + x + " RMSE: " + RMSE)
      }
    }

多次运行此代码使得RMSE在36和51之间变化。这意味着我们将执行的排名预测将至少缺少36个等级。 考虑到我们试图预测前100个排名的事实,它表明该算法执行得很差。 在这种情况下,lambda的差异不明显。但是在实际使用时,在选择lambda值时应该小心:选择的lambda越高,算法的要素数量就越少。这就是为什么交叉验证是重要的,因为要看看算法如何在不同的lambda上执行的。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2016-11-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 鸿的学习笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档