首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >用NumPy将MATLAB代码转换为Python的问题

用NumPy将MATLAB代码转换为Python的问题
EN

Stack Overflow用户
提问于 2022-02-01 10:23:45
回答 2查看 64关注 0票数 0

有人能帮我吗?

Matlab代码

代码语言:javascript
复制
test = 300
f = @(u) 2+3*u-2*u.^3+u.^6-4*u.^8;
rng(12354)
u_test = unifrnd(-1,1,[test,1]);
y_test = f(u_test);
xt0 = 1*ones(test, 1);
xt1 = u_test;
xt2 = u_test.^2;
xt3 = u_test.^3;
xt4 = u_test.^4;
xt5 = u_test.^5;
xt6 = u_test.^6;
xt7 = u_test.^7;
xt8 = u_test.^8;
xt9 = u_test.^9;
x_test = [xt0 xt1 xt2 xt3 xt4 xt5 xt6 xt7 xt8 xt9];

theta_test = (x_test'*x_test)^-1*x_test'*y_test;

Python代码

代码语言:javascript
复制
import numpy as np
test = 3
min_test = -1
max_test = 1
f = lambda u: 2+3*u-2*u**3+u**6-4*u**8
np.random.seed(12345)
u_test = np.random.uniform(min_test, max_test, [test, 1])

y_test = f(u_test)

xt0 = np.ones([test, 1])
xt1 = u_test
xt2 = u_test ** 2
xt3 = u_test ** 3
xt4 = u_test ** 4
xt5 = u_test ** 5
xt6 = u_test ** 6
xt7 = u_test ** 7
xt8 = u_test ** 8
xt9 = u_test ** 9
x_test = np.array([xt0, xt1, xt2, xt3, xt4, xt5, xt6, xt7, xt8, xt9])

theta_test = (x_test.T * x_test) ** -1 * x_test.T * y_test

MATLAB程序中θ的最终答案,也是正确的答案,是一个10 *1矩阵。它总共有10个数字,但Python代码中的θ答案是10 * 30矩阵。我们的输出有300个数字。

如何用Numpy进行乘法,使最终答案与MATLAB代码答案相同?

EN

回答 2

Stack Overflow用户

发布于 2022-02-01 16:40:11

正如对您的问题的注释所述,最好使用数组而不是变量列表来编写这段代码。

尽管如此,问题在于您使用的是*,它在Python中表示的是内向乘法,而不是矩阵乘法。你应该尝试的是

代码语言:javascript
复制
theta_test = (x_test.T @ x_test) ** -1 * x_test.T @ y_test
票数 0
EN

Stack Overflow用户

发布于 2022-02-01 17:33:25

您的numpy代码生成:

代码语言:javascript
复制
In [3]: x_test.shape
Out[3]: (10, 3, 1)
In [4]: y_test.shape
Out[4]: (3, 1)

去掉最后一个x_test维度:

代码语言:javascript
复制
In [11]: x1 = x_test[:,:,0]
In [12]: (x1.T@x1).shape
Out[12]: (3, 3)
In [13]: np.linalg.inv(x1.T@x1)
Out[13]: 
array([[ 0.34307216, -0.63513713,  0.36347551],
       [-0.63513713,  8.44935183, -6.36062978],
       [ 0.36347551, -6.36062978,  5.43319377]])

切换转置产生一个(10,10)数组,但它是奇异的

代码语言:javascript
复制
In [14]: np.linalg.inv(x1@x1.T)
Traceback (most recent call last):
  ...
LinAlgError: Singular matrix

(10,3)可以将(3,1)乘以产生(10,1)

代码语言:javascript
复制
In [18]: (x1@y_test).shape
Out[18]: (10, 1)

我有麻烦运行你的MATLAB代码在八度,所以不能生成它的等价物,看看你想要复制什么。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70938912

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档