首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >Orion-MSP:深度学习终于在表格数据上超越了XGBoost

Orion-MSP:深度学习终于在表格数据上超越了XGBoost

作者头像
deephub
发布2025-11-15 11:57:49
发布2025-11-15 11:57:49
490
举报
文章被收录于专栏:DeepHub IMBADeepHub IMBA

点击上方“Deephub Imba”,关注公众号,好文章不错过 !

表格数据一直是深度学习的老大难问题。这些年CV和NLP领域被Transformer统治得服服帖帖,但在真正的业务场景里,面对表格这类的结构化数据,XGBoost这些梯度提升树还是稳坐钓鱼台。

为什么会这样?问题其实很简单。图像的像素排列有空间位置关系,文本有上下文顺序,但表格里的列是啥顺序都行——年龄放第一列和放最后一列没区别。而且这些列的类型完全不同:有数值、有类别,有的服从正态分布有的严重偏态。同样是数字50,在年龄列和交易量列意义天差地别。

ArXiv上最近新有篇论文叫"Orion-MSP: Multi-Scale Sparse Attention for Tabular In-Context Learning",来自Lexsi Labs的团队,算是正面解决了这个问题。

上下文学习这条路走得通但是有坎

最近这两年,受大语言模型启发,研究者开始尝试给表格数据做foundation model。核心想法是in-context learning(ICL)——不用针对每个新数据集重新训练,直接给模型看几个样本示例,它就能推断出任务模式。

TabPFN和TabICL是这方面的先驱。它们在海量合成数据集上做meta-training,让Transformer学会表格数据的一般规律。理想情况是让一个模型打天下,新来个表格数据,喂几个标注样本就能zero-shot分类。对AutoML来说这简直是梦想场景。

但第一代模型撞上了三堵墙:

单一尺度的视野太窄。这些模型用统一的粒度处理所有特征。就像你盯着照片看,只能选一个固定距离——凑近了看到线头,但看不出整体是件毛衣;退远了知道是毛衣,但抓不到细节。真实数据的结构是多层次的:底层是单个特征的交互(比如年龄和收入的关系),中层是特征组(人口统计信息这一块),顶层是大的数据分区(个人属性 vs 行为数据),单尺度模型对这种层次结构基本是盲的。

O(m²)的计算瓶颈卡死了宽表。标准的dense attention让每个特征token关注所有其他token,对于m个特征,复杂度是O(m²)。几十上百个特征还能扛,但基因组数据、金融衍生品、传感器阵列这种动辄上千特征的场景就彻底歇菜了,内存爆掉是常事。

信息只能单向流动。TabICL这类模型的架构是流水线式的:先embedding列,再建模行间关系,最后ICL预测。下游发现的模式(比如数据集层面的统计特性)没法反馈回去优化上游的表示。这就很浪费。

Orion-MSP针对这三个问题给出了对应的解法。

三个关键创新点

多尺度处理是第一个。Orion-MSP同时在多个粒度上处理特征——假如一行有64个特征,它会并行地看:全部64个单独特征(scale 1)、16组每组4个特征(scale 4)、4组每组16个特征(scale 16)。细粒度抓个体交互,粗粒度抓语义块的关系,就像同时用不同焦距的镜头拍摄。

块稀疏注意力解决效率问题。借鉴了NLP里Longformer的做法,用structured block-sparse attention替换dense attention。通过结合局部滑动窗口(相邻特征互相看得见)、全局token(专门负责长距离信息传递)、随机连接(保持网络表达能力),复杂度从O(m²)降到接近O(m·log(m)),这个改进算是巨大了。

Perceiver式的跨组件内存实现双向信息流。这个设计更巧妙:先让训练样本把信息"写入"一组可学习的latent vectors(可以理解成一个共享的备忘录),然后所有样本(包括测试集)都能从这个备忘录"读取"信息来增强自己的表示。而且写和读严格分离——测试数据只能读不能写,这样就不会违反ICL的因果约束,不存在数据泄露问题。

这三个部分不是独立的补丁,而是协同工作的系统。稀疏注意力让多尺度计算变得可行,Perceiver内存让不同尺度、不同组件的信息能安全地整合起来。

