前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >“电视、新闻、报纸”-简单机器学习预测未来销售额

“电视、新闻、报纸”-简单机器学习预测未来销售额

作者头像
用户6719124
发布2019-11-17 23:47:12
6910
发布2019-11-17 23:47:12
举报

机器学习是python使用的一大方向,本文以简单的三种不同销售方式对最终销额的影响为例子,采用MSE均方差进行分析。

机器学习与深度学习一样,仍需要4部法进行分析。即读取数据、构建框架、训练、测试。

下面观察下部分数据。

由图可见,采用三种不同的销售方式,分别对三种不同销售方式进行投资会产生不同的最终销售额。那么是否有一种方法能通过这些样本数据来预测出销售额呢。答案是肯定的。下面开始分步进行介绍。

1. 数据读取部分

本文采用两种不同的方式进行数据读取

首先是pandas法读取,先引入pandas工具包

import pandas as pd

开始数据读取(将csv文件放在创建的.py文件)

将数据文件设定为path

path = '..\\Advertising.csv'
data = pd.read_csv(path)    # TV、Radio、Newspaper、Sales
    x = data[['TV', 'Radio', 'Newspaper']]
    y = data['Sales']

输出读取后的x,y文件查看

print(x)
print(y)

输出的x为

        TV  Radio  Newspaper
0    230.1   37.8       69.2
1     44.5   39.3       45.1
2     17.2   45.9       69.3
3    151.5   41.3       58.5
4    180.8   10.8       58.4
5      8.7   48.9       75.0
6     57.5   32.8       23.5
7    120.2   19.6       11.6
8      8.6    2.1        1.0
9    199.8    2.6       21.2
10    66.1    5.8       24.2

输出的y为

0      22.1
1      10.4
2       9.3
3      18.5
4      12.9
5       7.2
6      11.8
7      13.2
8       4.8
9      10.6
10      8.6

同时为防止出现乱码,在代码中可加入

mpl.rcParams['font.sans-serif'] = ['simHei']
mpl.rcParams['axes.unicode_minus'] = False

也可以用python自带库进行读取,

f = file(path, 'r')
print f
d = csv.reader(f)
for line in d:
    print line
f.close()

但推荐使用第一种方法。

接下来开始绘制图像

首先将三种销售方式放在同一张图像上

plt.figure(facecolor='w')
# 图像背景为白色
plt.plot(data['TV'], y, 'ro', label='TV')
# 画出电视广告销售方式的图像,以红色圆圈表示
plt.plot(data['Radio'], y, 'g^', label='Radio')
# 画出广播销售方式的图像,以绿色上三角表示
plt.plot(data['Newspaper'], y, 'mv', label='Newspaer')
# 画出报纸销售方式的图像,以紫色下三角表示
plt.legend(loc='lower right')
# 注释放在下右侧
plt.xlabel('广告花费', fontsize=16)
# 设定x轴标识,16号字体
plt.ylabel('销售额', fontsize=16)
# 设定y轴标识,16号字体
plt.title('广告花费与销售额对比数据', fontsize=18)
# 设定标题标识,18号字体
plt.grid(b=True, ls=':')
# 设置网格线
plt.show()

输出图像为

由图可见在广播和报纸销售方式上进行高投入却不会带来高收益,而销售额的增长却随着电视广告花费的增长而增长,呈正相关。

也可以分别绘制图像

plt.figure(facecolor='w', figsize=(9, 10))
plt.subplot(311)
# 绘制出三行,一列,第一个图像
plt.plot(data['TV'], y, 'ro')
plt.title('TV')
plt.grid(b=True, ls=':')
plt.subplot(312)
# 绘制出三行,一列,第二个图像
plt.plot(data['Radio'], y, 'g^')
plt.title('Radio')
plt.grid(b=True, ls=':')
plt.subplot(313)
plt.plot(data['Newspaper'], y, 'b*')
plt.title('Newspaper')
plt.grid(b=True, ls=':')
plt.tight_layout()
# 紧凑布局
plt.show()

绘制出的图像为:

三种图像的x轴坐标不同,在进行销售额对比方面明显不如第一个图。但这种图像适合单独进行分析,看其整体分布情况。

下面开始训练和测试部分

x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.8, random_state=1)
# 训练部分占0.8(80%),打乱随机取样

这里输出一下模型类型和模型的shape

print(type(x_test))
print(x_train.shape, y_train.shape)
<class 'pandas.core.frame.DataFrame'>
(160, 3) (160,)

由于共有200组数据,80%即160组数据。

开始构建模型

linreg = LinearRegression()
# 这里使用线性回归做模型
model = linreg.fit(x_train, y_train)
# 将x和y的训练样本放入到模型中去训练

输出一下模型和模型相关系数

print(model)
print(linreg.coef_, linreg.intercept_)
# 分别输出系数和截距
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
         normalize=False)
[0.0468431  0.17854434 0.00258619] 2.9079470208164278

随后为使样本输出图像更好看,进行重新排序

order = y_test.argsort(axis=0)
y_test = y_test.values[order]
# 转移到test上面去重排
x_test = x_test.values[order, :]
# 多行元素放到order上去选

开始进行预测

y_hat = linreg.predict(x_test)
# 进行预测

输出计算系数

mse = np.average((y_hat - np.array(y_test)) ** 2)
# 使用均方差进行分析
rmse = np.sqrt(mse)  # Root Mean Squared Error
print('MSE = ', mse, end=' ')
print('RMSE = ', rmse)
print('R2 = ', linreg.score(x_train, y_train))
# score api就是用来算R2的
print('R2 = ', linreg.score(x_test, y_test))
MSE =  1.9918855518287881 RMSE =  1.4113417558581578
R2 =  0.8959372632325174
R2 =  0.8927605914615385

将结果可视化

plt.figure(facecolor='w')
t = np.arange(len(x_test))
plt.plot(t, y_test, 'r-', linewidth=2, label='真实数据')
plt.plot(t, y_hat, 'g-', linewidth=2, label='预测数据')
plt.legend(loc='upper left')
plt.title('线性回归预测销量', fontsize=18)
plt.grid(b=True, ls=':')
plt.show()

结果图像为

这里注意有时模型样本不是越多越好,更具相关性的特征在进行训练时,效果会更好些。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-09-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档