在后端接口开发中,往往需要针对某一张表写相对应的增删改查的sql方法,比如我们查询某张表的数据
func (r *DomainDao) GetHostsByModel(model *config.HostsModel, startNum int, pagesize int) ([]*config.HostsModel, error) {
db := r.Db.GetDB()
if db == nil {
return nil, errors.New("db is nil")
}
result := make([]*config.HostsModel, 0)
args := make([]interface{}, 0)
strSql := "SELECT * FROM table WHERE 1=1 "
if model != nil {
if model.Gid != 0 {
strSql += " AND gid=? "
args = append(args, model.Gid)
}
if len(strings.Trim(model.CdnDomain, " ")) != 0 {
strSql += " AND cdn_domain=? "
args = append(args, model.CdnDomain)
}
if len(strings.Trim(model.CdnName, " ")) != 0 {
strSql += " AND cdn_name=? "
args = append(args, model.CdnName)
}
if model.CdnType != 0 {
strSql += " AND cdn_type=? "
args = append(args, model.CdnType)
}
if model.Master != 0 {
strSql += " AND master=? "
args = append(args, model.Master)
}
if model.Mode != 0 {
strSql += " AND mode=? "
args = append(args, model.Mode)
}
if model.AuthOutTime != 0 {
strSql += " AND auth_out_time=? "
args = append(args, model.AuthOutTime)
}
if model.Enable != 0 {
strSql += " AND enable=? "
args = append(args, model.Enable)
}
if len(strings.Trim(model.AuthKey, " ")) != 0 {
strSql += " AND auth_key=? "
args = append(args, model.AuthKey)
}
if len(strings.Trim(model.Reserve, " ")) != 0 {
strSql += " AND reserve=? "
args = append(args, model.Reserve)
}
}
if startNum >= 0 && pagesize > 0 {
strSql += " LIMIT ?,? "
args = append(args, startNum)
args = append(args, pagesize)
}
var err error
if len(args) > 0 {
_, err = db.ExecSelect(strSql, &result, args...)
} else {
_, err = db.ExecSelect(strSql, &result)
}
return result, err
}
出于封装的目的,我们常常将查询条件动态封装拼接以满足我们不同的查询业务,但每张表的字段都是不同的,所有每张表都像这样写就显得很冗余低效了,所以需要对其再封装
通过反射进行封装:
//GetSelectTableSql 获取通用查询sql
/**
* @name: GetSelectTableSql
* @Descripttion: 获取通用查询sql
* @param {model} 表结构对应条件结构体
* @param {tableName} 表名称
* @param {startNum} 开始下标位置
* @param {pagesize} 条数
* @return {string} 拼装成的sql {[]interface{}} 动态参数
*/
func (db *MySQL) GetSelectTableSql(model interface{}, tableName string, startNum int, pagesize int) (string, []interface{}) {
strSql := " SELECT * FROM " + tableName + " WHERE 1=1 "
args := make([]interface{}, 0)
if model != nil {
refValue := reflect.ValueOf(model)
refType := reflect.TypeOf(model)
fieldCount := refValue.NumField()
for i := 0; i < fieldCount; i++ {
fieldType := refType.Field(i)
fieldValue := refValue.Field(i)
gormTag := fieldType.Tag.Get("gorm")
//对应的列名称(tag中定义)
gormTag = strings.Replace(gormTag, "column:", "", -1)
gormTags := strings.Split(gormTag, ";")
if len(gormTags) > 0 {
column := gormTags[0]
isadd := false
switch fieldType.Type.String() {
case "string":
if fieldValue.Len() > 0 {
isadd = true
}
case "int", "int8", "int16", "int32", "int64":
if fieldValue.Int() != 0 {
isadd = true
}
case "time.Time":
valTime := fieldValue.Interface().(time.Time)
if !tool.CheckIsDefaultTime(valTime) {
isadd = true
} else {
if column == "delete_at" {
strSql += " AND delete_at IS NULL "
}
}
}
if isadd {
strSql += " AND " + column + " = ? "
args = append(args, fieldValue.Interface())
}
}
}
}
if startNum >= 0 && pagesize > 0 {
strSql += " LIMIT ?,? "
args = append(args, startNum)
args = append(args, pagesize)
}
return strSql, args
}
//GetSelectTableCountSql 获取通用查询的数量sql
/**
* @name: GetSelectTableCountSql
* @Descripttion: 获取通用查询的数量sql
* @param {model} 表结构对应条件结构体
* @param {tableName} 表名称
* @return {string} 拼装成的sql {[]interface{}} 动态参数
*/
func (db *MySQL) GetSelectTableCountSql(model interface{}, tableName string) (string, []interface{}) {
strSql := " SELECT COUNT(1) FROM " + tableName + " WHERE 1=1 "
args := make([]interface{}, 0)
if model != nil {
refValue := reflect.ValueOf(model)
refType := reflect.TypeOf(model)
fieldCount := refValue.NumField()
for i := 0; i < fieldCount; i++ {
fieldType := refType.Field(i)
fieldValue := refValue.Field(i)
gormTag := fieldType.Tag.Get("gorm")
//对应的列名称(tag中定义)
gormTag = strings.Replace(gormTag, "column:", "", -1)
gormTags := strings.Split(gormTag, ";")
if len(gormTags) > 0 {
column := gormTags[0]
isadd := false
switch fieldType.Type.String() {
case "string":
if fieldValue.Len() > 0 {
isadd = true
}
case "int", "int8", "int16", "int32", "int64":
if fieldValue.Int() != 0 {
isadd = true
}
case "time.Time":
valTime := fieldValue.Interface().(time.Time)
if !tool.CheckIsDefaultTime(valTime) {
isadd = true
} else {
if column == "delete_at" {
strSql += " AND delete_at IS NULL "
}
}
}
if isadd {
strSql += " AND " + column + " = ? "
args = append(args, fieldValue.Interface())
}
}
}
}
return strSql, args
}
//InsertTable 通用新增表
/**
* @name: InsertTable
* @Descripttion: 通用新增表
* @param {model} 表结构对应结构体数据
* @param {tableName} 表名称
* @return {int64} 新增gid {error} 错误
*/
func (db *MySQL) InsertTable(model interface{}, tableName string) (int64, error) {
if model == nil {
return -1, errors.New("model is nil")
}
strSql := "insert " + tableName
args := make([]interface{}, 0)
strSql += " ("
refValue := reflect.ValueOf(model)
refType := reflect.TypeOf(model)
fieldCount := refValue.NumField()
for i := 0; i < fieldCount; i++ {
fieldType := refType.Field(i)
fieldValue := refValue.Field(i)
gormTag := fieldType.Tag.Get("gorm")
//对应的列名称(tag中定义)
gormTag = strings.Replace(gormTag, "column:", "", -1)
gormTags := strings.Split(gormTag, ";")
if len(gormTags) > 0 {
column := gormTags[0]
isadd := false
switch fieldType.Type.String() {
case "string":
if fieldValue.Len() > 0 {
isadd = true
}
case "int", "int8", "int16", "int32", "int64":
if fieldValue.Int() != 0 {
isadd = true
}
case "time.Time":
valTime := fieldValue.Interface().(time.Time)
if !tool.CheckIsDefaultTime(valTime) {
isadd = true
}
}
if isadd {
strSql += column + ","
args = append(args, fieldValue.Interface())
}
}
}
if len(args) < 1 {
return -1, errors.New("args is nil")
}
insertKeyStr := strSql[0:len(strSql)-1] + ") "
insertValueStr := " values ("
for i := 0; i < len(args); i++ {
insertValueStr += "?"
if i != len(args)-1 {
insertValueStr += ","
}
}
insertValueStr += ")"
insertSql := insertKeyStr + insertValueStr
result, err := db.Exec(insertSql, args...)
return result, err
}
//UpdateTableByColumn 通用修改
/**
* @name: UpdateTableByColumn
* @Descripttion: 通用修改通过表结构某一字段
* @param {model} 表结构对应结构体数据
* @param {tableName} 表名称
* @param {mapcolumn} 根据表字段修改 默认通过gid
* @return {int64} 新增gid {error} 错误
*/
func (db *MySQL) UpdateTableByColumn(model interface{}, tableName string, mapcolumn map[string]interface{}) (int64, error) {
if model == nil {
return -1, errors.New("model is nil")
}
strSql := "update " + tableName + " SET "
args := make([]interface{}, 0)
refValue := reflect.ValueOf(model)
refType := reflect.TypeOf(model)
fieldCount := refValue.NumField()
var gid int64
for i := 0; i < fieldCount; i++ {
fieldType := refType.Field(i)
fieldValue := refValue.Field(i)
gormTag := fieldType.Tag.Get("gorm")
//对应的列名称(tag中定义)
gormTag = strings.Replace(gormTag, "column:", "", -1)
gormTags := strings.Split(gormTag, ";")
if len(gormTags) > 0 {
column := gormTags[0]
if column == "gid" {
gid = fieldValue.Interface().(int64)
}
isadd := false
switch fieldType.Type.String() {
case "string":
if fieldValue.Len() > 0 {
isadd = true
}
case "int", "int8", "int16", "int32", "int64":
if fieldValue.Int() != 0 {
isadd = true
}
case "time.Time":
valTime := fieldValue.Interface().(time.Time)
if !tool.CheckIsDefaultTime(valTime) {
isadd = true
} else {
if column == "delete_at" {
strSql += "delete_at=NULL,"
}
}
}
if isadd {
strSql += column + "=?,"
args = append(args, fieldValue.Interface())
}
}
}
if len(args) < 1 {
return -1, errors.New("args is nil")
}
//默认通过gid修改
if mapcolumn == nil {
if gid == 0 {
return -1, errors.New("update where is nil")
} else {
mapcolumn = make(map[string]interface{})
mapcolumn["gid"] = gid
}
}
updateStr := strSql[0:len(strSql)-1] + " where 1=1"
for k, v := range mapcolumn {
updateStr += " AND " + k + "=? "
args = append(args, v)
}
result, err := db.Exec(updateStr, args...)
return result, err
}
调用:
//新增
func (r *RatetemplateDao) InsertRatetemplate(model *config.RatetemplateModel) (int64, error) {
db := r.Db.GetDB()
if db == nil {
return -1, errors.New(dao.DbErrMsg)
}
if model == nil {
return -1, errors.New(daoErrMsg)
}
var intoModel interface{}
if model == nil {
intoModel = nil
} else {
intoModel = *model
}
result, err := db.InsertTable(intoModel, ratetemplateTableName)
return result, err
}
//修改
func (r *RatetemplateDao) UpdateRatetemplateById(model *config.RatetemplateModel) (int64, error) {
db := r.Db.GetDB()
if db == nil {
return -1, errors.New(dao.DbErrMsg)
}
if model == nil {
return -1, errors.New(daoErrMsg)
}
var intoModel interface{}
if model == nil {
intoModel = nil
} else {
intoModel = *model
}
result, err := db.UpdateTableByColumn(intoModel, ratetemplateTableName, nil)
return result, err
}
//查询
func (r *RatetemplateDao) GetRatetemplateByModel(model *config.RatetemplateModel, startNum int, pagesize int) ([]*config.RatetemplateModel, error) {
db := r.Db.GetDB()
if db == nil {
return nil, errors.New(dao.DbErrMsg)
}
result := make([]*config.RatetemplateModel, 0)
var intoModel interface{}
if model == nil {
intoModel = nil
} else {
intoModel = *model
}
strSql, args := db.GetSelectTableSql(intoModel, ratetemplateTableName, startNum, pagesize)
var err error
if len(args) > 0 {
_, err = db.ExecSelect(strSql, &result, args...)
} else {
_, err = db.ExecSelect(strSql, &result)
}
return result, err
}
传入的参数对象结构体需要在tag里面定义相应的解析值column:
type RatetemplateModel struct {
Gid int64 `sql:"Gid" gorm:"column:gid;primary_key;auto_increment;comment:'唯一标识';type:bigint(20)" json:"gid"`
CdnGid int64 `sql:"CdnGid" gorm:"column:cdn_gid;not null;comment:'拉流域名gid';type:bigint(20)" json:"cdn_gid"`
AppName string `sql:"AppName" gorm:"column:app_name;not null;comment:'业务线名称(live)';type:varchar(32)" json:"app_name"`
}
这样我们就不需要再写每张表的常规增删改查的sql语句了,而且当有大量的单一业务时,可以写一个代码生成工具根据数据库来生成这些代码
ps:反射的时候相应的传入对象结构体不能是指针类型的,如外层业务传入的是指针类型,需要转换为值类型,关键代码:
intoModel = *model,以上内容只是方便记录理解反射逻辑,某些具体sql执行方法未贴出