Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >TiDB v5.1 体验: 我用 TiDB 训练了一个机器学习模型

TiDB v5.1 体验: 我用 TiDB 训练了一个机器学习模型

原创
作者头像
PingCAP
修改于 2021-10-21 10:03:36
修改于 2021-10-21 10:03:36
94500
代码可运行
举报
文章被收录于专栏:PingCAP的专栏PingCAP的专栏
运行总次数:0
代码可运行

作者介绍

韩明聪,TiDB Contributor,上海交通大学 IPADS 实验室博士研究生,研究方向为系统软件。本文主要介绍了如何在 TiDB 中使用纯 SQL 训练一个机器学习模型。

前言

众所周知,TiDB 5.1 版本增加了很多新特性,其中有一个特性,即 ANSI SQL 99 标准中的 Common Table Expression (CTE)。一般来说,CTE 可以被用作一个 Statement 作用于临时的 View,将一个复杂的 SQL 解耦,提高开发效率。但是,CTE 还有一个重要的使用方式,即 Recursive CTE,允许 CTE 引用自身,这是完善 SQL 功能的最后一块核心的拼图。

在 StackOverflow 中有过这样一个讨论 “Is SQL or even TSQL Turing Complete”,其中点赞最多的回复中提到这样一句话:

“ In this set of slides Andrew Gierth proves that with CTE and Windowing SQL is Turing Complete, by constructing a cyclic tag system, which has been proved to be Turing Complete. The CTE feature is the important part however – it allows you to create named sub-expressions that can refer to themselves, and thereby recursively solve problems.”

即 CTE 和 Window Function 甚至使得 SQL 成为一个图灵完备的语言。

而这又让我想起来多年前看到过的一篇文章 Deep Neural Network implemented in pure SQL over BigQuery,作者使用纯 SQL 来实现了一个 DNN 模型,但是打开 repo 后发现,他竟然是标题党!实际上他还是使用了 Python 来实现迭代训练。

因此,既然 Recursive CTE 给了我们 “迭代” 的能力,这让我想挑战一下,能否在 TiDB 中使用纯 SQL 实现机器学习模型的训练、推理

Iris Dataset

首先要选择一个简单的机器学习模型和任务,我们先尝试 sklearn 中的入门数据集 iris dataset。这个数据集共包含 3 类 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这 4 个特征预测鸢尾花卉属于 iris-setosa,iris-versicolour,iris-virginica 中的哪一品种。

当下载好数据后(已经是 CSV 格式),我们先将数据导入到 TiDB 中。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> create table iris(sl float, sw float, pl float, pw float, type  varchar(16));
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> LOAD DATA LOCAL INFILE 'iris.csv' INTO  TABLE iris FIELDS  TERMINATED  BY ',' LINES  TERMINATED  BY  '\n' ;
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> select * from iris limit 10;+------+------+------+------+-------------+| sl   | sw   | pl   | pw   | type        |+------+------+------+------+-------------+|  5.1 |  3.5 |  1.4 |  0.2 | Iris-setosa ||  4.9 |    3 |  1.4 |  0.2 | Iris-setosa ||  4.7 |  3.2 |  1.3 |  0.2 | Iris-setosa ||  4.6 |  3.1 |  1.5 |  0.2 | Iris-setosa ||    5 |  3.6 |  1.4 |  0.2 | Iris-setosa ||  5.4 |  3.9 |  1.7 |  0.4 | Iris-setosa ||  4.6 |  3.4 |  1.4 |  0.3 | Iris-setosa ||    5 |  3.4 |  1.5 |  0.2 | Iris-setosa ||  4.4 |  2.9 |  1.4 |  0.2 | Iris-setosa ||  4.9 |  3.1 |  1.5 |  0.1 | Iris-setosa |+------+------+------+------+-------------+10 rows in set (0.00 sec)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> select type, count(*) from iris group by type;+-----------------+----------+| type            | count(*) |+-----------------+----------+| Iris-versicolor |       50 || Iris-setosa     |       50 || Iris-virginica  |       50 |+-----------------+----------+3 rows in set (0.00 sec)

Softmax Logistic Regression

