专栏首页LhWorld哥陪你聊算法【推荐系统篇】--推荐系统之训练模型

【推荐系统篇】--推荐系统之训练模型

一、前述

经过之前的训练数据的构建可以得到所有特征值为1的模型文件,本文将继续构建训练数据特征并构建模型。

二、详细流程

将处理完成后的训练数据导出用做线下训练的源数据(可以用Spark_Sql对数据进行处理) insert overwrite local directory '/opt/data/traindata' row format delimited fields terminated by '\t' select * from dw_rcm_hitop_prepare2train_dm; 注:这里是将数据导出到本地,方便后面再本地模式跑数据,导出模型数据。这里是方便演示真正的生产环境是直接用脚本提交spark任务,从hdfs取数据结果仍然在hdfs,再用ETL工具将训练的模型结果文件输出到web项目的文件目录下,用来做新的模型,web项目设置了定时更新模型文件,每天按时读取新模型文件

三、代码详解

package com.bjsxt.data

import java.io.PrintWriter

import org.apache.log4j.{ Level, Logger }
import org.apache.spark.mllib.classification.{ LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD }
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.mllib.optimization.SquaredL2Updater
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.{ SparkContext, SparkConf }

import scala.collection.Map

/**
 * Created by root on 2016/5/12 0012.
 */
class Recommonder {

}

object Recommonder {
  def main(args: Array[String]) {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    val conf = new SparkConf().setAppName("recom").setMaster("local[*]")
    val sc = new SparkContext(conf)
    //加载数据,用\t分隔开
    val data: RDD[Array[String]] = sc.textFile("d:/result").map(_.split("\t"))

    println("data.getNumPartitions:" + data.getNumPartitions) //如果文件在本地的话,默认是32M的分片

//    -1	Item.id,hitop_id85:1,Item.screen,screen2:1 一行数据格式
    //得到第一列的值,也就是label
    val label: RDD[String] = data.map(_(0))
    println(label)
    //sample这个RDD中保存的是每一条记录的特征名
    val sample: RDD[Array[String]] = data.map(_(1)).map(x => {
      val arr: Array[String] = x.split(";").map(_.split(":")(0))
      arr
    })
    println(sample)
//    //将所有元素压平,得到的是所有分特征,然后去重,最后索引化,也就是加上下标,最后转成map是为了后面查询用
    val dict: Map[String, Long] = sample.flatMap(x =>x).distinct().zipWithIndex().collectAsMap()
    //得到稀疏向量
    val sam: RDD[SparseVector] = sample.map(sampleFeatures => {
      //index中保存的是,未来在构建训练集时,下面填1的索引号集合
      val index: Array[Int] = sampleFeatures.map(feature => {
        //get出来的元素程序认定可能为空,做一个类型匹配
        val rs: Long = dict.get(feature) match {
          case Some(x) => x
        }
        //非零元素下标,转int符合SparseVector的构造函数
        rs.toInt
      })
      //SparseVector创建一个向量
      new SparseVector(dict.size, index, Array.fill(index.length)(1.0)) //通过这行代码,将哪些地方填1,哪些地方填0
    })
    //mllib中的逻辑回归只认1.0和0.0,这里进行一个匹配转换
    val la: RDD[LabeledPoint] = label.map(x => {
      x match {
        case "-1" => 0.0
        case "1"  => 1.0
      }
      //标签组合向量得到labelPoint
    }).zip(sam).map(x => new LabeledPoint(x._1, x._2))

//    val splited = la.randomSplit(Array(0.1, 0.9), 10)
//
//    la.sample(true, 0.002).saveAsTextFile("trainSet")
//    la.sample(true, 0.001).saveAsTextFile("testSet")
//    println("done")


    //逻辑回归训练,两个参数,迭代次数和步长,生产常用调整参数
     val lr = new LogisticRegressionWithSGD()
    // 设置W0截距
    lr.setIntercept(true)
//    // 设置正则化
//    lr.optimizer.setUpdater(new SquaredL2Updater)
//    // 看中W模型推广能力的权重
//    lr.optimizer.setRegParam(0.4)
    // 最大迭代次数
    lr.optimizer.setNumIterations(10)
    // 设置梯度下降的步长,学习率
    lr.optimizer.setStepSize(0.1)

    val model: LogisticRegressionModel = lr.run(la)

    //模型结果权重
    val weights: Array[Double] = model.weights.toArray
    //将map反转,weights相应下标的权重对应map里面相应下标的特征名
    val map: Map[Long, String] = dict.map(_.swap)
    //模型保存
    //    LogisticRegressionModel.load()
    //    model.save()
    //输出
    val pw = new PrintWriter("model");
    //遍历
    for(i<- 0 until weights.length){
      //通过map得到每个下标相应的特征名
      val featureName = map.get(i)match {
        case Some(x) => x
        case None => ""
      }
      //特征名对应相应的权重
      val str = featureName+"\t" + weights(i)
      pw.write(str)
      pw.println()
    }
    pw.flush()
    pw.close()
  }
}

 model文件截图如下:

