前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >1,Kalman滤波器参数如何选取

1,Kalman滤波器参数如何选取

原创
作者头像
雷科
修改2020-02-17 11:53:27
2.9K0
修改2020-02-17 11:53:27
举报
文章被收录于专栏:用户6953650的专栏

新冠居家封闭期间,对参考文献中估计常数的例子,初次使用python的NumPy库进行仿真,深入理解Kalman滤波器的参数对滤波性能的影响。

设计5组参数,生成图片的如下

模拟数据直方图统计.png
模拟数据直方图统计.png
状态.png
状态.png
滤波值的方差.png
滤波值的方差.png
新息的统计距离.png
新息的统计距离.png
新息的统计距离的统计信息.png
新息的统计距离的统计信息.png

结论

1.1,增加Q,增益增加,即观测值在状态更新方程中的权重变大,滤波器更加灵敏,反之亦然。

1.2,增加R,增益减小,即观测值在状态更新方程中的权重变小,滤波器反应迟钝,反之亦然。

2.1,参数R表示观测值的方差,应尽可能准确。

2.2,综合考虑滤波器在随机性和惯性等方面的表现,参数Q的取值在r/9 - r/4 较合适?

心得

Matlab真心不好下载不好用,Python确实好用多了。

代码如下

代码语言:txt
复制
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''=========================================================================================================
@Project -> File   :active -> kalmanEstimateConstant
@IDE    :PyCharm
@Author :Mr. Ddc
@Date   :2020/2/12 15:34
@Desc   :对参考文献中的例子(估计一个常数)进行仿真,深入理解Kalman滤波器的参数对滤波性能的影响。
@Ref    :Welch & Bishop, An Introduction to the Kalman Filter, UNC-Chapel Hill, TR 95-041, July 24, 2006
=========================================================================================================='''

import sys
import traceback
import getopt
import numpy as np
import matplotlib.pyplot as plt


# 模拟观测数据
def simData(xTruth, sigma, N):
    data = np.random.normal(xTruth, sigma, N)
    # np.savetxt("SimData.txt", data)  # 数据存盘
    return data

# 1维数据滤波
def filter1D(z, Q, R):
    x = np.zeros(z.shape)  # 滤波值
    pkk = np.zeros(z.shape)  # 滤波协方差
    dis = np.zeros(z.shape)  # 新息统计距离

    # 初始化
    x[0] = z[0]
    pkk[0] = R

    for i in range(1, z.size):
        xki = x[i-1]  # 状态的一步预测
        pki = pkk[i-1] + Q  # 协方差的一步预测
        v = z[i] - xki  # 新息
        S =  pki + R  # 新息的协方差
        K = pki / S  # 增益
        x[i] = xki + K * v  # 状态的更新
        # pkk[i] = pki - K*S*K  # 协方差的更新  # 两者是相同的
        pkk[i] = (1 - K) * pki  # 协方差的更新
        dis[i] = v / S * v  # 新息统计距离

    return x, pkk, dis

# 画图
def display(Q, R, xTruth, z, x, pkk, dis, ivSta):
    N = z.size

    # 颜色表
    colors = []
    colors.append('b')  # 参数1
    colors.append('g')  # 参数2
    colors.append('r')  # 参数3
    colors.append('y')  # 参数4
    colors.append('c')  # 参数5
    colors.append('m')  # 观测值用
    colors.append('orange')  # 真值用

    # markers
    markers = []
    markers.append('o')  # 参数1
    markers.append('v')  # 参数2
    markers.append('d')  # 参数3
    markers.append('x')  # 参数4
    markers.append('+')  # 参数5

    # 画 真值、观测值、滤波值
    title = '状态'
    plt.figure(title)
    plt.plot(np.arange(N) + 1, xTruth * np.ones(N), color = colors[len(colors)-1], label = '真值')  # plot
    plt.plot(np.arange(N) + 1, z, '.-', color = colors[len(colors)-2], label = '观测值')  # plot
    for i in range(0, len(R)):
        plt.plot(np.arange(N) + 1, x[i, :], 'o-',
                 marker = markers[i], color = colors[i], label = f'({i+1}):Q = {Q[i]: .6f}, R = {R[i]:.4f}')  # plot

    plt.legend()
    plt.xlabel('时间')  # 设置x轴标签
    plt.ylabel('电压')  # 设置y轴标签
    # plt.title(title)
    plt.savefig(title)

    # 画 滤波值的方差
    title = '滤波值的方差'
    fig = plt.figure(title)
    # 主图 ax1
    left, bottom, width, height = 0.15, 0.11, 0.8, 0.8
    ax1 = fig.add_axes([left, bottom, width, height])
    for i in range(0, len(R)):
        ax1.plot(np.arange(N)+1, pkk[i], '.-',
                 marker = markers[i], color = colors[i], label = f'({i+1}):Q = {Q[i]: .6f}, R = {R[i]:.4f}')  # plot

    ax1.legend(loc = 'upper right')
    ax1.grid()
    ax1.set_ylabel('方差')  # 设置y轴标签  斜体
    ax1.set_xlabel('时间')  # 设置x轴标签
    # ax1.set_ylabel('$\mathregular{Voltage^2}$')  # 设置y轴标签 去除斜体
    ax1.set_ylim([-0.002, 0.02])
    # ax1.set_title(title)
    # 子图,由于参数3对应的方差太大,用子图另画
    left, bottom, width, height = 0.25, 0.6, 0.25, 0.25
    ax2 = fig.add_axes([left, bottom, width, height])
    ax2.plot(np.arange(N)+1, pkk[3], '.-', marker = markers[3],  color = colors[3])
    ax2.grid()
    ax2.set_ylabel('方差', fontsize = 'small')  # 设置y轴标签
    ax2.set_xlabel('时间', fontsize = 'small')  # 设置x轴标签
    ax2.set_title('参数4', fontsize = 'small')
    ax2.patch.set_visible(False)  # 设置子图透明
    plt.savefig(title)

    # 画 新息的统计距离
    title = '新息的统计距离'
    fig = plt.figure(title)
    # 主图 ax1
    left, bottom, width, height = 0.12, 0.11, 0.8, 0.8
    ax1 = fig.add_axes([left, bottom, width, height])
    for i in range(0, len(R)):
        ax1.plot(np.arange(N)+1, dis[i, :], '.-',
                 marker = markers[i], color = colors[i], label = f'({i+1}):Q = {Q[i]: .6f}, R = {R[i]:.4f}')  # plot

    ax1.legend(loc = 'upper right')
    ax1.set_xlabel('时间')  # 设置x轴标签
    ax1.set_ylabel('统计距离')  # 设置y轴标签
    ax1.set_ylim([-0.1, 7])
    ax1.grid()
    # 子图,由于参数4对应的方差太大,用子图另画
    left, bottom, width, height = 0.23, 0.6, 0.25, 0.25
    ax2 = fig.add_axes([left, bottom, width, height])
    ax2.plot(np.arange(N)+1, dis[4], '.-', marker = markers[4], color = colors[4])
    ax2.set_xlabel('时间', fontsize = 'small')  # 设置x轴标签
    ax2.set_ylabel('统计距离', fontsize = 'small')  # 设置y轴标签
    ax2.grid()
    ax2.set_title('参数5', fontsize = 'small')
    ax2.patch.set_visible(False)  # 设置子图透明
    plt.savefig(title)

    # 画 新息的统计距离的统计信息
    txt = ('最小值', '最大值', '平均值', '标准差')
    title = '新息的统计距离的统计信息'
    fig = plt.figure(title)
    # 主图 ax1
    left, bottom, width, height = 0.12, 0.11, 0.8, 0.8
    ax1 = fig.add_axes([left, bottom, width, height])
    for i in range(0, ivSta.shape[1]):
        ax1.plot(np.arange(ivSta.shape[0])+1, ivSta[:, i], '.-', label = txt[i])  # plot

    ax1.legend(loc = 'upper right')
    ax1.set_xticks(np.arange(ivSta.shape[0])+1)
    ticks = []
    for i in range(ivSta.shape[0]):
        ticks.append('参数' + str(i+1))
    ax1.set_xticklabels(ticks)
    ax1.set_ylim([-0.3, 9])
    ax1.grid()
    # 子图,由于参数4对应的值太大,用子图另画
    left, bottom, width, height = 0.23, 0.6, 0.25, 0.25
    ax2 = fig.add_axes([left, bottom, width, height])
    ax2.bar(range(0, 3), ivSta[-1, 1:], alpha = 0.5)
    ax2.set_xticks(range(0, 3))
    ax2.set_xticklabels(txt[1:], rotation = 30, fontsize = 'small')
    ax2.patch.set_visible(False)  # 设置子图透明
    # for item in [fig, ax2]:  # 设置子图透明
    #     item.patch.set_visible(False)
    ax2.grid()
    ax2.set_title('参数5', fontsize = 'small')
    plt.savefig(title)

    plt.show()

# Works 实现主要功能
def Works():
    N = 100  # 观测次数
    xTruth = -0.37727  # 真值
    sigma = 0.1  # 叠加正态分布噪声的标准差
    r = sigma ** 2  # 观测方差

    # 1,模拟一次数据
    # simData(xTruth, sigma, N)  # 模拟观测数据
    # z = np.loadtxt("SimData.txt")  # 从硬盘读取数据
    z = simData(xTruth, sigma, N)  # 模拟观测数据

    plt.figure("模拟数据直方图统计")
    count, bins, ignored = plt.hist(z, z.size)
    plt.plot(bins, 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp( - (bins - xTruth)**2 / (2 * sigma**2) ), linewidth = 2, color = 'r')
    plt.savefig("模拟数据直方图统计")

    # 五组参数
    R = [r,     r,      r,      r*100,  r/100]
    Q = [1e-5,  r/9,    2*r,    1e-5,   1e-5]
    x = np.zeros((len(R), z.size))  # 滤波值,各行依次对应一组参数
    pkk = np.zeros((len(R), z.size))  # 滤波协方差,各行依次对应一组参数
    dis = np.zeros((len(R), z.size))  # 新息统计距离,各行依次对应一组参数

    for idxR in range(0, len(R)):
        r = R[idxR]
        q = Q[idxR]
        x[idxR, :], pkk[idxR, :], dis[idxR, :] = filter1D(z, q, r)  # 滤波

    # 统计新息的最小值/最大值/平均值/标准差
    ivSta = np.zeros((len(R), 4))  # 各行依次对应一组参数
    for idxR in range(0, len(R)):
        ivSta[idxR, :] = (np.min(dis[idxR, :]), np.max(dis[idxR, :]), np.mean(dis[idxR, :]), np.std(dis[idxR, :], ddof = 1) )

    display(Q, R, xTruth, z, x, pkk, dis, ivSta)  # 画图

    return 0


# 异常处理函数
class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg


# main()函数
def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "h", ["help"])
        except(getopt.error, msg):
            raise (Usage(msg))

        # 使plt支持中文
        plt.rcParams['font.sans-serif'] = ['SimHei']
        plt.rcParams['axes.unicode_minus'] = False

        # Works 实现主要功能
        Works()
    except Usage as err:  # 待处理新增代码中的异常
        print(sys.stderr, err.msg)
        print(sys.stderr, "for help use --help")
        return 1
    except Exception as e:
        print("print_exception()")
        exc_type, exc_value, exc_tb = sys.exc_info()
        print('the exc type is:', exc_type)
        print('the exc value is:', exc_value)
        print('the exc tb is:', exc_tb)
        traceback.print_exception(exc_type, exc_value, exc_tb)
    finally:
        print('exit main')


# 直接运行的入口
if "__main__" == __name__:
    sys.exit(main())

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 设计5组参数,生成图片的如下
  • 结论
  • 心得
  • 代码如下
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档