这里我们选择一个简单的机器学习模型 —— Softmax 逻辑回归,来实现多分类。(以下的图与介绍均来自百度百科

在 Softmax 回归中将 x 分类为类别 y 的概率为:

代价函数为:

可以求得梯度:

因此可以通过梯度下降方法,每次更新梯度:

Model Inference

我们先写一个 SQL 来实现 Inference,根据上面定义的模型和数据,输入的数据 X 共有五维(sl, sw, pl, pw 以及一个常数 1.0),输出使用 one-hot 编码。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> create table data(    x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30),         y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30));
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql>insert into dataselect    sl, sw, pl, pw, 1.0,     case when type='Iris-setosa'then 1 else 0 end,    case when type='Iris-versicolor'then 1 else 0 end,      case when type='Iris-virginica'then 1 else 0 endfrom iris;

参数共有 3 类 * 5 维 = 15 个:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> create table weight(    w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30),    w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30),    w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));

先全部初始化为 0.1,0.2,0.3(这里选择不同的数字是为了方便演示,也可以全部初始化为0.1):

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> insert into weight values (    0.1, 0.1, 0.1, 0.1, 0.1,    0.2, 0.2, 0.2, 0.2, 0.2,    0.3, 0.3, 0.3, 0.3, 0.3);

下面我们写一个 SQL 来统计对所有的 Data 进行 Inference 后结果的准确率。

为了方便理解,我们先给一个伪代码描述这个过程:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
weight = (       w00, w01, w02, w03, w04,    w10, w11, w12, w13, w14,    w20, w21, w22, w23, w24)for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:    exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)    exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)    exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)    sum_exp = exp0 + exp1 + exp2    // softmax    p0 = exp0  sum_exp    p1 = exp1  sum_exp    p2 = exp2  sum_exp    // inference result    r0 = p0 > p1 and p0 > p2     r1 = p1 > p0 and p1 > p2    r2 = p2 > p0 and p2 > p1         data.correct = (y0 == r0 and y1 == r1 and y2 == r2)return sum(Data.correct)  count(Data)

在上述代码中,我们对 Data 中的每一行元素进行计算,首先求三个向量点乘的 exp,然后求 softmax,最后选择 p0, p1, p2 中最大的为 1,其余为 0,这样就完成了一个样本的 Inference。如果一个样本最后 Inference 的结果与它本来的分类一致,那就是一次正确的预测,最后我们对所有样本中正确的数量求和,即可得到最后的正确率。

下面给出 SQL 的实现,我们选择把 data 中的每一行数据都和 weight (只有一行数据) join 起来,然后计算每一行数据的 Inference 结果,再对正确的样本数量求和:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
select sum(y0 = r0 and y1 = r1 and y2 = r2)  count(*)from    (select        y0, y1, y2,        p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2    from        (select             y0, y1, y2,            e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1,  e2/(e0+e1+e2) as p2        from            (select                  y0, y1, y2,                 exp(                     w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4                 ) as e0,                 exp(                     w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4                 ) as e1,                 exp(                     w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4                  ) as e2             from data, weight) t1        )t2    )t3;

可以看到上述 SQL 几乎是按步骤实现了上述伪代码的计算过程,得到结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
+-----------------------------------------------+| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |+-----------------------------------------------+|                                        0.3333 |+-----------------------------------------------+1 row in set (0.01 sec)

下面我们就对模型的参数进行学习。

Model Training

Notice:这里为了简化问题,不考虑 “训练集”、“验证集” 等问题,只使用全部的数据进行训练。

我们还是先给出一个伪代码,然后根据伪代码写出一个 SQL:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
weight = (       w00, w01, w02, w03, w04,    w10, w11, w12, w13, w14,    w20, w21, w22, w23, w24)for iter in iterations:    sum00 = 0    sum01 = 0    ...    sum23 = 0    sum24 = 0    for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:        exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)        exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)        exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)        sum_exp = exp0 + exp1 + exp2        // softmax        p0 = y0 - exp0  sum_exp        p1 = y1 - exp1  sum_exp        p2 = y2 - exp2  sum_exp        sum00 += p0 * x0        sum01 += p0 * x1        sum02 += p0 * x2        ...        sum23 += p2 * x3        sum24 += p2 * x4    w00 = w00 + learning_rate * sum00  Data.size    w01 = w01 + learning_rate * sum01  Data.size    ...    w23 = w23 + learning_rate * sum23  Data.size    w24 = w24 + learning_rate * sum24  Data.size

