前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Logistic回归与梯度下降法

Logistic回归与梯度下降法

作者头像
bear_fish
发布2018-09-19 16:33:19
5280
发布2018-09-19 16:33:19
举报

http://blog.csdn.net/acdreamers/article/details/44657979

Logistic回归为概率型非线性回归模型,是研究二分类观察结果

与一些影响因素

之间关系的一种

多变量分析方法。通常的问题是,研究某些因素条件下某个结果是否发生,比如医学中根据病人的一些症状来判断它是

否患有某种病。

在讲解Logistic回归理论之前,我们先从LR分类器说起。LR分类器,即Logistic Regression Classifier。

在分类情形下,经过学习后的LR分类器是一组权值

,当测试样本的数据输入时,这组权值与测试数

据按照线性加和得到

这里

是每个样本的

个特征。之后按照Sigmoid函数(又称为Logistic函数)的形式求出

由于Sigmoid函数的定义域为

,值域为

,因此最基本的LR分类器适合对两类目标进行分类。

所以Logistic回归最关键的问题就是研究如何求得

这组权值。此问题用极大似然估计来做。

下面正式地来讲Logistic回归模型

考虑具有

个独立变量的向量

,设条件慨率

为根据观测量相对于某事件

发生

的概率。那么Logistic回归模型可以表示为

其中

,那么在

条件下

不发生的概率为

所以事件发生与不发生的概率之比为

这个比值称为事件的发生比(the odds of experiencing an event),简记为odds

可以看出Logistic回归都是围绕一个Logistic函数来展开的。接下来就讲如何用极大似然估计求分类器的参数。

假设有

个观测样本,观测值分别为

,设

为给定条件下得到

的概率,

同样地,

的概率为

,所以得到一个观测值的概率为

因为各个观测样本之间相互独立,那么它们的联合分布为各边缘分布的乘积。得到似然函数为

然后我们的目标是求出使这一似然函数的值最大的参数估计,最大似然估计就是求出参数

,使

取得最大值,对函数

取对数得到

现在求向量

,使得

最大,其中

这里介绍一种方法,叫做梯度下降法(求局部极小值),当然相对还有梯度上升法(求局部极大值)。

对上述的似然函数求偏导后得到

由于是求局部极大值,所以根据梯度上升法,有

根据上述公式,只需初始化向量

全为零,或者随机值,迭代到指定精度为止。

现在就来用C++编程实现Logistic回归的梯度上升算法。首先要对训练数据进行处理,假设训练数据如下

训练数据:TrainData.txt

