前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >简单易学的机器学习算法——Softmax Regression

简单易学的机器学习算法——Softmax Regression

作者头像
felixzhao
发布2018-03-19 10:42:05
1.2K0
发布2018-03-19 10:42:05
举报
文章被收录于专栏:null的专栏

一、Softmax Regression简介

        Softmax Regression是Logistic回归的推广,Logistic回归是处理二分类问题的,而Softmax Regression是处理多分类问题的。Logistic回归是处理二分类问题的比较好的算法,具有很多的应用场合,如广告计算等。Logistic回归利用的是后验概率最大化的方式去计算权重。

二、Logistic回归的回顾

    在Logistic回归中比较重要的有两个公式,一个是阶跃函数:

另一个是对应的损失函数

最终,Logistic回归需要求出的是两个概率:

具体的Logistic回归的过程可参见“简单易学的机器学习算法——Logistic回归”。

三、Logistic回归的推广——Softmax Regression

    在Logistic回归需要求解的是两个概率:

,而在Softmax Regression中将不是两个概率,而是

个概率,

表示的是分类的个数。我们需要求出以下的概率值:

此时的损失函数为

其中

是一个指示性函数,意思是大括号里的值为真时,该函数的结果为1,否则为0。下面就这几个公式做个解释:

1、损失函数的由来

   概率函数可以表示为

其似然函数为

似然为

我们要最大化似然函数,即求

。再转化成损失函数。

2、对 

似然(或者是损失函数)求偏导

   为了简单,我们仅取一个样本,则可简单表示为

下面对

求偏导:

其中,

表示第

维。如Logistic回归中一样,可以使用基于梯度的方法来求解这样的最大化问题。基于梯度的方法可以参见“优化算法——梯度下降法”。

四、实验

1、训练数据

   从图上我们可以看到分为4类。

2、测试数据

在区间上随机生成了4000个点,这样比较直观地看到分类边界。

3、Matlab源码

主程序

代码语言:javascript
复制
clear all;
clc;

%% 导入数据
data = load('SoftInput.txt');
[m,n] = size(data);
labels = unique(data(:,3));
labelLen = length(labels);%划分的种类
dataMat(:,2:3) = data(:,1:2);
dataMat(:,1) = 1;%做好数据集,添加一列为1
labelMat(:,1) = data(:,3)+1;%分类的标签
%% 画图
figure;
hold on
for i = 1:m
    if labelMat(i,:) == 1
        plot(data(i,1),data(i,2),'.m');%粉红色
    elseif labelMat(i,:) == 2
        plot(data(i,1),data(i,2),'.b');%蓝色
    elseif labelMat(i,:) == 3
        plot(data(i,1),data(i,2),'.r');%红色
    else
        plot(data(i,1),data(i,2),'.k');%黑色
    end
end
title('原始数据集');
hold off

%% 初始化一些参数
M = m;%数据集的行
N = n;%数据集的列
K = labelLen;%划分的种类
alpha = 0.001;%学习率
weights = ones(N, K);%初始化权重

%% 利用随机梯度修改权重
weights = stochasticGradientAscent(dataMat, labelMat, M, weights, alpha);

%% 测试数据集(主要在区间里随机生成)
size = 4000;
[testDataSet, testLabelSet] = testData(weights, size, N);
%% 画出最终的分类图
figure;
hold on
for i = 1:size
    if testLabelSet(i,:) == 1
        plot(testDataSet(i,2),testDataSet(i,3),'.m');
    elseif testLabelSet(i,:) == 2
        plot(testDataSet(i,2),testDataSet(i,3),'.b');
    elseif testLabelSet(i,:) == 3
        plot(testDataSet(i,2),testDataSet(i,3),'.r');
    else
        plot(testDataSet(i,2),testDataSet(i,3),'.k');
    end
end
title('测试数据集');
hold off

随机梯度法

代码语言:javascript
复制
%% 随机梯度下降法(这里要用上升法)
function [ weights ] = stochasticGradientAscent( dataMat, labelMat, M, weights, alpha )
    for step = 1:500
        for i = 1:M%对每一个样本
            pop = exp(dataMat(i,:)*weights);%计算概率
            popSum = sum(pop);%分母
            pop = -pop/popSum;%求好概率
            pop(:,labelMat(i)) = pop(:,labelMat(i))+1;%加1的操作
            weights = weights + alpha*dataMat(i,:)'*pop;
        end
    end
end

生成测试数据

代码语言:javascript
复制
%% 计算测试数据集
function [ testDataSet, testLabelSet ] = testData( weights, m, n)
    testDataSet = ones(m,n);%构建了全1的矩阵
    testLabelSet = zeros(m,1);
    for i = 1:m
        testDataSet(i,2) = rand()*6-3;
        testDataSet(i,3) = rand()*15;
    end
    
    %% 计算测试数据的所属分类
    for i = 1:m
        testResult = testDataSet(i,:)*weights;
        [C,I] = max(testResult);
        testLabelSet(i,:) = I;
    end
end       
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、Softmax Regression简介
  • 二、Logistic回归的回顾
  • 三、Logistic回归的推广——Softmax Regression
    • 1、损失函数的由来
      • 2、对 
      • 四、实验
        • 1、训练数据
          • 2、测试数据
            • 3、Matlab源码
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档