看上去比较繁琐,因为我们这里选择把 sum, w 等向量给手动展开。

接着我们开始写 SQL 训练,我们先写只有一次迭代的 SQL:

设置学习率和样本数量

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> set @lr = 0.1;Query OK, 0 rows affected (0.00 sec)mysql> set @dsize = 150;Query OK, 0 rows affected (0.00 sec)

迭代一次:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
select     w00 + @lr * sum(d00)  @dsize as w00, w01 + @lr * sum(d01)  @dsize as w01, w02 + @lr * sum(d02)  @dsize as w02, w03 + @lr * sum(d03)  @dsize as w03, w04 + @lr * sum(d04)  @dsize as w04 ,    w10 + @lr * sum(d10)  @dsize as w10, w11 + @lr * sum(d11)  @dsize as w11, w12 + @lr * sum(d12)  @dsize as w12, w13 + @lr * sum(d13)  @dsize as w13, w14 + @lr * sum(d14)  @dsize as w14,    w20 + @lr * sum(d20)  @dsize as w20, w21 + @lr * sum(d21)  @dsize as w21, w22 + @lr * sum(d22)  @dsize as w22, w23 + @lr * sum(d23)  @dsize as w23, w24 + @lr * sum(d24)  @dsize as w24from    (select        w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24,        p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,        p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,        p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24    from        (select          w00, w01, w02, w03, w04,         w10, w11, w12, w13, w14,         w20, w21, w22, w23, w24,         x0, x1, x2, x3, x4,         y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2         from            (select                 w00, w01, w02, w03, w04,                w10, w11, w12, w13, w14,                w20, w21, w22, w23, w24,                x0, x1, x2, x3, x4, y0, y1, y2,                exp(                    w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4                ) as e0,                exp(                    w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4                ) as e1,                exp(                    w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4                 ) as e2             from data, weight) t1        )t2    )t3;

得到的结果是一次迭代后的模型参数:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| w00                              | w01                              | w02                              | w03                              | w04                              | w10                              | w11                              | w12                              | w13                              | w14                              | w20                              | w21                              | w22                              | w23                              | w24                              |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| 0.242000022455130986666666666667 | 0.199736070114635900000000000000 | 0.135689102774125773333333333333 | 0.104372938417325687333333333333 | 0.128775320011717430666666666667 | 0.296128284590438133333333333333 | 0.237124925707748246666666666667 | 0.281477497498236260000000000000 | 0.225631554555397960000000000000 | 0.215390025342499213333333333333 | 0.061871692954430866666666666667 | 0.163139004177615846666666666667 | 0.182833399727637980000000000000 | 0.269995507027276353333333333333 | 0.255834654645783353333333333333 |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+1 row in set (0.03 sec)

下面就是核心部分,我们使用 Recursive CTE 来进行迭代训练:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
mysql> set @num_iterations = 1000;Query OK, 0 rows affected (0.00 sec)

核心的思路是,每次迭代的输入都是上一次迭代的结果,然后我们再加一个递增的迭代变量来控制迭代次数,大体的架构:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
with recursive cte(iter, weight) as(select 1, init_weightunion allselect iter+1, new_weightfrom cte where ites < @num_iterations)

接着,我们把一次迭代的 SQL 和这个迭代的框架结合起来(为了提高计算精度,在中间结果里加入了一些类型转换):

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
with recursive weight( iter,         w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24) as(select 1,     cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),    cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),    cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))union allselect     iter + 1,    w00 + @lr * cast(sum(d00) as DECIMAL(35, 30))  @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30))  @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30))  @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30))  @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30))  @dsize as w04 ,    w10 + @lr * cast(sum(d10) as DECIMAL(35, 30))  @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30))  @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30))  @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30))  @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30))  @dsize as w14,    w20 + @lr * cast(sum(d20) as DECIMAL(35, 30))  @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30))  @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30))  @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30))  @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30))  @dsize as w24    from    (select        iter, w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24,        p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,        p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,        p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24    from        (select          iter, w00, w01, w02, w03, w04,         w10, w11, w12, w13, w14,         w20, w21, w22, w23, w24,         x0, x1, x2, x3, x4,         y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2         from            (select                 iter, w00, w01, w02, w03, w04,                w10, w11, w12, w13, w14,                w20, w21, w22, w23, w24,                x0, x1, x2, x3, x4, y0, y1, y2,                exp(                    w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
                ) as e0,

                exp(

                    w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4

                ) as e1,

                exp(

                    w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4 

                ) as e2

             from data, weight where iter < @num_iterations) t1

        )t2

    )t3

