首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >基于Pandas + GIF可视化的梯度下降算法

基于Pandas + GIF可视化的梯度下降算法
EN

Code Review用户
提问于 2020-01-02 05:15:37
回答 1查看 463关注 0票数 3

大家新年快乐,下面是我在Python中使用吴家祥的课程实现的梯度下降算法(单变量线性回归),matplotlib具有可选的可视化功能,它创建了一个GIF。欢迎任何优化/建议。

对于那些不知道梯度下降算法是什么的人,梯度下降是一种求函数局部极小的一阶迭代优化算法。若要使用梯度下降求函数的局部极小值,则在当前点采取与函数的梯度(或近似梯度)的负值成正比的步骤。相反,如果采取与梯度正成比例的步骤,则接近该函数的局部最大值。您可以查看维基百科上的定义

Ex1 GIF:

Ex2 GIF:

所用数据集的链接:

ex1:

https://drive.google.com/open?id=1tztcXVillZTrbPeeCd28djRooM5nkiBZ

ex2:

https://drive.google.com/open?id=17ZQ4TLA7ThtU-3J-G108a1fzCH72nSFp

代码语言:javascript
运行
复制
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
import imageio
import os


def compute_cost(b, m, data):
    """
    Compute cost function for univariate linear regression using mean squared error(MSE).
    Args:
        b: Intercept (y = mx + b).
        m: Slope (y = mx + b).
        data: A pandas df with x and y data.
    Return:
         Cost function value.
    """
    data['Mean Squared Error(MSE)'] = (data['Y'] - (m * data['X'] + b)) ** 2
    return data['Mean Squared Error(MSE)'].sum().mean()


def adjust_gradient(b_current, m_current, data, learning_rate):
    """
    Adjust Theta parameters for the univariate linear equation y^ = hθ(x) = θ0 + θ1x
    Args:
        b_current: Current Intercept (y = mx + b) or [Theta(0) θ0].
        m_current: Current Slope (y = mx + b) or [Theta(1) θ1].
        data: A pandas df with x and y data.
        learning_rate: Alpha value.
    Return:
        Adjusted Theta parameters.
    """
    data['b Gradient'] = -(2 / len(data)) * (data['Y'] - ((m_current * data['X']) + b_current))
    data['m Gradient'] = -(2 / len(data)) * data['X'] * (data['Y'] - ((m_current * data['X']) + b_current))
    new_b = b_current - (data['b Gradient'].sum() * learning_rate)
    new_m = m_current - (data['m Gradient'].sum() * learning_rate)
    return new_b, new_m


def gradient_descent(data, b, m, learning_rate, max_iter, visual=False):
    """
    Optimize Theta values for the univariate linear regression equation y^ = hθ(x) = θ0 + θ1x.
    Args:
        data: A pandas df with x and y data.
        b: Starting b (θ0) value.
        m: Starting m (θ1) value.
        learning_rate: Alpha value.
        max_iter: Maximum number of iterations.
        visual: If True, a GIF progression will be generated.
    Return:
        Optimized values for θ0 and θ1.
    """
    line = np.arange(len(data))
    folder_name = None
    if visual:
        folder_name = str(random.randint(10 ** 6, 10 ** 8))
        os.mkdir(folder_name)
        os.chdir(folder_name)
    for i in range(max_iter):
        b, m = adjust_gradient(b, m, data, learning_rate)
        if visual:
            data['Line'] = (line * m) + b
            data.plot(kind='scatter', x='X', y='Y', figsize=(8, 8), marker='x', color='r')
            plt.plot(data['Line'], color='b')
            plt.grid()
            plt.title(f'y = {m}x + {b}\nCurrent cost: {compute_cost(b, m, data)}\nIteration: {i}\n'
                      f'Alpha = {learning_rate}')
            fig_name = ''.join([str(i), '.png'])
            plt.savefig(fig_name)
            plt.close()
    if visual:
        frames = os.listdir('.')
        frames.sort(key=lambda x: int(x.split('.')[0]))
        frames = [imageio.imread(frame) for frame in frames]
        imageio.mimsave(folder_name + '.gif', frames)
    return b, m


if __name__ == '__main__':
    data = pd.read_csv('data.csv')
    data.columns = ['X', 'Y']
    learning = 0.00001
    initial_b, initial_m = 0, 0
    max_it = 350
    b, m = gradient_descent(data, initial_b, initial_m, learning, max_it, visual=True)
EN

回答 1

Code Review用户

回答已采纳

发布于 2020-01-02 17:26:03

什么都没发生!

当我运行脚本时,我的第一个想法是它挂起来了,因为什么都没发生。显然,我的电脑计算速度太慢了。

为了让用户放心,程序正在工作,有一些输出表明进展。下面是我添加它的方式,但是您可以通过进度条和其他东西使它变得更时尚:

代码语言:javascript
运行
复制
n_reports = 20
for i in range(max_iter):
    if max_iter >= n_reports and i % (max_iter // n_reports) == 1:
        print(f'{(i * 100 // max_iter)}% ', end = '', flush = True)
    ...
print('100%')

n_reports是打印完成百分比的次数。不耐烦的用户希望打印得更频繁,而耐心的用户(或计算机速度更快的用户)希望打印得更少。

临时目录

Python有一个临时文件和目录模块,您应该使用这些模块,而不是创建自己的文件和目录。此外,除非有必要,否则永远不要更改进程当前工作目录( chdir调用)。它会引起非常奇怪的问题。

tempfilepathlib模块简化了文件处理:

代码语言:javascript
运行
复制
if visual:
    save_dir = pathlib.Path(tempfile.mkdtemp())
...
for i in range(max_iter):
    ...
    if visual:
        ...
        fig_name = f'{i:05}.png'
        plt.savefig(save_dir / fig_name)
...
if visual:
    frames = sorted([f.resolve() for f in save_dir.iterdir()])
    frames = [imageio.imread(frame) for frame in frames]
    image_file = save_dir.name + '.gif'
    imageio.mimsave(image_file, frames)
    print(f'Saved image as {image_file}')

我也是这样做的,所以文件名中填充了零。这样,你就可以依赖字典顺序,而不必编写自己的关键函数。

票数 2
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/234936

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档