前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >算法复现 | 使用KMEAN算法对印度洋台风路径进行分类

算法复现 | 使用KMEAN算法对印度洋台风路径进行分类

作者头像
郭好奇同学
发布2022-11-15 09:16:09
1.1K0
发布2022-11-15 09:16:09
举报
文章被收录于专栏:好奇心Log好奇心Log

点击下方公众号,回复资料,收获惊喜

以下全文代码和数据均已发布至和鲸社区,复制下面链接前往,可一键fork跑通:

https://www.heywhale.com/mw/project/6302faacf31025b7777230c9

本文根据《K-均值聚类法用于西北太平洋热带气旋路径分类》文献中的聚类方法,对印度洋的台风路径进行聚类分析。 其核心原理就是通过计算每条台风路径的经、纬向质心,以及经、纬、对角向的方差,作为聚类的依据,使用KMEAN算法将上述5个特征进行分类。 最后将分类后的结构进行可视化展示。

导入相关库

代码语言:javascript
复制
import pandas as pd
import numpy as np
import glob
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeat
from cartopy.mpl.ticker import LongitudeFormatter,LatitudeFormatter
from cartopy.io.shapereader import Reader, natural_earth
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.image import imread
import shapely.geometry as sgeom
import cartopy.io.shapereader as shpreader
import matplotlib.lines as mlines
from sklearn.cluster import KMeans

读取数据&相应预处理

代码语言:javascript
复制
files = glob.glob(f'/home/mw/input/india_typhoon6005/indian/*/*')
df_list = []
i = 1
for f in files[:]:
    lines = open(f, 'r').readlines()
    max_line = max(lines, key=len)
    max_line_num = len(max_line.split(','))
    df = pd.read_csv(f, sep=',', header=None, error_bad_lines=False, names=range(max_line_num))

    df['count'] = i
    df = df[df.columns.drop('count').insert(0, 'count')]
    df = df.loc[:, :8].drop(columns=[0, 3, 4, 5])

    df.columns.values[:6] = ['count', 'num', 'time', 'lat', 'lon', 'speed']
    
    df['time'] = pd.to_datetime(df['time'], format='%Y%m%d%H')
    df['year'] = df['time'][0].year
    df['month'] = df['time'][0].month
    df['day'] = df['time'][0].day
        
    df['y'] = df['lat'].apply(lambda x: float(x[:-1]) / 10)
    df['x'] = df['lon'].apply(lambda x: float(x[:-1]) / 10)
    df['speed'] = df['speed'] * 0.514
    df['w'] = df['speed'] ** 0.5
    
    if (df['time'].iloc[-1] - df['time'].iloc[0]).days >= 1 and df['speed'].max() >= 17.2:
        df_list.append(df)
        i = i + 1
    
data = pd.concat(df_list)
data['w'] = data['w'].replace(0, np.nan)
data.dropna(axis='rows', how='any', inplace=True)
data.drop(columns=['lat', 'lon'], inplace=True)
代码语言:javascript
复制
data

根据文献设置相应特征

代码语言:javascript
复制
tc = data
代码语言:javascript
复制
tc['wx'] = tc['w'] * tc['x']
tc['wy'] = tc['w'] * tc['y']
tc
代码语言:javascript
复制
tc['wx_sum'] = tc.groupby(['count'])['wx'].transform('sum')
tc['wy_sum'] = tc.groupby(['count'])['wy'].transform('sum')
tc['w_sum'] = tc.groupby(['count'])['w'].transform('sum')
tc['x_mean'] = tc['wx_sum'] / tc['w_sum']
tc['y_mean'] = tc['wy_sum'] / tc['w_sum']
tc['x_var'] = tc['wx_sum'] / tc['w_sum']
tc['x_var'] = (tc['x'] - tc['x_mean']) ** 2 * tc['w']
tc['y_var'] = (tc['y'] - tc['y_mean']) ** 2 * tc['w']
tc['xy_var'] = (tc['y'] - tc['y_mean']) * (tc['x'] - tc['x_mean']) * tc['w']
tc['x_var_sum'] = tc.groupby(['count'])['x_var'].transform('sum')
tc['y_var_sum'] = tc.groupby(['count'])['y_var'].transform('sum')
tc['xy_var_sum'] = tc.groupby(['count'])['xy_var'].transform('sum')
tc['x_var_mean'] = tc['x_var_sum'] / tc['w_sum']
tc['y_var_mean'] = tc['y_var_sum'] / tc['w_sum']
tc['xy_var_mean'] = tc['xy_var_sum'] / tc['w_sum']
tc
代码语言:javascript
复制
tc_group = tc.groupby('count').mean()[['x_mean', 'y_mean', 'x_var_mean', 'y_var_mean', 'xy_var_mean']]
tc_group

运行算法

代码语言:javascript
复制
kmeans = KMeans(n_clusters=4, random_state=0)
kmeans.fit(tc_group)
代码语言:javascript
复制
KMeans(n_clusters=4, random_state=0)

查看算法输出结果

代码语言:javascript
复制
print(kmeans.labels_)
代码语言:javascript
复制
[2 2 0 0 2 3 0 2 3 0 3 2 0 2 2 2 2 2 2 3 3 2 2 3 2 2 1 2 3 2 2 2 2 2 3 2 2
 3 2 2 0 0 2 3 2 2 2 2 2 3 2 2 0 2 2 0 2 3 2 2 0 2 2 1 2 0 0 1 2 2 3 2 0 0
 3 2 2 2 3 2 2 1 2 2 2 2 2 0 0 2 3 3 3 3 2 2 3 2 3 2 2 2 2 2 1 2 3 3 3 3 2
 2 2 2 2 3 2 3 2 3 3 0 2 2 2 2 2 2 0 2 2 3 2 2 2 3 0 2 3 2 2 2 2 2 2 0 2 2
 2 3 2 2 3 3 2 0 3 2 3 3 0 2 3 2 2 3 2 2 0 2 0 2 2 3 2 3 2 3 2 3 3 3 2 3 2
 2 1 2 2 3 2 3 3 2 2 3 2 0 2 2 3 3 3 3 3 3 2 3 3 2 2]