having count(*) > 0

)

select * from weight where iter = @num_iterations;

这个版本和上面迭代一次的版本的区别在于两点:

  1. 在 data join weight 后,我们增加一个 where iter < @num_iterations 用于控制迭代次数,并且在最后的输出中增加了一列 iter + 1 as ite;
  2. 最后我们还增加了 having count(*) > 0 ,避免当最后没有输入数据时,aggregation 还是会输出数据,导致迭代不能结束。

然后我们得到结果:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery

啊这……

recursive cte 竟然不允许在 recursive part 里有子查询!不过把上面的子查询全部都合并到一起也不是不可以,那我手动合并一下,然后再试一下:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block

不允许子查询我可以手动改 SQL,但是不允许用 aggregate function 我是真的没办法了!

在这里我们只能宣布挑战失败…诶,为啥我不能去改一下 TiDB 的实现呢?

根据 proposal 中的介绍,recursive CTE 的实现并没有脱离 TiDB 基本的执行框架,咨询了 @wjhuang2016 之后,得知之所以不允许使用子查询和 aggregate function 的原因应该有两点:

  1. MySQL 也不允许
  2. 如果允许的话,有很多的 corner case 需要处理,非常的复杂

但是这里我们只是需要试验一下功能,暂时把这个 check 给删除掉也未尝不可,diff 里删除了对子查询和 aggregation function 的检查。

下面我们再次执行一遍:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

| iter | w00                              | w01                              | w02                               | w03                               | w04                              | w10                              | w11                               | w12                               | w13                               | w14                              | w20                               | w21                               | w22                              | w23                              | w24                               |

+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

| 1000 | 0.988746701341992382020000000002 | 2.154387045383744124308666666676 | -2.717791657467537500866666666671 | -1.219905459264249309799999999999 | 0.523764101056271250025665250523 | 0.822804724410132626693333333336 | -0.100577045244777709968533333327 | -0.033359805866941626546666666669 | -1.046591158370568595420000000005 | 0.757865074561280001352887284083 | -1.511551425752124944953333333333 | -1.753810000138966371560000000008 | 3.051151463334479351666666666650 | 2.566496617634817948266666666655 | -0.981629175617551201349829226980 |

+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

成功了!我们得到了迭代 1000 次后的参数!

下面我们用新的参数来重新计算正确率:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
+-------------------------------------------------+

| sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |

+-------------------------------------------------+

|                                          0.9867 |

+-------------------------------------------------+

1 row in set (0.02 sec)

这次正确率到达了 98%。

Conclusion

我们这次成功使用纯 SQL 在 TiDB 中训练了一个 Softmax logistic regression model,主要利用了 TiDB v5.1 版本的 Recursive CTE 功能。在测试的过程中,我们发现了目前 TiDB 的 Recursive CTE 不允许存在 subquery 和 aggregate function,我们简单修改了 TiDB 的代码,绕过了这个限制,最终成功训练出了一个模型,并在 iris dataset 上得到了 98% 的准确率。

Discussion

  • 经过一些测试后,发现 PostgreSQL 和 MySQL 均不支持在 Recursive CTE 使用聚合函数,可能实现起来确实存在一些难以处理的 corner case,具体大家可以讨论一下。
  • 本次的尝试,是手动把所有的维度全部展开,实际上我还写了一个不需要展开所有维度的实现(例如 data 表的 schema 是 (idx, dim, value)),但是这种实现方式需要 join 两次 weight 表,也就是在 CTE 里需要递归访问两次,这还需要修改 TiDB Executor 的实现,所以就没有写在这里。但实际上,这种实现方式更加的通用,一个 SQL 可以处理所有维度数量的模型(我最初想尝试用 TiDB 训练 MINIST)。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