Orion-MSP的整体架构。输入表先做column-wise embedding得到E,然后多尺度稀疏行交互模块在不同粒度(1/4/16)上用稀疏attention处理特征,产生行embedding H。接着跨组件Perceiver内存模块实现双向通信:训练行写内存,所有行读内存得到增强表示R。最后ICL head一次前向传播预测测试标签。

架构细节

我们从头捋一遍流程。

第一步:列的distributional embedding

跟TabICL一样,Orion-MSP用Set Transformer给每列做embedding。这步很关键,因为单个cell的值脱离了列的分布就没意义。Set Transformer把每列当作无序集合,学习该列在训练集上的分布摘要,然后用这个摘要给每个cell生成context-aware的embedding。所以均值45的列里的50和均值500的列里的50,embedding完全不同。

第二步:多尺度稀疏行交互

拿到cell embedding之后要建模行内特征的关系。假设一行64个特征,Orion-MSP并行地在三个尺度上处理:

Scale 1看全部64个独立特征;Scale 4把特征分成16个块,每块4个,看块与块的关系;Scale 16分成4个大块,每块16个,做高层推理。

每个尺度用的都是block-sparse attention。

注意力机制的构成。白色表示没有attention。(a)特殊attention,包括CLS=4和global attention GB=4;(b)滑动窗口attention,w=8;(c)随机attention,r=2;(d)Orion-MSP的组合行表示。

这个稀疏模式保证了局部交互(滑动窗口)、长程依赖(global tokens)和网络表达力(随机连接)的平衡。最后把所有尺度的表示aggregate起来,得到每行的最终embedding。

代码逻辑大概是这样:

代码语言:javascript
复制
 // Algorithm 1: Multi-Scale Sparse Row-Wise Interaction (Simplified)
function MultiScaleSparseAttention(E, scales=[1, 4, 16]):
  all_scale_outputs = []

  for scale in scales:
    // 1. Group features into blocks of size 'scale'
    grouped_features = GroupFeatures(E, size=scale)

    // 2. Prepend special CLS and GLOBAL tokens
    sequence = [CLS, GLOBAL, ...grouped_features]

    // 3. Build the sparse attention mask
    //    - GLOBAL tokens attend to everything
    //    - Other tokens use sliding window + random links
    sparse_mask = BuildBlockSparseMask(sequence_length)

    // 4. Process with a Transformer encoder using the sparse mask
    processed_sequence = TransformerEncoder(sequence, mask=sparse_mask)

    // 5. Extract the output CLS token, which summarizes the row at this scale
    scale_output = processed_sequence[CLS_token_position]
    all_scale_outputs.append(scale_output)

  // 6. Aggregate the outputs from all scales (e.g., by averaging)
  final_row_embedding = Aggregate(all_scale_outputs)

   return final_row_embedding

Transformer encoder这步因为用了稀疏mask,复杂度是O(m * window_size)而不是O(m²)。位置编码用的RoPE,帮助模型理解特征在序列中的相对位置。

第三步:Perceiver内存做迭代refinement

行embedding现在已经包含了多尺度信息但还能更进一步。Cross-Component Perceiver Memory模块的工作方式:

写阶段(只有训练样本参与):训练样本的行embedding去"写"一组learnable latent vectors。这个过程把训练集的核心模式压缩成一个summary。

读阶段(所有样本):latent memory被冻结,然后所有样本(训练+测试)的embedding都去"读"这个memory,通过cross-attention获取全局context来refine自己的表示。

测试样本能利用训练集的全局信息,但不会反向影响训练表示。因果约束得到严格保证。

代码语言:javascript
复制
 // Algorithm 2: ICL with Perceiver Memory (Simplified)
function PerceiverMemoryRefinement(H_all_samples, H_train_samples):
  // 1. Initialize a learnable latent memory (the "cheat sheet")
  latent_memory = InitializeMemory()

  // --- WRITE PHASE (TRAIN ONLY) ---
  // 2. The memory attends to the training samples to encode global patterns
  for i in 1..N_write_layers:
    latent_memory = CrossAttention(query=latent_memory, key_value=H_train_samples)

  // At this point, latent_memory is a summary of the training set. It is now frozen.

  // --- READ PHASE (ALL SAMPLES) ---
  // 3. All samples attend to the memory to enrich their representations
  refined_embeddings = H_all_samples
  for i in 1..N_read_layers:
    refined_embeddings = CrossAttention(query=refined_embeddings, key_value=latent_memory)

   return refined_embeddings

