首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >是否可以根据损失函数停止scipy.optimize.curve_fit?

是否可以根据损失函数停止scipy.optimize.curve_fit?
EN

Stack Overflow用户
提问于 2021-07-06 14:45:49
回答 1查看 249关注 0票数 1

我正在尝试最小化两个有界函数之间的均方误差,curve_fit做得很好,但当两个函数之间的均方误差小于0.1时,我想停止计算。下面是一个简单的示例代码

代码语言:javascript
运行
复制
import numpy as np
from scipy import optimize, integrate

def sir_model(y, x, beta, gamma):
    sus = -beta * y[0] * y[1] / N
    rec = gamma * y[1]
    inf = -(sus + rec)
    return sus, inf, rec

def fit_odeint(x, beta, gamma):
    return integrate.odeint(sir_model, (sus0, inf0, rec0), x, args=(beta, gamma))[:,1]

population = float(1000)
xdata = np.arange(0,335,dtype = float)
upper_bounds = np.array([1,0.7])

N = population
inf0 = 10
sus0 = N - inf0
rec0 = 0.0

#curve to approximate
ydata = fit_odeint(xdata, beta = 0.258, gamma = 0.612)

popt, pcov = optimize.curve_fit(fit_odeint, xdata, ydata,bounds=(0, upper_bounds))

问题是,真正的问题更难解决。所以我想用一个固定的容差(mse = 0.1)来停止函数curve_fit。我尝试使用ftol,但它似乎不起作用。

EN

回答 1

Stack Overflow用户

发布于 2021-07-07 15:13:53

如果我理解正确的话,您希望在底层最小二乘优化问题的目标是<= 0.1时立即终止优化。不幸的是,curve_fitleast_squares都不支持对象值的回调或容差。然而,scipy.optimize.minimize做到了。因此,让我们使用它。

为此,我们必须将您的曲线拟合问题表示为最小化问题:

代码语言:javascript
运行
复制
min ||ydata - fit_odeint(xdata, *coeffs)||**2

s.t. lb <= coeffs <= ub

然后,我们通过minimize解决这个问题,并使用一个回调函数,在目标函数值为<= 0.1时立即终止算法:

代码语言:javascript
运行
复制
from scipy.optimize import minimize
from numpy.linalg import norm

# the objective function
def obj(coeffs):
    return norm(ydata - fit_odeint(xdata, *coeffs))**2

# bounds
bnds = [(0, 1), (0, 0.7)]

# initial point
x0 = np.zeros(2)

# xk is the current parameter vector and state is an OptimizeResult object
def my_callback(xk, state):
    if state.fun <= 0.1:
        return True

# call the solver (res.x contains your coefficients)
res = minimize(obj, x0=x0, bounds=bnds, method="trust-constr", callback=my_callback)

这给了我:

代码语言:javascript
运行
复制
 barrier_parameter: 0.1
 barrier_tolerance: 0.1
          cg_niter: 3
      cg_stop_cond: 4
            constr: [array([0.08205584, 0.44233162])]
       constr_nfev: [0]
       constr_nhev: [0]
       constr_njev: [0]
    constr_penalty: 1635226.9491785716
  constr_violation: 0.0
    execution_time: 0.09222197532653809
               fun: 0.007733264340965375
              grad: array([-3.99185467,  4.04015523])
               jac: [<2x2 sparse matrix of type '<class 'numpy.float64'>'
    with 2 stored elements in Compressed Sparse Row format>]
   lagrangian_grad: array([-0.03385145,  0.19847042])
           message: '`callback` function requested termination.'
            method: 'tr_interior_point'
              nfev: 12
              nhev: 0
               nit: 4
             niter: 4
              njev: 4
        optimality: 0.1984704226037759
            status: 3
           success: False
         tr_radius: 7.0
                 v: [array([ 3.95800322, -3.8416848 ])]
                 x: array([0.08205584, 0.44233162])

请注意,回调的签名只对'trust-constr‘方法有效,对于其他方法,签名是callback(xk) -> bool,即您需要在回调中自己计算目标函数值。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68265753

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档