首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >在Python中高精度地查找由(x,y)数据给定的两条曲线的交点

在Python中高精度地查找由(x,y)数据给定的两条曲线的交点
EN

Stack Overflow用户
提问于 2017-02-26 11:12:32
回答 1查看 12.3K关注 0票数 8

我有两个数据集:(x,y1)和(x,y2)。我想找出这两条曲线相交的位置。目标类似于这个问题:Intersection of two graphs in Python, find the x value:

然而,这里描述的方法只找到与最近的数据点的交点。我想以比原始数据间距更高的精度找到曲线的交点。一种选择是简单地重新插值到更精细的网格。这是可行的,但精度取决于我为重新插值选择的点数,这是任意的,需要在精度和效率之间进行权衡。

或者,我可以使用scipy.optimize.fsolve来找到数据集的两个样条插值的确切交点。这很有效,但是它不能很容易地找到多个交叉点,需要我提供交叉点的合理猜测,并且可能不能很好地扩展。(最终,我希望找到几千个(x,y1,y2)集合的交集,因此一个有效的算法会很好。)

这是我到目前为止所拥有的。有什么改进的想法吗?

代码语言:javascript
复制
import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate, scipy.optimize

x  = np.linspace(1, 4, 20)
y1 = np.sin(x)
y2 = 0.05*x

plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

interp1 = scipy.interpolate.InterpolatedUnivariateSpline(x, y1)
interp2 = scipy.interpolate.InterpolatedUnivariateSpline(x, y2)

new_x = np.linspace(x.min(), x.max(), 100)
new_y1 = interp1(new_x)
new_y2 = interp2(new_x)
idx = np.argwhere(np.diff(np.sign(new_y1 - new_y2)) != 0)
plt.plot(new_x[idx], new_y1[idx], 'ro', ms=7, label='Nearest data-point method, with re-interpolated data')

def difference(x):
    return np.abs(interp1(x) - interp2(x))

x_at_crossing = scipy.optimize.fsolve(difference, x0=3.0)
plt.plot(x_at_crossing, interp1(x_at_crossing), 'cd', ms=7, label='fsolve method')

plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

plt.savefig('curve crossing.png', dpi=200)
plt.show()

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-04-22 04:31:08

最好的(也是最有效的)答案很可能取决于数据集及其采样方式。但是,对于许多数据集,一个很好的近似是它们在数据点之间几乎是线性的。因此,我们可以通过原始帖子中显示的“最近数据点”方法找到交叉点的大致位置。然后,我们可以使用线性插值来优化最近的两个数据点之间的交点位置。

这个方法非常快,并且适用于2Dnumpy数组,如果你想一次计算多条曲线的交叉点(就像我想在我的应用程序中做的那样)。

(我借用了"How do I compute the intersection point of two lines in Python?“中的代码进行线性插值。)

代码语言:javascript
复制
from __future__ import division 
import numpy as np
import matplotlib.pyplot as plt

def interpolated_intercept(x, y1, y2):
    """Find the intercept of two curves, given by the same x data"""

    def intercept(point1, point2, point3, point4):
        """find the intersection between two lines
        the first line is defined by the line between point1 and point2
        the first line is defined by the line between point3 and point4
        each point is an (x,y) tuple.

        So, for example, you can find the intersection between
        intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)

        Returns: the intercept, in (x,y) format
        """    

        def line(p1, p2):
            A = (p1[1] - p2[1])
            B = (p2[0] - p1[0])
            C = (p1[0]*p2[1] - p2[0]*p1[1])
            return A, B, -C

        def intersection(L1, L2):
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = L1[2] * L2[1] - L1[1] * L2[2]
            Dy = L1[0] * L2[2] - L1[2] * L2[0]

            x = Dx / D
            y = Dy / D
            return x,y

        L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
        L2 = line([point3[0],point3[1]], [point4[0],point4[1]])

        R = intersection(L1, L2)

        return R

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
    xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
    return xc,yc

def main():
    x  = np.linspace(1, 4, 20)
    y1 = np.sin(x)
    y2 = 0.05*x

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

    # new method!
    xc, yc = interpolated_intercept(x,y1,y2)
    plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

    plt.savefig('curve crossing.png', dpi=200)
    plt.show()

if __name__ == '__main__': 
    main()

更新2018-12-13:如果有必要找到几个拦截,以下是执行此操作的代码的修改版本:

代码语言:javascript
复制
from __future__ import division 
import numpy as np
import matplotlib.pyplot as plt

def interpolated_intercepts(x, y1, y2):
    """Find the intercepts of two curves, given by the same x data"""

    def intercept(point1, point2, point3, point4):
        """find the intersection between two lines
        the first line is defined by the line between point1 and point2
        the first line is defined by the line between point3 and point4
        each point is an (x,y) tuple.

        So, for example, you can find the intersection between
        intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)

        Returns: the intercept, in (x,y) format
        """    

        def line(p1, p2):
            A = (p1[1] - p2[1])
            B = (p2[0] - p1[0])
            C = (p1[0]*p2[1] - p2[0]*p1[1])
            return A, B, -C

        def intersection(L1, L2):
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = L1[2] * L2[1] - L1[1] * L2[2]
            Dy = L1[0] * L2[2] - L1[2] * L2[0]

            x = Dx / D
            y = Dy / D
            return x,y

        L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
        L2 = line([point3[0],point3[1]], [point4[0],point4[1]])

        R = intersection(L1, L2)

        return R

    idxs = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    xcs = []
    ycs = []

    for idx in idxs:
        xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
        xcs.append(xc)
        ycs.append(yc)
    return np.array(xcs), np.array(ycs)

def main():
    x  = np.linspace(1, 10, 50)
    y1 = np.sin(x)
    y2 = 0.02*x

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

    # new method!
    xcs, ycs = interpolated_intercepts(x,y1,y2)
    for xc, yc in zip(xcs, ycs):
        plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

    plt.savefig('curve crossing.png', dpi=200)
    plt.show()

if __name__ == '__main__': 
    main()
代码语言:javascript
复制

票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42464334

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档