这个refined representation R既有行本身的信息,又融入了训练集的distributional knowledge,预测自然更稳。

第四步:split-masked Transformer做zero-shot预测

最后refined embeddings进ICL prediction head。这里用标准Transformer但加了split attention mask来enforce ICL规则:

训练样本可以互相attend;测试样本可以attend训练样本(学任务)和其他测试样本(利用query set的pattern);训练样本绝对不能attend测试样本。

然后一次forward pass输出测试label。没有gradient更新,纯inference。

实验结果

作者在三个主要benchmark上测试了Orion-MSP:TALENT、OpenML-CC18、TabZilla,几百个不同的数据集,对手包括XGBoost、CatBoost这些传统方法,还有TabPFN、TabICL、TabDPT这些新的foundation models。

三个benchmark suite的性能对比。Rank是mean rank(越小越好)。Metrics包括准确率(ACC)和加权F1。"All"列是所有suite的汇总rank。第一名和第二名用不同格式标注。

Orion-MSP拿到了3.58的overall zero-shot rank,所有benchmark里最好。准确率和F1上持续match或超过TabPFN和TabICL。

高维数据的碾压优势

按特征数量分组看性能,差异就出来了。

按特征维度(数据集宽度)的性能变化。ACC是准确率,F1是加权F1分数,范围0-1越高越好。模型按adaptation策略分组。每组内第一名第二名有格式标记。

窄表和中等宽度表上大家都还行,但到了宽表(100+特征),dense attention模型的O(m²)复杂度就成了致命伤。很多Transformer-based的模型直接OOM崩掉。Orion-MSP的稀疏attention让它在这个区间依然保持强劲性能。

金融和医疗领域表现突出

在数据天然具有层次结构的领域,多尺度架构的优势更明显。

医疗和金融数据集的性能。Rank是域内mean rank(越低越好)。ACC和F1都是0-1范围,越高越好。

医疗数据集上准确率0.8045最高。医疗数据本来就是分层的:实验室检查、生命体征、人口学信息,多尺度架构正好match这种结构。

金融数据集上mean rank 4.60排第一。金融数据也是多层次的:市场指标、工具属性、宏观经济因素,Perceiver memory帮忙整合不同scale和context的信息效果很好。

模型在不平衡的数据集上上表现也不错。多尺度attention似乎能放大minority class的信号——细粒度scale捕捉少数类的subtle pattern,粗粒度scale提供global context防止对多数类过拟合。

为什么这个工作重要

Orion-MSP不只是刷了个榜,它代表了表格数据建模思路上的转变。从单一尺度、dense attention的架构,转向hierarchical、efficient、context-aware的设计。

这也说明表格数据这个战场还没打完。但Orion-MSP至少证明了,深度学习如果properly designed,是可以在结构化数据上超越传统方法的。关键是要respect数据本身的结构特点,设计既powerful又efficient的架构。

总结

之前的tabular foundation models被三个问题限制住了——单尺度处理看不到层次结构,O(m²)的dense attention在宽表上爆炸,单向信息流浪费了context。

Orion-MSP通过多尺度处理捕获不同粒度的特征交互;块稀疏attention把复杂度降到接近线性;Perceiver-style memory实现ICL-safe的双向信息共享。

作者自己承认,在非常简单的低维数据集上,Orion-MSP的复杂架构优势不明显。小表格可能简单模型就够了。不过这个论文可以说是很炸裂了,能比XGBoost效果要好的话应该有点说法。

论文地址:https://arxiv.org/abs/2511.02818


喜欢就关注一下吧:

点个 在看 你最好看!

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-11-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 DeepHub IMBA 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 上下文学习这条路走得通但是有坎
  • 三个关键创新点
  • 架构细节
    • 第一步:列的distributional embedding
    • 第二步:多尺度稀疏行交互
    • 第三步:Perceiver内存做迭代refinement
    • 第四步:split-masked Transformer做zero-shot预测
  • 实验结果
    • 高维数据的碾压优势
    • 金融和医疗领域表现突出
  • 为什么这个工作重要
  • 总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档