[cpp] view plain copy

  1. 1 0 0 1 0 1  
  2. 0 0 1 2 0 0  
  3. 1 0 0 1 1 0  
  4. 0 0 0 0 1 0  
  5. 0 0 1 0 0 0  
  6. 0 0 1 0 1 0  
  7. 0 0 1 2 1 0  
  8. 1 0 0 0 0 0  
  9. 0 0 1 0 1 0  
  10. 1 0 1 0 0 0  
  11. 0 0 1 0 1 0  
  12. 0 0 1 0 0 0  
  13. 0 0 1 0 1 0  
  14. 1 0 0 1 0 0  
  15. 1 0 0 0 1 0  
  16. 2 0 0 0 1 0  
  17. 1 0 0 2 1 0  
  18. 2 0 0 0 1 0  
  19. 2 0 1 0 0 0  
  20. 0 0 1 0 1 0  
  21. 0 0 1 2 0 0  
  22. 0 0 0 0 0 0  
  23. 0 0 1 0 1 0  
  24. 1 0 1 0 1 1  
  25. 0 0 1 2 1 0  
  26. 1 0 1 0 0 0  
  27. 0 0 1 0 0 0  
  28. 0 0 0 2 0 0  
  29. 1 0 0 0 1 0  
  30. 2 0 1 0 0 0  
  31. 2 0 1 1 1 0  
  32. 1 0 1 1 0 0  
  33. 1 0 1 2 0 0  
  34. 1 0 0 1 1 0  
  35. 0 0 0 0 1 0  
  36. 1 1 0 0 1 0  
  37. 1 0 1 2 1 0  
  38. 0 0 0 0 1 0  
  39. 0 0 1 0 0 0  
  40. 1 0 1 1 1 0  
  41. 1 0 1 0 1 0  
  42. 2 0 1 2 0 0  
  43. 0 0 1 2 1 0  
  44. 0 0 1 0 1 0  
  45. 2 0 1 0 1 0  
  46. 0 0 1 0 1 0  
  47. 1 0 0 0 0 0  
  48. 1 0 0 0 1 0  
  49. 0 0 0 0 1 0  
  50. 0 0 1 2 1 0  
  51. 0 1 1 0 0 0  
  52. 0 1 0 0 1 0  
  53. 2 1 0 0 0 0  
  54. 2 1 0 0 0 0  
  55. 1 1 0 2 0 0  
  56. 1 1 0 0 0 1  
  57. 0 1 0 0 0 0  
  58. 2 1 0 0 1 0  
  59. 0 1 0 0 1 0  
  60. 2 1 0 2 1 0  
  61. 2 1 0 2 1 0  
  62. 1 1 0 2 1 0  
  63. 0 1 0 0 0 1  
  64. 2 1 1 0 1 0  
  65. 2 1 0 1 1 0  
  66. 1 1 0 0 0 1  
  67. 2 1 0 0 0 0  
  68. 1 1 0 0 1 0  
  69. 1 1 0 0 0 0  
  70. 2 1 0 1 1 0  
  71. 1 1 0 0 1 0  
  72. 1 0 1 1 0 1  
  73. 2 1 0 1 1 0  
  74. 0 1 0 0 1 0  
  75. 1 0 1 0 0 0  
  76. 0 0 1 0 0 1  
  77. 1 0 0 0 0 0  
  78. 0 0 0 2 1 0  
  79. 1 0 1 2 0 1  
  80. 1 0 0 1 1 0  
  81. 2 0 1 2 1 0  
  82. 2 0 0 0 1 0  
  83. 1 0 0 1 1 0  
  84. 1 0 1 0 1 0  
  85. 0 0 1 0 0 0  
  86. 1 0 0 2 1 0  
  87. 2 0 1 1 1 0  
  88. 0 0 1 0 1 0  
  89. 0 0 0 0 1 0  
  90. 2 0 0 1 0 1  
  91. 0 0 1 0 0 0  
  92. 0 0 0 0 0 0  
  93. 1 0 1 1 1 1  
  94. 2 0 1 0 1 0  
  95. 0 0 0 0 0 0  
  96. 1 0 1 0 1 0  
  97. 0 0 0 0 1 0  
  98. 0 0 0 2 0 0  
  99. 0 0 0 0 0 0  
  100. 0 0 1 2 0 0  
  101. 0 0 1 0 1 0  
  102. 0 0 1 0 0 1  
  103. 0 0 0 2 1 0  
  104. 1 0 1 1 1 0  
  105. 1 0 0 1 1 0  
  106. 0 0 1 0 1 0  
  107. 1 0 0 0 0 0  
  108. 1 0 1 0 1 0  
  109. 2 0 0 0 1 0  
  110. 1 0 0 0 1 0  
  111. 2 0 0 1 1 0  
  112. 0 0 1 2 1 0  
  113. 1 0 1 2 0 0  
  114. 0 0 1 2 1 0  
  115. 1 0 0 0 0 0  
  116. 0 0 1 0 1 0  
  117. 0 0 0 1 1 0  
  118. 1 0 0 0 1 0  
  119. 2 0 0 1 1 0  
  120. 1 0 0 1 1 0  
  121. 1 0 1 0 0 0  
  122. 1 1 0 1 1 0  
  123. 2 1 0 0 1 0  
  124. 0 1 0 0 0 0  
  125. 1 1 0 1 0 1  
  126. 1 1 0 2 1 0  
  127. 0 1 0 0 0 0  
  128. 1 1 0 2 0 0  
  129. 0 1 0 0 1 0  
  130. 1 1 0 0 1 1  
  131. 1 1 0 2 1 0  
  132. 1 0 0 2 1 0  
  133. 2 1 1 1 1 0  
  134. 0 1 0 0 1 0  
  135. 0 1 0 0 1 0  
  136. 2 1 0 0 0 1  
  137. 1 1 0 2 1 0  
  138. 1 1 0 0 1 0  
  139. 1 1 1 0 0 0  
  140. 2 1 0 2 1 0  
  141. 2 1 1 1 0 0  
  142. 0 1 0 0 1 0  
  143. 1 1 0 2 1 0  
  144. 0 1 0 0 1 0  
  145. 1 1 0 1 1 0  
  146. 0 1 0 0 1 0  
  147. 0 1 0 0 0 0  
  148. 1 1 0 0 0 0  
  149. 1 1 0 2 1 0  
  150. 1 1 0 0 0 0  
  151. 0 1 1 2 0 0  
  152. 2 1 0 0 1 0  
  153. 2 0 1 0 0 1  
  154. 0 0 1 0 1 0  
  155. 1 0 1 0 0 0  
  156. 0 0 1 2 1 0  
  157. 0 0 1 0 0 0  
  158. 1 0 1 0 1 0  
  159. 0 0 1 0 1 0  
  160. 0 0 1 0 1 0  
  161. 1 0 1 0 1 0  
  162. 0 0 0 0 0 1  
  163. 0 0 1 2 1 0  
  164. 0 0 1 0 1 0  
  165. 0 0 1 0 1 0  
  166. 0 0 1 0 0 0  
  167. 0 0 1 0 0 1  
  168. 0 0 1 2 1 0  
  169. 2 0 1 2 1 0  
  170. 0 0 1 0 1 0  
  171. 0 0 1 0 1 0  
  172. 0 0 1 0 1 0  
  173. 1 0 0 0 0 0  
  174. 2 0 1 1 1 0  
  175. 0 0 1 0 0 1  
  176. 1 0 1 0 0 0  
  177. 1 0 1 1 1 0  
  178. 1 0 1 1 0 0  
  179. 0 0 1 0 0 0  
  180. 1 0 1 1 1 0  
  181. 1 0 1 2 0 0  
  182. 2 0 0 0 1 0  
  183. 0 0 1 0 0 1  
  184. 0 0 1 0 1 0  
  185. 0 0 1 0 1 0  
  186. 1 0 1 0 0 0  
  187. 0 0 1 0 0 0  
  188. 2 0 1 1 0 0  
  189. 0 0 1 2 0 0  
  190. 1 0 0 1 1 1  
  191. 0 0 0 0 1 0  
  192. 0 0 0 0 0 1  
  193. 0 0 1 0 1 0  
  194. 2 0 1 2 1 0  
  195. 1 0 0 1 0 0  
  196. 0 0 1 0 0 0  
  197. 2 0 0 1 1 1  
  198. 0 0 1 0 0 0  
  199. 0 0 1 0 1 0  
  200. 2 0 1 0 1 0  
  201. 0 0 1 0 1 0  
  202. 2 0 0 0 1 0  
  203. 1 0 1 0 1 0  
  204. 1 0 0 0 1 0  
  205. 0 0 1 0 0 1  
  206. 2 0 0 0 0 0  
  207. 2 0 0 1 1 0  
  208. 0 0 1 0 1 0  
  209. 0 0 0 0 1 0  
  210. 2 0 1 0 0 0  
  211. 1 0 1 0 1 0  
  212. 0 0 0 0 1 0  
  213. 1 0 1 0 1 0  
  214. 0 0 1 0 0 0  
  215. 1 0 1 0 1 0  
  216. 1 0 1 0 1 0  
  217. 1 0 1 0 1 0  
  218. 0 0 1 2 0 0  
  219. 2 0 1 0 1 1  
  220. 0 0 1 0 1 0  
  221. 0 0 1 2 1 0  
  222. 0 0 0 0 0 0  
  223. 0 0 1 0 1 0  
  224. 1 0 1 0 1 0  
  225. 0 0 1 0 1 0  
  226. 1 0 1 0 0 0  
  227. 0 0 1 0 1 0  
  228. 0 0 1 0 0 0  
  229. 1 0 1 0 0 0  
  230. 0 0 1 0 1 0  
  231. 0 0 1 0 1 0  
  232. 1 0 0 0 1 0  
  233. 0 0 1 0 0 0  
  234. 0 0 0 0 1 0  
  235. 1 0 1 1 1 0  
  236. 0 0 0 2 0 0  
  237. 0 0 1 0 1 0  
  238. 0 0 1 0 1 0  
  239. 0 0 1 0 1 0  
  240. 1 0 0 1 1 0  
  241. 2 0 0 0 1 0  
  242. 1 0 0 0 0 0  
  243. 2 0 0 2 1 0  
  244. 0 0 1 2 1 0  
  245. 1 0 1 0 0 1  
  246. 0 0 1 2 1 0  
  247. 0 0 1 2 1 0  
  248. 0 0 1 0 1 0  
  249. 1 0 1 2 1 0  
  250. 0 0 0 2 0 0  
  251. 1 0 0 0 0 0  
  252. 0 0 0 2 1 0  
  253. 0 0 1 0 1 0  
  254. 2 0 0 0 1 0  
  255. 1 0 0 0 0 0  
  256. 1 0 0 1 1 0  
  257. 1 0 1 1 1 0  
  258. 1 0 1 0 1 1  
  259. 0 0 1 0 1 0  
  260. 1 1 0 2 1 0  
  261. 1 1 0 1 0 0  
  262. 2 1 0 2 1 0  
  263. 1 1 1 0 0 0  
  264. 0 1 1 0 0 0  
  265. 0 1 1 0 0 1  
  266. 0 1 0 0 1 0  
  267. 1 1 1 0 0 0  
  268. 1 1 1 0 1 0  
  269. 0 1 0 0 1 0  
  270. 0 1 1 0 0 1  
  271. 1 1 1 1 1 0  
  272. 1 1 0 2 1 0  
  273. 0 1 0 2 0 0  
  274. 1 1 0 2 1 0  
  275. 0 0 1 2 1 0  
  276. 2 1 1 1 1 0  
  277. 0 1 0 0 1 0  
  278. 0 0 1 0 1 0  
  279. 2 1 0 1 1 0  
  280. 0 1 0 0 1 0  
  281. 1 1 0 0 0 0  
  282. 1 1 0 0 1 0  
  283. 0 1 0 0 0 0  
  284. 0 1 1 0 0 0  
  285. 2 1 0 0 1 0  
  286. 2 1 0 0 0 0  
  287. 1 1 0 0 1 0  
  288. 2 1 0 1 1 0  

