前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于keras实现VGG-19网络的音频分类

基于keras实现VGG-19网络的音频分类

作者头像
深度学习与Python
发布2019-06-18 20:58:12
1.2K0
发布2019-06-18 20:58:12
举报

介绍

在这篇文章中,我将针对音频分类的问题。我将根据音频波形训练VGG-19的音频分类器。下边是整个项目的步骤和代码:

网络结构

VGG是牛津大学Visual Geometry Group研究机构的缩写。VGG在AlexNet基础上做了改进,整个网络都使用了同样大小的3*3卷积核尺寸和2*2最大池化尺寸,网络结构简洁。本次采用的VGG-19的详细说明可以参见其论文,具体结构如下图所示:

数据下载

首先从Youtube下载音频文件,我选择了我想要音频的youtube视频,然后我使用下面的代码来下载.mp3格式的音频文件。

from __future__ import unicode_literals 
import youtube_dl 

ydl_opts = { 'format' : 'bestaudio/best' , 
'postprocessors' : [{
'key' : 'FFmpegExtractAudio','preferredcodec':'mp3',
'preferredquality' : '192',}]} 
with youtube_dl.YoutubeDL(ydl_opts) as ydl: 
  ydl.download([<youtube video link>]) 
  # for bike sounds : https://www.youtube.com/watch?v=sRdRwHPjJPk 
  #    for    car sounds : https://www.youtube.com/watch?v=PPdNb-XQXR8 

将mp3转为wav格式

在下载完音频后,我们先将其转换为wav格式,方便我们后续的处理。

  from pydub import AudioSegment 
  sound = AudioSegment.from_mp3( "car.mp3" ) 
  sound.export( "car.wav" , format= "wav" ) 

特征提取

首先我们将音频切分成15s的音频块,具体代码如下:

  from pydub import AudioSegment 
  import os 
  if not os.path.exists( "bike" ): 
  os.makedirs( "bike" ) 
  count=1 
  for i in range(1,1000,15): 
     t1 = i * 1000 #Works in milliseconds 
     t2 = (i+15) * 1000 
     newAudio = AudioSegment.from_wav( "bikes.wav" ) 
     newAudio = newAudio[t1:t2] 
     newAudio.export( 'bike/' +str(count)+ '.wav' , format= "wav" ) #Exports to a wav file in the current path. 
     print(count) 
     count+=1 

然后我们将这些15s的音频块绘制出幅值波形图,并将其保存为图片为后续模型分类做好准备,具体代码如下:

  from scipy.io.wavfile import read 
  import matplotlib.pyplot as plt 
  from os import walk 
  import os 
  if not os.path.exists( "carPlots" ): 
     os.makedirs( "carPlots" ) 
     car_wavs = [] 
  for (_,_,filenames) in walk( 'car' ): 
     car_wavs.extend(filenames) 
     break 
  for car_wav in car_wavs: 
  # read audio samples 
     input_data = read( "car/" + car_wav) 
     audio = input_data[1] 
  # plot the first 1024 samples 
     plt.plot(audio) 
  # label the axes 
     plt.ylabel( "Amplitude" ) 
     plt.xlabel( "Time" ) 
  # set the title 
  # plt.title("Sample Wav") 
  # display the plot 
    plt.savefig( "carPlots/" + car_wav.split( '.' )[0] + '.png' ) 
  # plt.show() 
    plt.close( 'all' ) 

模型建立

在上一步中,我们已经提取好了特征,接下来就是搭建模型框架,本次我们使用的是VGG-19网络,具体网络结构参见上边网络可视化图。具体代码如下:

  import os 
  from keras.applications.vgg19 import VGG19 
  from keras.preprocessing import image 
  from keras.applications.vgg19 import preprocess_input 
  from keras.models import Model 
  import numpy as np 

  base_model = VGG19(weights= 'imagenet' ) 
  model = Model(inputs=base_model.input, outputs=base_model.get_layer( 'flatten' ).output) 

  def get_features(img_path): 
     img = image.load_img(img_path, target_size=(224, 224)) 
     x = image.img_to_array(img) 
     x = np.expand_dims(x, axis=0) 
     x = preprocess_input(x) 
     flatten = model.predict(x) 
     return list(flatten[0]) 

  X = [] 
  y = [] 

  car_plots = [] 
  for (_,_,filenames) in os.walk( 'carPlots' ): 
     car_plots.extend(filenames) 
     break 

  for cplot in car_plots: 
     X.append(get_features( 'carPlots/' + cplot)) 
     y.append(0) 

  bike_plots = [] 
  for (_,_,filenames) in os.walk( 'bikePlots' ): 
     bike_plots.extend(filenames) 
     break 

  for cplot in bike_plots: 
     X.append(get_features( 'bikePlots/' + cplot)) 
     y.append(1) 

训练与测试

现在我们已经到了模型的最后一步,利用我们处理的特征以及搭建的网络框架对模型进行训练,这里强烈推荐大家学习pandas、sklearn以及keras库,你会发现在机器学习中不可避免的会用到这几个库。

  from sklearn.model_selection import train_test_split 
  from sklearn.svm import LinearSVC 
  from sklearn.metrics import accuracy_score 

  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=42, stratify=y) 

  clf = LinearSVC(random_state=0, tol=1e-5) 
  clf.fit(X_train, y_train) 

  predicted = clf.predict(X_test) 

  # get the accuracy 
  print (accuracy_score(y_test, predicted)) 

总结

虽然这个网络结构比较简单,但仍然有97%的准确率。一方面是数据特征处理较好,另外也说明keras神经网络框架的强大。在我们已经训练的模型的基础上,如果我们能创建一个chrome扩展,在网页上实时对视频中的音频进行分类,感兴趣大家可以试一下。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-04-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 深度学习与python 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档