专栏首页相约机器人使用PyTorch进行表格数据的深度学习

使用PyTorch进行表格数据的深度学习

作者 | Aakanksha NS

来源 | Medium

编辑 | 代码医生团队

使用表格数据进行深度学习的最简单方法是通过fast-ai库,它可以提供非常好的结果,但是对于试图了解幕后实际情况的人来说,它可能有点抽象。因此在本文中,介绍了如何在Pytorch中针对多类分类问题构建简单的深度学习模型来处理表格数据。

Pytorch是一个流行的开源机器库。它像Python一样易于使用和学习。使用PyTorch的其他一些优势是其多GPU支持和自定义数据加载器。

代码

https://jovian.ml/aakanksha-ns/shelter-outcome

数据集

https://www.kaggle.com/c/shelter-animal-outcomes/data

它是一个表格数据集,由训练集中的约26k行和10列组成。除以外的所有列DateTime都是分类的。

训练样本数据

问题陈述

根据保护动物的某些特征(例如年龄,性别,肤色,品种),预测其结果。

有5种可能的结果:Return_to_owner, Euthanasia, Adoption, Transfer, Died。期望找到动物的结局属于5类中每一种的概率。

数据预处理

尽管此步骤很大程度上取决于特定的数据和问题,但仍需要遵循两个必要的步骤:

摆脱Nan价值观:

Nan(不是数字)表示数据集中缺少值。该模型不接受Nan值,因此必须删除或替换它们。

对于数字列,一种常见的处理这些值的方法是使用剩余数据的0,均值,中位数,众数或其他某种函数来估算它们。缺失值有时可能表示数据集中的基础特征,因此人们经常创建一个新的二进制列,该列与具有缺失值的列相对应,以记录数据是否缺失。

对于分类列,Nan可以将值视为自己的类别!

标签编码所有分类列:

由于模型只能接受数字输入,因此将所有分类元素都转换为数字。这意味着使用数字代替使用字符串来表示类别。选择用来表示列中任何类别的数字并不重要,因为稍后将使用分类嵌入来进一步编码这些类别。这是标签编码的一个简单示例:

使用了LabelEncoderscikit-learn库中的类对分类列进行编码。可以定义一个自定义类来执行此操作并跟踪类别标签,因为也需要它们对测试数据进行编码。

标签编码目标:

如果目标具有字符串条目,还需要对目标进行标签编码。另外请确保维护一个字典,将编码映射到原始值,因为将需要它来找出模型的最终输出。

保护成果问题特有的数据处理:

除了上述步骤,还对示例问题进行了更多处理。

  • 删除了该AnimalID列,因为它是唯一的,不会对训练有所帮助。
  • 删除了该OutcomeSubtype列,因为它是目标的一部分,但并没有要求对其进行预测。
  • 已删除DateTime列,因为输入记录的确切时间戳似乎不是一项重要功能。实际上,首先尝试将其拆分为单独的月份和年份列,但后来意识到完全删除该列会带来更好的结果!
  • 已删除Name列,因为该列中的Nan值太多(缺少10k以上)。同样,在确定动物的结局方面,这似乎不是一个非常重要的特征。

注意:在NoteBook中,堆叠了train和test列,然后进行了预处理以避免基于测试集上的train set标签进行标签编码(因为这将涉及维护编码标签到实际值的字典) 。可以在此处进行堆栈和处理,因为没有数字列(因此无需进行插补),并且每列的类别数是固定的。实际上,绝对不能这样做,因为它可能会将某些数据从测试/验证集中泄漏到训练数据中,并导致模型评估不准确。例如如果数字列中缺少值,例如age 并决定使用平均值来推算该平均值,则平均值应仅在训练集合(而不是堆叠的训练测试有效集合)上计算,并且该值也应用于推算验证和测试集中的缺失值。

分类嵌入

分类嵌入与NLP中常用的词嵌入非常相似。基本思想是在列中具有每个类别的固定长度矢量表示。这与单次编码的不同之处在于,使用嵌入而不是使用稀疏矩阵,而是为每个类别获得了一个密集矩阵,其中相似类别的值在嵌入空间中彼此接近。因此,此过程不仅节省了内存(因为具有太多类别的列的一键编码实际上会炸毁输入矩阵,而且它是非常稀疏的矩阵),而且还揭示了分类变量的内在属性。

例如,如果有一列颜色,并且找到了它的嵌入,则可以期望red并且pink在嵌入空间中,该距离比red和更近。blue

分类嵌入层等效于每个单编码输入上方的额外层:

资料来源:分类变量的实体嵌入研究论文

对于保护所结果问题,只有分类列,但将考虑少于3个值的列为连续列。为了确定每一列嵌入向量的长度,从fast-ai库中获取了一个简单的函数:

#categorical embedding for columns having more than two values
emb_c = {n: len(col.cat.categories) for n,col in X.items() if len(col.cat.categories) > 2}
emb_cols = emb_c.keys() # names of columns chosen for embedding
emb_szs = [(c, min(50, (c+1)//2)) for _,c in emb_c.items()] #embedding sizes for the chosen columns

Pytorch数据集和DataLoader

扩展了DatasetPytorch提供的(抽象)类,以便在训练时更轻松地访问数据集并有效使用DataLoader模块来管理批次。这涉及根据特定数据集覆盖__len__和__getitem__方法。

由于只需要嵌入分类列,因此将输入分为两部分:数字部分和分类部分。

class ShelterOutcomeDataset(Dataset):
    def __init__(self, X, Y, emb_cols):
        X = X.copy()
        self.X1 = X.loc[:,emb_cols].copy().values.astype(np.int64) #categorical columns
        self.X2 = X.drop(columns=emb_cols).copy().values.astype(np.float32) #numerical columns
        self.y = Y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X1[idx], self.X2[idx], self.y[idx]

然后,选择批处理大小,并将其与数据集一起馈入DataLoader。深度学习通常是分批进行的。DataLoader帮助在训练之前有效地管理这些批次并重新整理数据。

#creating train and valid datasets
train_ds = ShelterOutcomeDataset(X_train, y_train, emb_cols)
valid_ds = ShelterOutcomeDataset(X_val, y_val, emb_cols)
 
batch_size = 1000
train_dl = DataLoader(train_ds, batch_size=batch_size,shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=batch_size,shuffle=True)

要进行健全性检查,可以遍历创建的DataLoader以查看每个批次:

模型

数据分为连续的和分类的部分。首先根据先前确定的大小将分类部分转换为嵌入向量,然后将它们与连续部分连接起来,以馈送到网络的其余部分。这张照片演示了使用的模型:

class ShelterOutcomeModel(nn.Module):
    def __init__(self, embedding_sizes, n_cont):
        super().__init__()
        self.embeddings = nn.ModuleList([nn.Embedding(categories, size) for categories,size in embedding_sizes])
        n_emb = sum(e.embedding_dim for e in self.embeddings) #length of all embeddings combined
        self.n_emb, self.n_cont = n_emb, n_cont
        self.lin1 = nn.Linear(self.n_emb + self.n_cont, 200)
        self.lin2 = nn.Linear(200, 70)
        self.lin3 = nn.Linear(70, 5)
        self.bn1 = nn.BatchNorm1d(self.n_cont)
        self.bn2 = nn.BatchNorm1d(200)
        self.bn3 = nn.BatchNorm1d(70)
        self.emb_drop = nn.Dropout(0.6)
        self.drops = nn.Dropout(0.3)
        
 
    def forward(self, x_cat, x_cont):
        x = [e(x_cat[:,i]) for i,e in enumerate(self.embeddings)]
        x = torch.cat(x, 1)
        x = self.emb_drop(x)
        x2 = self.bn1(x_cont)
        x = torch.cat([x, x2], 1)
        x = F.relu(self.lin1(x))
        x = self.drops(x)
        x = self.bn2(x)
        x = F.relu(self.lin2(x))
        x = self.drops(x)
        x = self.bn3(x)
        x = self.lin3(x)
        return x

训练

现在在训练集上训练模型。使用了Adam优化器来优化交叉熵损失。训练非常简单:遍历每批,进行前向遍历,计算梯度,进行梯度下降,并根据需要重复此过程。可以看一下NoteBook以了解代码。

https://jovian.ml/aakanksha-ns/shelter-outcome

测试输出

由于有兴趣查找测试输入的每个类别的概率,因此在模型输出上应用Softmax函数。还进行了Kaggle提交,以查看此模型的性能如何:

仅进行了很少的功能工程和数据探索,并使用了非常基础的深度学习架构,但模型完成了约50%的解决方案。这表明使用神经网络对表格数据建模的这种方法非常强大!

本文分享自微信公众号 - 相约机器人(xiangyuejiqiren),作者:代码医生

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-01-13

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 在点云上进行深度学习:在Google Colab中实现PointNet

    3D数据对于自动驾驶汽车,自动驾驶机器人,虚拟现实和增强现实至关重要。与以像素阵列表示的2D图像不同,它可以表示为多边形网格,体积像素网格,点云等。

    代码医生工作室
  • 结合知识图谱实现基于电影的推荐系统

    知识图谱(Knowledge Graph,KG)可以理解成一个知识库,用来存储实体与实体之间的关系。知识图谱可以为机器学习算法提供更多的信息,帮助模型更好地完成...

    代码医生工作室
  • 如何构建PyTorch项目

    自从开始训练深度神经网络以来,一直在想所有Python代码的结构是什么。理想情况下,良好的结构应支持对该模型进行广泛的试验,允许在一个紧凑的框架中实现各种不同的...

    代码医生工作室
  • Day22psutil&图形界面

    psutil 用Python来编写脚本简化日常的运维工作是Python的一个重要用途。 在Python中获取系统信息的一个好办法是使用psutil这个第三方模...

    林清猫耳
  • AlphaGo Zero你也来造一只,PyTorch实现五脏俱全| 附代码

    遥想当年,AlphaGo的Master版本,在完胜柯洁九段之后不久,就被后辈AlphaGo Zero (简称狗零) 击溃了。

    量子位
  • 【NLP保姆级教程】手把手带你HAN文本分类(附代码)

    今天来看看网红Attention的效果,来自ACL的论文Hierarchical Attention Networks for Document Classif...

    NewBeeNLP
  • 中文NLP笔记:11. 基于 LSTM 生成古诗

      在每行末尾加上 ] 符号是为了标识这首诗已经结束,说明 ] 符号之前的语句和之后的语句是没有关联关系的,后面会舍弃掉包含 ] 符号的训练数据。

    杨熹
  • python3 函数迭代器

    迭代器协议,是指对象(实例)能够使用next函数获取下一项数据,在没有下一项数据之前触发一个StopIteration异常来终止迭代

    py3study
  • 散列表结构 字典与集合

    散列表(Hash Table)结构是字典(Dictionary)和集合(Set)的一种实现方式。散列算法的作用是尽可能快地在数据结构中找到一个值。在散列表上插入...

    py3study
  • python pyqt5 pandas处理数据

    """ Module implementing MainWindow. """

    用户5760343

扫码关注云+社区

领取腾讯云代金券