代码语言:javascript
复制
print(kmeans.cluster_centers_)
代码语言:javascript
复制
[[ 73.8071703   11.65878343  39.74780613   7.47642322 -11.02116106]
 [ 78.00734193  10.96868728 112.7133901    4.30306716  -2.60074766]
 [ 86.76754499  14.82524246   6.22982994   7.68770382  -0.54250094]
 [ 63.71074018  15.72378535   8.58366153   3.80791174  -1.27377464]]
代码语言:javascript
复制
tc_group['label'] = kmeans.labels_ + 1
代码语言:javascript
复制
tc_group['label'].value_counts()
代码语言:javascript
复制
3    121
4     59
1     25
2      6
Name: label, dtype: int64
代码语言:javascript
复制
tc_group
代码语言:javascript
复制
tc.set_index(['count'], inplace=True)
tc['label'] = tc_group['label']
代码语言:javascript
复制
tc

将结果可视化

代码语言:javascript
复制
# 分级设色
def get_color(w):
    if w <= 13.8:
        color='#FFFF00'
    elif 13.9<= w <= 17.1:
        color='#6495ED'
    elif 17.2 <= w <= 24.4:
        color='#3CB371'
    elif 24.5<= w <= 32.6:
        color='#FFA500'
    elif 32.7 <= w <= 61.3:
        color='#FF00FF'
    else:
        color='#DC143C'
    return color
代码语言:javascript
复制
def create_map(title, extent):
    fig = plt.figure(figsize=(12, 8), dpi=400)
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    ax.imshow(
        imread('/home/mw/input/natural_earth8967/NE1_50M_SR_W.tif'),
        origin='upper', 
        transform=ccrs.PlateCarree(),
        extent=[-180, 180, -90, 90]
    )
    ax.set_extent(extent,crs=ccrs.PlateCarree())

    gl = ax.gridlines(draw_labels=False, linewidth=1, color='k', alpha=0.5, linestyle='--')
    gl.top_labels = gl.right_labels = False  
    ax.set_xticks(np.arange(extent[0], extent[1]+5, 5))
    ax.set_yticks(np.arange(extent[2], extent[3]+5, 5))
    ax.xaxis.set_major_formatter(LongitudeFormatter())
    ax.xaxis.set_minor_locator(plt.MultipleLocator(1))
    ax.yaxis.set_major_formatter(LatitudeFormatter())
    ax.yaxis.set_minor_locator(plt.MultipleLocator(1))
    ax.tick_params(axis='both', labelsize=10, direction='out')

    province = shpreader.Reader('/home/mw/input/cn_shp3641/Province_9.shp')
    ax.add_geometries(province.geometries(), crs=ccrs.PlateCarree(), linewidths=0.1,edgecolor='k',facecolor='none')

    a = mlines.Line2D([],[],color='#FFFF00',marker='o',markersize=7, label='D',ls='')
    b = mlines.Line2D([],[],color='#6495ED', marker='o',markersize=7, label='DD',ls='')
    c = mlines.Line2D([],[],color='#3CB371', marker='o',markersize=7, label='CS',ls='')
    d = mlines.Line2D([],[],color='#FFA500', marker='o',markersize=7, label='SCS',ls='')
    e = mlines.Line2D([],[],color='#FF00FF', marker='o',markersize=7, label='VSCS',ls='')
    f = mlines.Line2D([],[],color='#DC143C', marker='o',markersize=7, label='SuperCS',ls='')
    ax.legend(handles=[a,b,c,d,e,f], numpoints=1, handletextpad=0, loc='upper left', shadow=True)
    plt.title(f'{title} Typhoon Track', fontsize=15)
    return ax

代码语言:javascript
复制
for label in tc_group['label'].value_counts().index[:]:
    tc_number = tc_group["label"].value_counts()[label]
    print(f'label:{label}, tc number:{tc_number}')
    
    one_type = tc[tc['label']==label].reset_index()
    ax = create_map(f'Type {label}', [40, 110, 0, 30])
    
    for num in one_type['count'].value_counts().index:
        df = one_type[one_type['count']==num]
        for i in range(len(df))[:1]:
            ax.scatter(list(df['x'])[i], list(df['y'])[i], marker='o', s=10, color='k')

        for i in range(len(df)-1):
            pointA = list(df['x'])[i],list(df['y'])[i]
            pointB = list(df['x'])[i+1],list(df['y'])[i+1]
            ax.add_geometries([sgeom.LineString([pointA, pointB])], color=get_color(list(df['speed'])[i+1]),crs=ccrs.PlateCarree(), linewidth=2)
    plt.savefig(f'track_type{label}_typhoon.png')
代码语言:javascript
复制
label:3, tc number:121
label:4, tc number:59
label:1, tc number:25
label:2, tc number:6

问题讨论

本次复现的工作其实并没有全部完成,在确定台风分类数量上只是随机选择了一个数,但实际文献中是给了一个确定分类个数的方法的:

在这里是抛砖引玉,感兴趣的盆友们可以自行fork本项目,添加后续解决方案。

❝参考文献:K-均值聚类法用于西北太平洋热带气旋路径分类 ❞

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-09-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 好奇心Log 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 导入相关库
  • 读取数据&相应预处理
  • 根据文献设置相应特征
  • 运行算法
  • 查看算法输出结果
  • 将结果可视化
  • 问题讨论
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档