前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >SQLite3 自定义函数(UDF)

SQLite3 自定义函数(UDF)

作者头像
一个会写诗的程序员
发布2022-05-13 15:12:48
6110
发布2022-05-13 15:12:48
举报
文章被收录于专栏:一个会写诗的程序员的博客
代码语言:javascript
复制
package main

import (
    "database/sql"
    "fmt"
    "log"
    "math"
    "math/rand"

    sqlite "github.com/mattn/go-sqlite3"
)

// Computes x^y
func pow(x, y int64) int64 {
    return int64(math.Pow(float64(x), float64(y)))
}

// Computes the bitwise exclusive-or of all its arguments
func xor(xs ...int64) int64 {
    var ret int64
    for _, x := range xs {
        ret ^= x
    }
    return ret
}

// Returns a random number. It's actually deterministic here because
// we don't seed the RNG, but it's an example of a non-pure function
// from SQLite's POV.
func getrand() int64 {
    return rand.Int63()
}

// Computes the standard deviation of a GROUPed BY set of values
type stddev struct {
    xs []int64
    // Running average calculation
    sum int64
    n   int64
}

func newStddev() *stddev { return &stddev{} }

func (s *stddev) Step(x int64) {
    s.xs = append(s.xs, x)
    s.sum += x
    s.n++
}

func (s *stddev) Done() float64 {
    mean := float64(s.sum) / float64(s.n)
    var sqDiff []float64
    for _, x := range s.xs {
        sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2))
    }
    var dev float64
    for _, x := range sqDiff {
        dev += x
    }
    dev /= float64(len(sqDiff))
    return math.Sqrt(dev)
}

func main() {
    sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{
        ConnectHook: func(conn *sqlite.SQLiteConn) error {
            if err := conn.RegisterFunc("pow", pow, true); err != nil {
                return err
            }
            if err := conn.RegisterFunc("xor", xor, true); err != nil {
                return err
            }
            if err := conn.RegisterFunc("rand", getrand, false); err != nil {
                return err
            }
            if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil {
                return err
            }
            return nil
        },
    })

    db, err := sql.Open("sqlite3_custom", ":memory:")
    if err != nil {
        log.Fatal("Failed to open database:", err)
    }
    defer db.Close()

    var i int64
    err = db.QueryRow("SELECT pow(2,3)").Scan(&i)
    if err != nil {
        log.Fatal("POW query error:", err)
    }
    fmt.Println("pow(2,3) =", i) // 8

    err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i)
    if err != nil {
        log.Fatal("XOR query error:", err)
    }
    fmt.Println("xor(1,2,3,4,5) =", i) // 7

    err = db.QueryRow("SELECT rand()").Scan(&i)
    if err != nil {
        log.Fatal("RAND query error:", err)
    }
    fmt.Println("rand() =", i) // pseudorandom

    _, err = db.Exec("create table foo (department integer, profits integer)")
    if err != nil {
        log.Fatal("Failed to create table:", err)
    }
    _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)")
    if err != nil {
        log.Fatal("Failed to insert records:", err)
    }

    rows, err := db.Query("select department, stddev(profits) from foo group by department")
    if err != nil {
        log.Fatal("STDDEV query error:", err)
    }
    defer rows.Close()
    for rows.Next() {
        var dept int64
        var dev float64
        if err := rows.Scan(&dept, &dev); err != nil {
            log.Fatal(err)
        }
        fmt.Printf("dept=%d stddev=%f\n", dept, dev)
    }
    if err := rows.Err(); err != nil {
        log.Fatal(err)
    }
}
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-05-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档