专栏首页AI异构神经网络架构搜索——可微分搜索(SGAS)​

神经网络架构搜索——可微分搜索(SGAS)​

神经网络架构搜索——可微分搜索(SGAS)

KAUST&Intel发表在CVPR 2020上的NAS工作,针对现有DARTS框架在搜索阶段具有高验证集准确率的架构可能在评估阶段表现不好的问题,提出了分解神经网络架构搜索过程为一系列子问题,SGAS使用贪婪策略选择并剪枝候选操作的技术,在搜索CNN和GCN网络架构均达到了SOTA。

  • Paper: SGAS: Sequential Greedy Architecture Search
  • Code: https://github.com/lightaime/sgas

动机

NAS技术都有一个通病:在搜索过程中验证精度较高,但是在实际测试精度却没有那么高。传统的基于梯度搜索的DARTS技术,是根据block构建更大的超网,由于搜索的过程中验证不充分,最终eval和test精度会出现鸿沟。从下图的Kendall系数来看,DARTS搜出的网络精度排名和实际训练完成的精度排名偏差还是比较大。

"Accuracy GAP"

方法

整体思路

本文使用与DARTS相同的搜索空间,SGAS搜索过程简单易懂,如下图所示。类似DARTS搜索过程为每条边指定参数α,超网训练时通过文中判定规则逐渐确定每条边的具体操作,搜索结束后即可得到最终模型。

SGAS架构示意图

算法伪代码

为了保证在贪心搜索的过程中能尽量保证搜索的全局最优性,进而引入了三个指标两个评估准则

三个指标

边的重要性

非零操作参数对应的softmax值求和,作为边的重要性衡量指标。

S_{E I}^{(i, j)}=\sum_{o \in \mathcal{O}, o \neq z e r o} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}
alphas = []
for i in range(4):
    for n in range(2 + i):
        alphas.append(Variable(1e-3 * torch.randn(8)))
# alphas经过训练后
mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach() # mat为14*8维度的二维列表,softmax归一化。 
EI = torch.sum(mat[:, 1:], dim=-1) # EI为14个数的一维列表,去掉none后的7个ops对应alpha值相加
选择的准确性

计算操作分布的标准化熵,熵越小确定性越高;熵越高确定性越小。

\begin{array}{c} p_{o}^{(i, j)}=\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{S_{E I}^{(i, j)} \sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}, o \in \mathcal{O}, o \neq z e r o \\ S_{S C}^{(i, j)}=1-\frac{-\sum_{o \in \mathcal{O}, o \neq z e r o} p_{o}^{(i, j)} \log \left(p_{o}^{(i, j)}\right)}{\log (|\mathcal{O}|-1)} \end{array}
import torch.distributions.categorical as cate
probs = mat[:, 1:] / EI[:, None]
entropy = cate.Categorical(probs=probs).entropy() / math.log(probs.size()[1])
SC = 1-entropy
选择的稳定性