各个特征下面对应的权重:

将模型文件和用户历史数据,和商品表数据加载到redis中去。

 代码如下:

# -*- coding=utf-8 -*-
import redis

pool = redis.ConnectionPool(host='node05', port='6379',db=2)
r = redis.Redis(connection_pool=pool)
f1 = open('../data/ModelFile.txt')
f2 = open('../data/UserItemsHistory.txt')
f3 = open('../data/ItemList.txt')
for i in list:
    lines = i.readlines(100)
    if not lines:
        break
    for line in lines:
        kv = line.split('\t')
        if i==f1:
          r.hset("rcmd_features_score", kv[0], kv[1])
        if i == f2:
          r.hset('rcmd_user_history', kv[0], kv[1])
        if i==f3:
          r.hset('rcmd_item_list', kv[0], line[:-2])
f1.close()

 最终redis文件中截图如下:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 【机器学习】--xgboost初始之代码实现分类

    默认可以通过pip安装,若是安装不上可以通过https://www.lfd.uci.edu/~gohlke/pythonlibs/网站下载相关安装包,将安装包拷...

    LhWorld哥陪你聊算法
  • Spark算子篇 --Spark算子之combineByKey详解

    第二个函数:一开始a是初始值,b是分组内的元素值,比如A[1_],因为没有b值所以不能调用combine函数,第二组因为函数内元素值是[2_,3]调用combi...

    LhWorld哥陪你聊算法
  • 【Storm篇】--Storm中的同步服务DRPC

    Drpc(分布式远程过程调用)是一种同步服务实现的机制,在Storm中客户端提交数据请求之后,立刻取得计算结果并返回给客户端。同时充分利用Storm的计算能力实...

    LhWorld哥陪你聊算法
  • Apex(配置)

    在本教程中,我们将使用Salesforce的Developer Edition。 在开发人员版本中,您不能选择创建沙盒组织。 Sandbox功能在其他版本的Sa...

    Zero--周
  • 百家号爬取(3)

    Centy Zhao
  • 安装ext3grep

    依赖包的安装 [root@localhost .unison]# rpm -qa |  grep  e2fsprogs e2fsprogs-devel-...

    py3study
  • 实战分享 | 你知道这个死锁是怎么产生的吗?

    | 作者 王文安,腾讯CSIG数据库专项的数据库工程师,主要负责腾讯云数据库 MySQL 的相关的工作,热爱技术,欢迎留言进行交流。 ---- Part1 背...

    腾讯云数据库 TencentDB
  • RoIPooling与RoIAlign的区别

    RolPooling可以使生成的候选框region proposal映射产生固定大小的feature map,先贴出一张图,接着通过这图解释RoiPooling...

    于小勇
  • CentOS7yum安装PHP7.2的操作方法

    以上这篇CentOS7yum安装PHP7.2的操作方法就是小编分享给大家的全部内容了,希望能给大家一个参考。

    砸漏
  • MySQL 案例:Update 死锁详解

    锁作为 MySQL 知识体系的主要部分之一,是每个 DBA 都需要学习和掌握的知识。锁保证了数据库在并发的场景下数据的一致性,同时锁冲突也是影响数据库性能的因素...

    王文安@DBA

扫码关注云+社区

领取腾讯云代金券