一个简单的pandas图产生预期的输出,图例上有一个圆标记:
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
matplotlib.pyplot.show()
请注意,"CrudeRate“图例项是一条带圆标记的直线,这是正确的。
但是,如果我为Holt线性指数平滑函数添加了一些额外的绘图,则图例将丢失圆形标记:
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
fit.fittedvalues.plot(color=color, style="--")
title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
forecast = fit.forecast(predict).rename("${}$".format(title))
forecast.plot(color=color, legend=True, style="--")
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.show()
请注意,"CrudeRate“图例项只是一条没有圆标记的直线。
是什么原因导致第二个案例中的图例丢失了主图的圆形标记?
发布于 2019-02-20 17:24:32
在matplotlib.pyplot.show()
之前使用matplotlib.pyplot.legend()
可以解决您的问题。
由于您绘制了3个图表,而据我所知,您只需要在图例中包含2个标签,因此我们将label='_nolegend_'
传递给fit.fittedvalues.plot()
。如果我们不这样做,我们将在图形图例中有一个值为None
的第三个标签。
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
fit.fittedvalues.plot(color=color, style="--", label='_nolegend_')
title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
forecast = fit.forecast(predict).rename("${}$".format(title))
forecast.plot(color=color, legend=True, style="--")
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.legend()
matplotlib.pyplot.show()
另外,为了使您更容易编写代码,作为跟随import matplotlib.pyplot as plt
导入matplotlib.pyplot
是一种很好的实践。
https://stackoverflow.com/questions/54791323
复制相似问题