gorm 框架是国内的大神 jinzhu 基于 go 语言开源实现的一款数据库 orm 框架. 【gorm】一词恢弘大气,前缀 go 代表 go 语言, 后缀 orm 全称 Object Relation Mapping,指的是使用对象映射的方式,让使用方能够像操作本地对象实例一样轻松便捷地完成远端数据库的操作.
1 入口
gorm 框架通过一个 gorm.DB
实例来指代我们所操作的数据库. 使用 gorm 的第一步就是要通过 Open 方法创建出一个 gorm.DB 实例,其中首个入参为连接器 dialector,本身是个抽象的 interface,其实现类关联了具体数据库类型.
po 模型
基于 orm 的思路,与某张数据表所关联映射的是 po (persist object)模型.
定义 po 类时,可以通过声明 TableName 方法,来指定该类对应的表名.
type Reward struct {
gorm.Model
Amount sql.NullInt64 `gorm:"column:amount"`
Type string `gorm:"not null"`
UserID int64 `gorm:"not null"`
}
func (r Reward) TableName() string {
return "reward"
}
定义 po 类时,可以通过组合 gorm.Model 的方式,完成主键、增删改时间等4列信息的一键添加,并且由于声明了 DeletedAt 字段,gorm 将会默认会启动软删除模式.
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt DeletedAt `gorm:"index"`
}
核心类
gorm.DB 是 gorm 定义的数据库类. 所有执行的数据库的操作都将紧密围绕这个类,以链式调用的方式展开. 每当执行过链式调用后,新生成的 DB 对象中就存储了一些当前请求特有的状态信息,我们把这种对象称作“会话”.
DB 类中的核心字段包括:
- Config:用户自定义的配置项
- Error:一次会话执行过程中遇到的错误
- RowsAffected:该请求影响的行数
- Statement:一次会话的状态信息,比如请求和响应信息
- clone:会话被克隆的次数. 倘若 clone = 1,代表是始祖 DB 实例;倘若 clone > 1,代表是从始祖 DB 克隆出来的会话
DB.Table
请求在执行时,需要明确操作的是哪张数据表.使用方可以通过链式调用 DB.Table 方法,显式声明本次操作所针对的数据表,这种方式的优先级是最高的.在 DB.Table 方法缺省的情况下,gorm 则会尝试通过 po 类的 TableName 方法获取表名.
DB.Statement
会话状态 statement 类,里面存储了一次会话中包含的状态信息,比如请求中的条件、sql 语句拼接格式、响应参数类型、数据表的名称等等.
// Statement statement
type Statement struct {
// 数据库实例
*DB
// ...
// 表名
Table string
// 操作的 po 模型
Model interface{}
// ...
// 处理结果反序列化到此处
Dest interface{}
// ...
// 各种条件语句
Clauses map[string]clause.Clause
// ...
// 是否启用 distinct 模式
Distinct bool
// select 语句
Selects []string // selected columns
// omit 语句
Omits []string // omit columns
// join
Joins []join
// ...
// 连接池,通常情况下是 database/sql 库下的 *DB 类型. 在 prepare 模式为 gorm.PreparedStmtDB
ConnPool ConnPool
// 操作表的概要信息
Schema *schema.Schema
// 上下文,请求生命周期控制管理
Context context.Context
// 在未查找到数据记录时,是否抛出 recordNotFound 错误
RaiseErrorOnNotFound bool
// ...
// 执行的 sql,调用 state.Build 方法后,会将 sql 各部分文本依次追加到其中. 具体可见 2.5 小节
SQL strings.Builder
// 存储的变量
Vars []interface{}
// ...
}
connPool
这里额外强调一下 connPool 字段,其含义是连接池,和数据库的交互操作都需要依赖它才得以执行
db 克隆
DB 的克隆流程,所有在始祖 DB 基础上追加状态信息,克隆出来的 DB 实例都可以称为“会话”.
会话的状态信息主要存储在 statement 当中的,所以在克隆 DB 时,很重要的一环就是完成对 其中 statement 部分的创建/克隆.
该流程对应的方法为 DB.getInstance 方法,主要通过 DB 中的 clone 字段来判断当前是首次从始祖 DB 中执行克隆操作还是在一个会话的基础上克隆出一个新的会话实例
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
// 倘若是首次对 db 进行 clone,则需要构造出一个新的 statement 实例
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
}
// 倘若已经 db clone 过了,则还需要 clone 原先的 statement
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
PrepareStmtDB
在 prepare 预处理模式下,DB 中连接池 connPool 的实现类为 PreparedStmtDB. 定义该类的目的是为了使用 database/sql 标准库中的 prepare 能力,完成预处理状态 statement 的构造和复用.
// prepare 模式下的 connPool 实现类.
type PreparedStmtDB struct {
// 各 stmt 实例. 其中 key 为 sql 模板,stmt 是对封 database/sql 中 *Stmt 的封装
Stmts map[string]*Stmt
// ...
Mux *sync.RWMutex
// 内置的 ConnPool 字段通常为 database/sql 中的 *DB
ConnPool
}
Stmt 类是 gorm 框架对 database/sql 标准库下 Stmt 类的简单封装,两者区别并不大:
type Stmt struct {
// database/sql 标准库下的 statement
*sql.Stmt
// 是否处于事务
Transaction bool
// 标识当前 stmt 是否已初始化完成
prepared chan struct{}
prepareErr error
}
processor执行器
gorm 框架执行 crud 操作逻辑时使用到的执行器 processor,针对 crud 操作的处理函数会以 list 的形式聚合在对应类型 processor 的 fns 字段当中.各类 processor 的初始化是通过 initializeCallbacks 方法完成
type callbacks struct {
// 对应存储了 crud 等各类操作对应的执行器 processor
// query -> query processor
// create -> create processor
// update -> update processor
// delete -> delete processor
processors map[string]*processor
}
后续在请求执行过程中,会根据 crud 的类型,从 callbacks 中获取对应类型的 processor. 比如一笔查询操作,会通过 callbacks.Query() 方法获取对应的 processor
执行器 processor 具体的类定义如下,其中核心字段包括:
- db:从属的 gorm.DB 实例
- Clauses:根据 crud 类型确定的 SQL 格式模板,后续用于拼接生成 sql
- fns:对应于 crud 类型的执行函数链
所有请求遵循的处理思路都是,首先根据其从属的 crud 类型,找到对应的 processor,然后调用 processor 的 Execute 方法,执行该 processor 下的 fns 函数链.
在 Execute 方法中,还有一项很重要的事情,是根据 crud 的类型,获取 sql 拼接格式 clauses,将其赋值到该 processor 的 BuildClauses 字段当中. crud 各类 clauses 格式展示如下:
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
clause 条件
接下来要介绍的是 gorm 框架中的条件 Clause. 一条执行 sql 中,各个部分都属于一个 clause,比如一条 SELECT * FROM reward WHERE id < 10 ORDER by id 的 SQL,其中就包含了 SELECT、FROM、WHERE 和 ORDER 四个 clause.
当使用方通过链式操作克隆 DB时,对应追加的状态信息就会生成一个新的 clause,追加到 statement 对应的 clauses 集合当中. 当请求实际执行时,会取出 clauses 集合,拼接生成完整的 sql 用于执行.
条件 clause 本身是个抽象的 interface,定义如下:
// Interface clause interface
type Interface interface {
// clause 名称
Name() string
// 生成对应的 sql 部分
Build(Builder)
// 和同类 clause 合并
MergeClause(*Clause)
}
不同的 clause 有不同的实现类,我们以 SELECT 为例进行展示:
type Select struct {
// 使用使用 distinct 模式
Distinct bool
// 是否 select 查询指定的列,如 select id,name
Columns []Column
Expression Expression
}
func (s Select) Name() string {
return "SELECT"
}
func (s Select) Build(builder Builder) {
// select 查询指定的列
if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString("DISTINCT ")
}
// 将指定列追加到 sql 语句中
for idx, column := range s.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
// 不查询指定列,则使用 select *
} else {
builder.WriteByte('*')
}
}
拼接 sql 是通过调用 Statement.Build 方法来实现的,入参对应的是 crud 中某一类 processor 的 BuildClauses.
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if firstClauseWritten {
stmt.WriteByte(' ')
}
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b(c, stmt)
} else {
c.Build(stmt)
}
}
}
}
以 query 查询类为例,会遵循 “SELECT”->“FROM”->“WHERE”->“GROUP BY”->“ORDER BY”->“LIMIT”->“FOR” 的顺序,依次从 statement 中获取对应的 clause,通过调用 clause.Build 方法,将 sql 本文组装到 statement 的 SQL 字段中.
初始化
创建db
gorm.Open 方法是创建 DB 实例的入口方法,其中包含如下几项核心步骤:
- 完成 gorm.Config 配置的创建和注入
- 完成连接器 dialector 的注入,本篇使用的是 mysql 版本
- 完成 callbacks 中 crud 等几类 processor 的创建 ( 通过 initializeCallbacks(…) 方法 )
- 完成 connPool 的创建以及各类 processor fns 函数的注册( 通过 dialector.Initialize(…) 方法 )
- 倘若启用了 prepare 模式,需要使用 preparedStmtDB 进行 connPool 的平替
- 构造 statement 实例
- 根据策略,决定是否通过 ping 请求测试连接
- 返回创建好的 db 实例
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
// ...
// 表、列命名策略
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
}
// ...
// 连接器
if dialector != nil {
config.Dialector = dialector
}
// ...
db = &DB{Config: config, clone: 1}
// 初始化 callback 当中的各个 processor
db.callbacks = initializeCallbacks(db)
// ...
if config.Dialector != nil {
// 在其中会对 crud 各个方法的 callback 方法进行注册
// 会对 db.connPool 进行初始化,通常情况下是 database/sql 库下 *sql.DB 的类型
err = config.Dialector.Initialize(db)
// ...
}
// 是否启用 prepare 模式
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
// 倘若启用了 prepare 模式,会对 conn 进行替换
db.ConnPool = preparedStmt
}
// 构造一个 statement 用于存储处理链路中的一些状态信息
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
// 倘若未禁用 AutomaticPing,
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
// ...
return
}
初始化dialector
mysql 是我们常用的数据库,对应于 mysql 版本的 dialector 实现类位于 github.com/go-sql-driver/mysql 包下. 使用方可以通过 Open 方法,将传入的 dsn 解析成配置,然后返回 mysql 版本的 Dialector 实例.
package mysql
func Open(dsn string) gorm.Dialector {
dsnConf, _ := mysql.ParseDSN(dsn)
return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
}
通过 Dialector.Initialize 方法完成连接器初始化操作,其中也会涉及到对连接池 connPool 的初构造,并通过 callbacks.RegisterDefaultCallbacks 方法完成 crud 四类 processor 当中 fns 的注册操作:
import(
"github.com/go-sql-driver/mysql"
)
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.DriverName == "" {
dialector.DriverName = "mysql"
}
// connPool 初始化
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
if err != nil {
return err
}
}
// ...
// register callbacks
callbackConfig := &callbacks.Config{
CreateClauses: CreateClauses,
QueryClauses: QueryClauses,
UpdateClauses: UpdateClauses,
DeleteClauses: DeleteClauses,
}
// ...完成 crud 类操作 callback 函数的注册
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
// ...
return
}
注册crud函数
对应于 crud 四类 processor,注册的函数链 fns 的****内容和顺序是固定的,展示如上图. 相应的源码展示如下,对应的方法为 RegisterDefaultCallbacks(…):
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
// ...
// 创建类 create processor
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
createCallback.Clauses = config.CreateClauses
// 查询类 query processor
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
// 删除类 delete processor
deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
// 更新类 update processor
updateCallback := db.Callback().Update() updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
updateCallback.Clauses = config.UpdateClauses
// row 类
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
// raw 类
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
}
查询
以 db.First 方法作为入口,展示数据库查询的方法链路:
在 db.First 方法当中:
- 遵循 First 的语义,通过 limit 和 order 追加 clause,限制只取满足条件且主键最小的一笔数据
- 追加用户传入的一系列 condition,进行 clause 追加
- 在 First、Take、Last 等方法中,会设置 RaiseErrorOnNotFound 标识为 true,倘若未找到记录,则会抛出 ErrRecordNotFound 错误
- 设置 statement 中的 dest 为用户传入的 dest,作为反序列化响应结果的对象实例
- 获取 query 类型的 processor,调用 Execute 方法执行其中的 fn 函数链,完成 query 操作
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
// order by id limit 1
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
// append clauses
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
// set RaiseErrorOnNotFound
tx.Statement.RaiseErrorOnNotFound = true
// set dest
tx.Statement.Dest = dest
// execute ...
return tx.callbacks.Query().Execute(tx)
}
执行查询类操作时,通常会通过链式调用的方式,传入一些查询限制条件,比如 Where、Group By、Order、Limit 之类. 我们以 Limit 为例,进行展开介绍:
- 首先调用 db.getInstance() 方法,克隆出一份 DB 会话实例
- 调用 statement.AddClause 方法,将 limit 条件追加到 statement 的 Clauses map 中
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: &limit})
return
}
func (stmt *Statement) AddClause(v clause.Interface) {
// ...
name := v.Name()
c := stmt.Clauses[name]
c.Name = name
v.MergeClause(&c)
stmt.Clauses[name] = c
}
核心Query方法
在 query 类型 processor 的 fns 函数链中,最主要的函数是 Query,其中涉及的核心步骤包括:
- 调用 BuildQuerySQL(…) 方法,根据传入的 clauses 组装生成 sql
- 调用 connPool.QueryContext(…) ,完成查询类 sql 的执行,返回查到的行数据 rows(非 prepare 模式下,此处会对接 database/sql 库,走到 sql.DB.QueryContext(…) 方法中)
- 调用 gorm.Scan() 方法,将结果数据反序列化到 statement 的 dest 当中
func Query(db *gorm.DB) {
if db.Error == nil {
// 拼接生成 sql
BuildQuerySQL(db)
if !db.DryRun && db.Error == nil {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
}
}
}
Scan扫描数据
gorm.Scan() 方法,其作用是将查询结果数据反序列化到 dest 当中:
- 通过对 statement 中的 dest 进行分类,采取的不同的处理方式
- 核心方法都是通过 rows.Scan(…) 方法,将响应数据反序列化到 dest 当中
- 调用 rows.Err() 方法,抛出请求过程中遇到的错误
- 倘若启用了 RaiseErrorOnNotFound 模式且查询到的行数为 0,则抛出错误 ErrRecordNotFound
对应源码展示如下:
// Scan 方法将 rows 中的数据扫描解析到 db statement 中的 dest 当中
// 其中 rows 通常为 database/sql 下的 *Rows 类型
// 扫描数据的核心在于调用了 rows.Scan 方法
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
// 影响的行数
db.RowsAffected = 0
// 根据 dest 类型进行断言分配
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() {
// ...
db.RowsAffected++
// 扫描数据的核心在于,调用 rows
db.AddError(rows.Scan(values...))
// ...
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
// ...
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(dest))
}
default:
// ...
// 根据 dest 类型进行前处理 ...
db.AddError(rows.Scan(dest))
// ...
}
// 倘若 rows 中存在错误,需要抛出
if err := rows.Err(); err != nil && err != db.Error {
db.AddError(err)
}
// 在 first、last、take 模式下,RaiseErrorOnNotFound 标识为 true,在没有查找到数据时,会抛出 ErrRecordNotFound 错误
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
db.AddError(ErrRecordNotFound)
}
}