前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >运用伪逆矩阵求最小二乘解

运用伪逆矩阵求最小二乘解

作者头像
为为为什么
发布2023-04-09 10:27:13
1.6K0
发布2023-04-09 10:27:13
举报
文章被收录于专栏:又见苍岚又见苍岚

之前分析过最小二乘的理论,记录了 Scipy 库求解的方法,但无法求解多元自变量模型,本文记录更加通用的伪逆矩阵求解最小二乘解的方法。

背景

我已经反复研习很多关于最小二乘的内容,虽然朴素但是着实花了一番功夫:

已经有工具可以解很多最小二乘的模型参数了,但是几个专用的最小二乘方法最多支持一元函数的求解,难以计算多元函数最小二乘解,此时就可以用伪逆矩阵求解了。

多元多项式形式模型

这个概念可能不够准确,我要描述的是形如如下函数的一类模型:

f( {\bf x} )=\sum _{i=1}^{n}a_if_i(x_i)

其中模型

最小二乘的损失函数为:

L= \sum_{i=1}\left(f\left(x_{i}\right)-y_{i}\right){2}

对于上述模型,可以利用伪逆求最小二乘解的方法可以用于求解类似线性多项式形式的模型参数,这样就可以求解多元、更加复杂的模型参数。

  • 本质上来说,就是因为这种形式的模型可以凑出形如 A x=b 的矩阵表示,因此可以用这种方法求解。

伪逆求解

在介绍伪逆的文章中其实已经把理论说完了,这里搬运结论:

  • 方程组 A x=b 的最佳最小二乘解为 x=A^{+} b ,并且最佳最小二乘解是唯一的。

实例应用

Python 求逆矩阵
矩阵求逆
代码语言:javascript
复制
import numpy as np

a  = np.array([[1, 2], [3, 4]])  # 初始化一个非奇异矩阵(数组)
print(np.linalg.inv(a))  # 对应于MATLAB中 inv() 函数

# 矩阵对象可以通过 .I 更方便的求逆
A = np.matrix(a)
print(A.I)


-->
[[-2.   1. ]
 [ 1.5 -0.5]]
[[-2.   1. ]
 [ 1.5 -0.5]]

矩阵求伪逆
代码语言:javascript
复制
import numpy as np

# 定义一个奇异阵 A
A = np.zeros((4, 4))
A[0, -1] = 1
A[-1, 0] = -1
A = np.matrix(A)
print(A)
# print(A.I)  将报错,矩阵 A 为奇异矩阵,不可逆
print(np.linalg.pinv(A))   # 求矩阵 A 的伪逆(广义逆矩阵),对应于MATLAB中 pinv() 函数


-->
[[ 0.  0.  0.  1.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [-1.  0.  0.  0.]]
[[ 0.  0.  0. -1.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 1.  0.  0.  0.]]

应用示例

假设我们需要拟合一个多元的复杂的但是参数为多项式形式的模型参数,模型为:

f( {\bf x} )=p_1\frac{e^{x_1}}{\sqrt{x_1}}+p_2 x_2^{1.5} + p_3 \sin x_3

模型真实参数为

代码语言:javascript
复制
import numpy as np

# 定义函数
def f1(x):
    return (np.e ** x) / (x ** 0.5)

def f2(x):
    return (x ** 1.5)

def f3(x):
    return np.sin(x)

# 真实参数
gt_p = [7, 3, 12]

# 真实模型
def f(x1, x2, x3):
    return gt_p[0] * f1(x1) + gt_p[1] * f2(x2) + gt_p[2] * f3(x3)

# 三组自变量数据
X1 = np.arange(1, 3, 0.1)
X2 = X1 * 3
X3 = X1 ** 2

# 生成带噪声的观测值 b
b = np.matrix(f(X1, X2, X3) + (np.random.rand(len(X1)) - 0.5)).T

# 生成矩阵 A
A0 = f1(X1)
A1 = f2(X2)
A2 = f3(X3)

A = np.matrix(np.vstack([A0, A1, A2]).T)

# 逆矩阵求解
para = np.linalg.pinv(A) * b

# 输出结果
print(f"ground truth: {gt_p}")
print(f"got: {para.tolist()}")

输出结果:

代码语言:javascript
复制
ground truth: [7, 3, 12]
got: [[7.046011821943054], [2.9831510054344843], [11.989895579628328]]

参考资料

文章链接: https://cloud.tencent.com/developer/article/2260562

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2023年4月8日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 背景
  • 多元多项式形式模型
  • 伪逆求解
  • 实例应用
    • Python 求逆矩阵
      • 矩阵求逆
      • 矩阵求伪逆
    • 应用示例
    • 参考资料
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档