将历史信息纳入操作分布评估,使用直方图交叉核计算平均选择稳定性。直方图交叉核的原理详见(https://blog.csdn.net/hong__fang/article/details/50550656)。

S_{S S}^{(i, j)}=\frac{1}{K} \sum_{t=T-K}^{T-1} \sum_{o_{t} \in \mathcal{O}, o_{t} \neq z e r o} \min \left(p_{o_{t}}^{(i, j)}, p_{o_{T}}^{(i, j)}\right)
def histogram_intersection(a, b):
  c = np.minimum(a.cpu().numpy(),b.cpu().numpy())
  c = torch.from_numpy(c).cuda()
  sums = c.sum(dim=1)
  return sums

def histogram_average(history, probs):
  histogram_inter = torch.zeros(probs.shape[0], dtype=torch.float).cuda()
  if not history:
    return histogram_inter
  for hist in history:
    histogram_inter += utils.histogram_intersection(hist, probs)
  histogram_inter /= len(history)
  return histogram_inter

probs_history = []

probs_history.append(probs)
if (len(probs_history) > args.history_size):
  probs_history.pop(0)
  
histogram_inter = histogram_average(probs_history, probs)

SS = histogram_inter

两种评估准则

评估准则1:

选择具有高边缘重要性和高选择确定性的操作

S_{1}^{(i, j)}=\text { normalize }\left(S_{E I}^{(i, j)}\right) * \text { normalize }\left(S_{S C}^{(i, j)}\right)
def normalize(v):
  min_v = torch.min(v)
  range_v = torch.max(v) - min_v
  if range_v > 0:
    normalized_v = (v - min_v) / range_v
  else:
    normalized_v = torch.zeros(v.size()).cuda()

  return normalized_v

score = utils.normalize(EI) * utils.normalize(SC)
评估准则2:

在评估准则1的基础上,加入考虑选择稳定性

S_{2}^{(i, j)}=S_{1}^{(i, j)} * \text { normalize }\left(S_{S S}^{(i, j)}\right)
score = utils.normalize(EI) * utils.normalize(SC) * utils.normalize(SS)

实验结果

CIFAR-10(CNN)

CIFAR-10(CNN)

ImageNet(CNN)

ImageNet(CNN)

ModelNet40(GCN)

ModelNet40(GCN)

PPI(GCN)

PPI(GCN)

参考

[1] Li, Guohao et al. ,SGAS: Sequential Greedy Architecture Search

[2] https://zhuanlan.zhihu.com/p/134294068

[3] 直方图交叉核 https://blog.csdn.net/hong__fang/article/details/50550656

本文分享自微信公众号 - AI异构(gh_ed66a0ffe20a),作者:许柯

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-05-27

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • DARTS-:增加辅助跳连,走出搜索性能崩溃

    尽管可微分架构搜索(DARTS)发展迅速,但它长期存在性能不稳定的问题,这极大地限制了它的应用。现有的鲁棒性方法是从由此产生的恶化行为中获取线索,而不是找出其原...

    AI异构
  • HAWQ-V2:基于Hessian迹的混合比特量化策略

    量化是减少神经网络的内存占用和推理时间的有效方法。但是,超低精度量化可能会导致模型精度显着下降。解决此问题的一种有前途的方法是执行混合精度量化,其中更敏感的层保...

    AI异构
  • GOLD-NAS:针对神经网络可微分架构搜索的一次大手术

    DARTS的搜索空间非常有限,例如,对于每个边保留了一个运算符,每个节点固定接收两个前继输入,等等。这些约束有利于NAS搜索的稳定性,但它们也缩小了强大的搜索方...

    AI异构
  • 涨知识|Google语法快速高效的搜索

    在日常生活中我们经常会用到Google、百度这样的搜索引擎。但是对于大多数的用户来说搜索的效率远远达不到预期的效果。所以为了提高搜索的效率我们需要学习一些常用的...

    算法与编程之美
  • Django使用websocket实现实时消息推送和聊天

    WebSocket 是 HTML5 开始提供的一种在单个 TCP 连接上进行全双工通讯的协议。

    菲宇
  • 机器学习性能评价指标汇总

    AUC 是 ROC (Receiver Operating Characteristic) 曲线以下的面积, 介于0.1和1之间。Auc作为数值可以直观的评价分...

    莫斯
  • spring rest 容易被忽视的后端服务 chunked 性能问题

    spring boot 容易被忽视的后端服务 chunked 性能问题 标签(空格分隔): springboot springmvc chunked 背景 sp...

    王清培
  • 速读原著-TCP/IP(UDP和ARP之间的交互作用)

    使用U D P,可以看到U D P与A R P典型实现之间的有趣的(而常常未被人提及)交互作用。我们用s o c k程序来产生一个包含8 1 9 2字节数据的U...

    cwl_java
  • 优秀的代码都是如何分层的?

    说起应用分层,大部分人都会认为这个不是很简单嘛 就controller,service, mapper三层。看起来简单,很多人其实并没有把他们职责划分开,在很多...

    程序员小明
  • 【安富莱专题教程第7期】终极调试组件Event Recorder,各种Link通吃,支持时间和功耗测量,printf打印,RTX5及中间件调试

    说明: 1、继前面的专题教程推出SEGGER的RTT,JScope,Micrium的uC/Probe之后,再出一期终极调试方案Event Recoder,之所以...

    armfly

扫码关注云+社区

领取腾讯云代金券