上面训练数据中,每一行代表一组训练数据,每组有7个数组,第1个数字代表ID,可以忽略之,2~6代表这组训

练数据的特征输入,第7个数字代表输出,为0或者1。每个数据之间用一个空格隔开。

首先我们来研究如何一行一行读取文本,在C++中,读取文本的一行用getline()函数。

getline()函数表示读取文本的一行,返回的是读取的字节数,如果读取失败则返回-1。用法如下:

[cpp] view plain copy

  1. #include <iostream>
  2. #include <string.h>
  3. #include <fstream>
  4. #include <string>
  5. #include <stdio.h>
  6. using namespace std;  
  7. int main()  
  8. {  
  9.     string filename = "data.in";  
  10.     ifstream file(filename.c_str());  
  11. char s[1024];  
  12. if(file.is_open())  
  13.     {  
  14. while(file.getline(s,1024))  
  15.         {  
  16. int x,y,z;  
  17.             sscanf(s,"%d %d %d",&x,&y,&z);  
  18.             cout<<x<<" "<<y<<" "<<z<<endl;  
  19.         }  
  20.     }  
  21. return 0;  
  22. }  

拿到每一行后,可以把它们提取出来,进行系统输入。 Logistic回归的梯度上升算法实现如下

代码:

[cpp] view plain copy

  1. #include <iostream>
  2. #include <string.h>
  3. #include <fstream>
  4. #include <stdio.h>
  5. #include <math.h>
  6. #include <vector>
  7. #define Type double
  8. #define Vector vector
  9. using namespace std;  
  10. struct Data  
  11. {  
  12.     Vector<Type> x;  
  13.     Type y;  
  14. };  
  15. void PreProcessData(Vector<Data>& data, string path)  
  16. {  
  17.     string filename = path;  
  18.     ifstream file(filename.c_str());  
  19. char s[1024];  
  20. if(file.is_open())  
  21.     {  
  22. while(file.getline(s, 1024))  
  23.         {  
  24.             Data tmp;  
  25.             Type x1, x2, x3, x4, x5, x6, x7;  
  26.             sscanf(s,"%lf %lf %lf %lf %lf %lf %lf", &x1, &x2, &x3, &x4, &x5, &x6, &x7);  
  27.             tmp.x.push_back(1);  
  28.             tmp.x.push_back(x2);  
  29.             tmp.x.push_back(x3);  
  30.             tmp.x.push_back(x4);  
  31.             tmp.x.push_back(x5);  
  32.             tmp.x.push_back(x6);  
  33.             tmp.y = x7;  
  34.             data.push_back(tmp);  
  35.         }  
  36.     }  
  37. }  
  38. void Init(Vector<Data> &data, Vector<Type> &w)  
  39. {  
  40.     w.clear();  
  41.     data.clear();  
  42.     PreProcessData(data, "TrainData.txt");  
  43. for(int i = 0; i < data[0].x.size(); i++)  
  44.         w.push_back(0);  
  45. }  
  46. Type WX(const Data& data, const Vector<Type>& w)  
  47. {  
  48.     Type ans = 0;  
  49. for(int i = 0; i < w.size(); i++)  
  50.         ans += w[i] * data.x[i];  
  51. return ans;  
  52. }  
  53. Type Sigmoid(const Data& data, const Vector<Type>& w)  
  54. {  
  55.     Type x = WX(data, w);  
  56.     Type ans = exp(x) / (1 + exp(x));  
  57. return ans;  
  58. }  
  59. Type Lw(const Vector<Data>& data, Vector<Type> w)  
  60. {  
  61.     Type ans = 0;  
  62. for(int i = 0; i < data.size(); i++)  
  63.     {  
  64.         Type x = WX(data[i], w);  
  65.         ans += data[i].y * x - log(1 + exp(x));  
  66.     }  
  67. return ans;  
  68. }  
  69. void Gradient(const Vector<Data>& data, Vector<Type> &w, Type alpha)  
  70. {  
  71. for(int i = 0; i < w.size(); i++)  
  72.     {  
  73.         Type tmp = 0;  
  74. for(int j = 0; j < data.size(); j++)  
  75.             tmp += alpha * data[j].x[i] * (data[j].y - Sigmoid(data[j], w));  
  76.         w[i] += tmp;  
  77.     }  
  78. }  
  79. void Display(int cnt, Type objLw, Type newLw, Vector<Type> w)  
  80. {  
  81.     cout<<"第"<<cnt<<"次迭代:  ojLw = "<<objLw<<"  两次迭代的目标差为: "<<(newLw - objLw)<<endl;  
  82.     cout<<"参数w为: ";  
  83. for(int i = 0; i < w.size(); i++)  
  84.         cout<<w[i]<<" ";      cout<<endl;  
  85.     cout<<endl;  
  86. }  
  87. void Logistic(const Vector<Data>& data, Vector<Type> &w)  
  88. {  
  89. int cnt = 0;  
  90.     Type alpha = 0.1;  
  91.     Type delta = 0.00001;  
  92.     Type objLw = Lw(data, w);  
  93.     Gradient(data, w, alpha);  
  94.     Type newLw = Lw(data, w);  
  95. while(fabs(newLw - objLw) > delta)  
  96.     {  
  97.         objLw = newLw;  
  98.         Gradient(data, w, alpha);  
  99.         newLw = Lw(data, w);  
  100.         cnt++;  
  101.         Display(cnt,objLw,newLw, w);  
  102.     }  
  103. }  
  104. void Separator(Vector<Type> w)  
  105. {  
  106.     Vector<Data> data;  
  107.     PreProcessData(data, "TestData.txt");  
  108.     cout<<"预测分类结果:"<<endl;  
  109. for(int i = 0; i < data.size(); i++)  
  110.     {  
  111.         Type p0 = 0;  
  112.         Type p1 = 0;  
  113.         Type x = WX(data[i], w);  
  114.         p1 = exp(x) / (1 + exp(x));  
  115.         p0 = 1 - p1;  
  116.         cout<<"实例: ";  
  117. for(int j = 0; j < data[i].x.size(); j++)  
  118.             cout<<data[i].x[j]<<" ";  
  119.         cout<<"所属类别为:";  
  120. if(p1 >= p0) cout<<1<<endl;  
  121. else cout<<0<<endl;  
  122.     }  
  123. }  
  124. int main()  
  125. {  
  126.     Vector<Type> w;  
  127.     Vector<Data> data;  
  128.     Init(data, w);  
  129.     Logistic(data, w);  
  130.     Separator(w);  
  131. return 0;  
  132. }  

测试数据:TestData.txt

[cpp] view plain copy

  1. 10009 1 0 0 1 0 1  
  2. 10025 0 0 1 2 0 0  
  3. 20035 0 0 1 0 0 1  
  4. 20053 1 0 0 0 0 0  
  5. 30627 1 0 1 2 0 0  
  6. 30648 2 0 0 0 1 0 
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2016年07月25日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档