深度神经网络权值初始化的几种方式及为什么不能初始化为零(1)
写在前面:该篇文章的内容以及相关代码(代码在最后),都是我亲自手敲出来的,相关结论分析也是花了挺长时间做出来的,如需转载该文章,请务必先联系我,在后台留言即可。
用户7699929
2020/08/27
2.4K0
学习前馈神经网络的数学原理
在我上一篇博客中,我们讨论了人工神经网络的动机是来源于生理。这一篇博文,我们将讨论如何实现人工神经网络。在人工神经网络中,我们使用不同层数的网络来解决问题。使用多少层的网络才能解决一个特定的问题是另一个话题,我很快将为此写一个博客。但是,目前我们仍然可以着手实现网络,并学习如何用它去解决问题。
wheel_BL
2018/02/01
1K0
学习前馈神经网络的数学原理
某老牌反作弊产品分析-(存在加密漏洞可被中间人攻击)
本文仅限学习交流,请勿用于非法以及商业用途,由于时间和水平有限,文中错漏之处在所难免,请多多批评指正。
我是小三
2022/06/19
2K0
某老牌反作弊产品分析-(存在加密漏洞可被中间人攻击)
神经网络的基本原理
神经网络是由一个个的被称为“神经元”的基本单元构成,单个神经元的结构如下图所示:
felixzhao
2022/05/12
1.2K0
神经网络的基本原理
【NLP】Transformer理论解读
Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,目前已经在目标检测、自然语言处理、时序预测等多个深度学习领域获得了应用,成为了新的研究热点。
zstar
2022/09/16
6040
【NLP】Transformer理论解读
机器学习(33)之局部线性嵌入(LLE)【降维】总结
关键字全网搜索最新排名 【机器学习算法】:排名第一 【机器学习】:排名第一 【Python】:排名第三 【算法】:排名第四 前言 局部线性嵌入(Locally Linear Embedding,简称LLE)也是非常重要的降维方法。和传统的PCA,LDA等关注样本方差的降维方法相比,LLE关注于降维时保持样本局部的线性特征,由于LLE在降维时保持了样本的局部特征,它广泛的用于图像图像识别,高维数据可视化等领域。 什么是流形学习 LLE属于流形学习(Manifold Learning)的一种。因此我们首先看看什
昱良
2018/04/04
1.9K0
机器学习(33)之局部线性嵌入(LLE)【降维】总结
一文搞定BP神经网络——从原理到应用(原理篇)「建议收藏」
本文着重讲述经典BP神经网络的数学推导过程,并辅助一个小例子。本文不会介绍机器学习库(比如sklearn, TensorFlow等)的使用。 欲了解卷积神经网络的内容,请参见我的另一篇博客一文搞定卷积神经网络——从原理到应用。
全栈程序员站长
2022/09/09
4.7K1
一文搞定BP神经网络——从原理到应用(原理篇)「建议收藏」
低于0.01%的极致Crash率是怎么做到的?
作者卢子填, 腾讯移动互联网 高级开发工程师 商业转载请联系腾讯WeTest获得授权,非商业转载请注明出处。 WeTest 导读 看似系统Bug的Crash 99%都不是系统问题!本文将与你一起探索Crash分析的科学方法。 在移动互联网闯荡多年的iOS手机管家,经过不断迭代创新,已经涵盖了隐私(加密相册)、安全(骚扰拦截、短信过滤)、工具(网络检测、照片清理、极简提醒等)等等各个方面,为千万用户提供安全专业的服务。但与此同时,工程代码也越来越庞大(近30万行),一丁点的问题都会影响大量的用户,所以手管一
WeTest质量开放平台团队
2018/07/11
2.2K0
[ByteCTF 2021 Final] Master of HTTPD && exsc 题解
IoT题,aarch64,题目修改了mini_httpd的身份验证部分,加了一个输出认证信息的函数——没留意终端STDOUT...这里耽误了点时间。mini_httpd的源码可以在官网下载。
赤道企鹅
2022/08/01
9190
[ByteCTF 2021 Final] Master of HTTPD && exsc 题解
激活函数、正向传播、反向传播及softmax分类器,一篇就够了!
原文链接:https://juejin.im/post/5d46816e51882560b9544ac1
mantch
2019/08/14
1.3K0
美团买菜IOS版设备风控浅析与算法还原
本文仅限学习交流,请勿用于非法以及商业用途,由于时间和水平有限,文中错漏之处在所难免,敬请各位大佬多多批评指正。
我是小三
2021/11/29
5.7K0
美团买菜IOS版设备风控浅析与算法还原
方法的查找流程——快速查找
消息的接收者是objc_super类型,其内部携带了当前方法的调用者——实例对象自身,以及实例对象的父类。
拉维
2021/03/10
6530
方法的查找流程——快速查找
机器学习-10-神经网络python实现-鸢尾花分类
懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合 +算法评估+持续调优+工程化接口实现
用户2225445
2024/05/05
3180
机器学习-10-神经网络python实现-鸢尾花分类
汽车APP产品分析-亿盾加固1
代码安全只是表面,核心是帮助客户满足业务不被阻断、关键数据资产不被窃取的安全需求。因为加固自身不创造价值,加固的价值必须和公司业务挂钩,来间接体现。通过安全体系建立为业务服务保障,增加收益与减少了资损率。
我是小三
2023/03/11
1.7K0
汽车APP产品分析-亿盾加固1
机器学习-10-神经网络python实现-从零开始
本文来源原文链接:https://blog.csdn.net/weixin_66845445/article/details/133828686
用户2225445
2024/04/23
4960
机器学习-10-神经网络python实现-从零开始
如何编写一个Android inline hook框架
缺点:1、不支持函数替换(即hook后不执行原函数),现在只能修改参数寄存器,无法修改返回值。2、不支持定义同类型的hook函数来接受处理参数,只能通过修改寄存器的方式修改参数。多余4个/或者占两个字节的参数,那么参数还要自己从栈上捞取。所以issues中说的把mov r0,sp去掉用来接收参数也是有问题的,就是参数在栈上的情况,传过来的时候sp不是原来的sp了。
FB客服
2020/02/26
3.5K0
机器学习-05-特征工程
特征工程是指使用专业的背景知识和技巧处理数据,使得特征能在机器学习算法上发生更好的作用的过程。更好的特征意味着更强的灵活性,只需简单模型就能得到更好的结果,因此,特征工程在机器学习中占有相当重要的地位,可以说是决定结果成败的最关键和决定性的因素。
用户2225445
2024/03/21
7320
机器学习-05-特征工程
五万字总结,深度学习基础。「建议收藏」
人工神经网络(Artificial Neural Networks,简写为ANNs)是一种模仿动物神经网络行为特征,进行分布式并行信息处理的算法数学模型。这种网络依靠系统的复杂程度,通过调整内部大量节点之间相互连接的关系,从而达到处理信息的目的,并具有自学习和自适应的能力。神经网络类型众多,其中最为重要的是多层感知机。为了详细地描述神经网络,我们先从最简单的神经网络说起。
全栈程序员站长
2022/08/31
1K0
五万字总结,深度学习基础。「建议收藏」
机器学习(五)使用Python和R语言从头开始理解和编写神经网络介绍目录神经网络背后的直观知识多层感知器及其基础知识什么是激活函数?前向传播,反向传播和训练次数(epochs)多层感知器全批量梯度下降
本篇文章是原文的翻译过来的,自己在学习和阅读之后觉得文章非常不错,文章结构清晰,由浅入深、从理论到代码实现,最终将神经网络的概念和工作流程呈现出来。自己将其翻译成中文,以便以后阅读和复习和网友参考。因时间(文字纯手打加配图)紧促和翻译水平有限,文章有不足之处请大家指正。 介绍 你可以通过两种方式学习和实践一个概念: 选项1:您可以了解一个特定主题的整个理论,然后寻找应用这些概念的方法。所以,你阅读整个算法的工作原理,背后的数学知识、假设理论、局限,然后去应用它。这样学习稳健但是需要花费大量的时间去准备。
致Great
2018/04/11
1.3K0
机器学习(五)使用Python和R语言从头开始理解和编写神经网络介绍目录神经网络背后的直观知识多层感知器及其基础知识什么是激活函数?前向传播,反向传播和训练次数(epochs)多层感知器全批量梯度下降
机器学习算法(一)SVM
机器学习的一般框架: 训练集 => 提取特征向量 => 结合一定的算法(分类器:比如决策树、KNN)=>得到结果
全栈程序员站长
2022/11/08
2.4K0
机器学习算法(一)SVM
推荐阅读
相关推荐
深度神经网络权值初始化的几种方式及为什么不能初始化为零(1)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验