背景
无论是单体项目,还是分布式项目,一个请求进来总会有一定的链路,单体项目中会调用各种方法,分布式服务中更麻烦一点,跨服务调用。于是乎,我们就希望有一个全局的traceId可以把一个请求过程中经过的所有链路的关键信息串联起来,这样的话在检索日志的时候可以带来极大的方便,根据traceId把整个链路上的日志全部打印出来。
在golang项目中,通用的写法是通过context实现traceId信息传递。那么gorm如何通过context把traceId传进去,以实现打印日志带上traceId信息呢?
我们得通过阅读源码来寻找这个问题的解决方案。
gorm源码解读
我们首先需要了解gorm日志打印是如何实现的,任意找一个sql执行方法进去,比如,查询的方法。
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {tx = db.getInstance()if len(conds) > 0 {if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {tx.Statement.AddClause(clause.Where{Exprs: exprs})}}tx.Statement.Dest = destreturn tx.callbacks.Query().Execute(tx)
}
进一步寻找打印日志的逻辑,定位到Execute方法。在Execute方法中找到了打印日志的逻辑。
if stmt.SQL.Len() > 0 {db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {sql, vars := stmt.SQL.String(), stmt.Varsif filter, ok := db.Logger.(ParamsFilter); ok {sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)}return db.Dialector.Explain(sql, vars...), db.RowsAffected}, db.Error)}
到了这边,我们发现,日志打印调用的Trace方法的第一个传参是Context。所以,我们继续顺腾摸瓜看这个Context是通过什么方式传进来的。Context从db.Statement中获取的。所以,我们需要寻找给db.Statement赋值的方法。
func (db *DB) getInstance() *DB {if db.clone > 0 {tx := &DB{Config: db.Config, Error: db.Error}if db.clone == 1 {// clone with new statementtx.Statement = &Statement{DB: tx,ConnPool: db.Statement.ConnPool,Context: db.Statement.Context,Clauses: map[string]clause.Clause{},Vars: make([]interface{}, 0, 8),SkipHooks: db.Statement.SkipHooks,}} else {// with clone statementtx.Statement = db.Statement.clone()tx.Statement.DB = tx}return tx}return db
}
然后,我们就在WithContext的方法中找到了把context传递进来的入口。
func (db *DB) WithContext(ctx context.Context) *DB {return db.Session(&Session{Context: ctx})
}
传Context的入口找到了,那么,gorm中如何根据context中自定义值打印日志呢?比如,Context中塞了自定义的traceId的key,value值?
我们回到前面打印日志的地方,看打印日志的方法,打印日志的Trace方法是这个接口下的一个方法。
type Interface interface {LogMode(LogLevel) InterfaceInfo(context.Context, string, ...interface{})Warn(context.Context, string, ...interface{})Error(context.Context, string, ...interface{})Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {if l.LogLevel <= Silent {return}elapsed := time.Since(begin)switch {case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):sql, rows := fc()if rows == -1 {l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)} else {l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)}case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:sql, rows := fc()slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)if rows == -1 {l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)} else {l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)}case l.LogLevel == Info:sql, rows := fc()if rows == -1 {l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)} else {l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)}}
}
定位到trace方法中,我们发现并没有处理Context,其实很正常。
所以,我们需要重写这个Trace方法,自定义一个log对象,实现gorm的log接口。
解决方案
直接上代码。
package mainimport ("context""encoding/json""errors""fmt""time""go.uber.org/zap""gorm.io/gorm""gorm.io/gorm/logger""gorm.io/gorm/utils""gorm.io/driver/mysql"
)func main() {zapL, err := zap.NewProduction()if err != nil {panic(err)}log := New(zapL,WithCustomFields(String("timeStamp", time.Now().Format("2006-01-02 15:04:05")),func(ctx context.Context) zap.Field {v := ctx.Value("requestId")if v == nil {return zap.Skip()}if vv, ok := v.(string); ok {return zap.String("trace", vv)}return zap.Skip()},func(ctx context.Context) zap.Field {v := ctx.Value("method")if v == nil {return zap.Skip()}if vv, ok := v.(string); ok {return zap.String("method", vv)}return zap.Skip()},),WithConfig(logger.Config{SlowThreshold: 200 * time.Millisecond,Colorful: false,IgnoreRecordNotFoundError: false,LogLevel: logger.Info,}),)mysqlConfig := mysql.Config{DSN: "*******", // DSN data source nameDefaultStringSize: 191, // string 类型字段的默认长度SkipInitializeWithVersion: false, // 根据版本自动配置}// your dialectordb, _ := gorm.Open(mysql.New(mysqlConfig), &gorm.Config{Logger: log})// do your thingsresult := make(map[string]interface{})ctx := context.WithValue(context.Background(), "method", "method")db.WithContext(context.WithValue(ctx, "requestId", "requestId123456")).Table("privacy_detail").Find(&result)db.WithContext(context.WithValue(context.Background(), "requestId", "requestId123457")).Table("privacy_detail").Find(&result)db.WithContext(context.WithValue(context.Background(), "requestId", "requestId123458")).Table("privacy_detail").Create(&result)log.Info(context.WithValue(context.Background(), "requestId", "requestId123456"), "msg", "args")
}// Logger logger for gorm2
type Logger struct {log *zap.Loggerlogger.ConfigcustomFields []func(ctx context.Context) zap.Field
}// Option logger/recover option
type Option func(l *Logger)// WithCustomFields optional custom field
func WithCustomFields(fields ...func(ctx context.Context) zap.Field) Option {return func(l *Logger) {l.customFields = fields}
}// WithConfig optional custom logger.Config
func WithConfig(cfg logger.Config) Option {return func(l *Logger) {l.Config = cfg}
}// SetGormDBLogger set db logger
func SetGormDBLogger(db *gorm.DB, l logger.Interface) {db.Logger = l
}// New logger form gorm2
func New(zapLogger *zap.Logger, opts ...Option) logger.Interface {l := &Logger{log: zapLogger,Config: logger.Config{SlowThreshold: 200 * time.Millisecond,Colorful: false,IgnoreRecordNotFoundError: false,LogLevel: logger.Warn,},}for _, opt := range opts {opt(l)}return l
}// LogMode log mode
func (l *Logger) LogMode(level logger.LogLevel) logger.Interface {newLogger := *lnewLogger.LogLevel = levelreturn &newLogger
}// Info print info
func (l Logger) Info(ctx context.Context, msg string, args ...interface{}) {if l.LogLevel >= logger.Info {//预留10个字段位置fields := make([]zap.Field, 0, 10+len(l.customFields))fields = append(fields, zap.String("file", utils.FileWithLineNum()))for _, customField := range l.customFields {fields = append(fields, customField(ctx))}for _, arg := range args {if vv, ok := arg.(zapcore.Field); ok {if len(vv.String) > 0 {fields = append(fields, zap.String(vv.Key, vv.String))} else if vv.Integer > 0 {fields = append(fields, zap.Int64(vv.Key, vv.Integer))} else {fields = append(fields, zap.Any(vv.Key, vv.Interface))}}}l.log.Info(msg, fields...)}
}// Warn print warn messages
func (l Logger) Warn(ctx context.Context, msg string, args ...interface{}) {if l.LogLevel >= logger.Warn {//预留10个字段位置fields := make([]zap.Field, 0, 10+len(l.customFields))fields = append(fields, zap.String("file", utils.FileWithLineNum()))for _, customField := range l.customFields {fields = append(fields, customField(ctx))}for _, arg := range args {if vv, ok := arg.(zapcore.Field); ok {if len(vv.String) > 0 {fields = append(fields, zap.String(vv.Key, vv.String))} else if vv.Integer > 0 {fields = append(fields, zap.Int64(vv.Key, vv.Integer))} else {fields = append(fields, zap.Any(vv.Key, vv.Interface))}}}l.log.Warn(msg, fields...)}
}// Error print error messages
func (l Logger) Error(ctx context.Context, msg string, args ...interface{}) {if l.LogLevel >= logger.Error {//预留10个字段位置fields := make([]zap.Field, 0, 10+len(l.customFields))fields = append(fields, zap.String("file", utils.FileWithLineNum()))for _, customField := range l.customFields {fields = append(fields, customField(ctx))}for _, arg := range args {if vv, ok := arg.(zapcore.Field); ok {if len(vv.String) > 0 {fields = append(fields, zap.String(vv.Key, vv.String))} else if vv.Integer > 0 {fields = append(fields, zap.Int64(vv.Key, vv.Integer))} else {fields = append(fields, zap.Any(vv.Key, vv.Interface))}}}l.log.Error(msg, fields...)}
}// Trace print sql message
func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {if l.LogLevel <= logger.Silent {return}fields := make([]zap.Field, 0, 6+len(l.customFields))elapsed := time.Since(begin)switch {case err != nil && l.LogLevel >= logger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):for _, customField := range l.customFields {fields = append(fields, customField(ctx))}fields = append(fields,zap.Error(err),zap.String("file", utils.FileWithLineNum()),zap.Duration("latency", elapsed),)sql, rows := fc()if rows == -1 {fields = append(fields, zap.String("rows", "-"))} else {fields = append(fields, zap.Int64("rows", rows))}fields = append(fields, zap.String("sql", sql))l.log.Error("", fields...)case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn:for _, customField := range l.customFields {fields = append(fields, customField(ctx))}fields = append(fields,zap.Error(err),zap.String("file", utils.FileWithLineNum()),zap.String("slow!!!", fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)),zap.Duration("latency", elapsed),)sql, rows := fc()if rows == -1 {fields = append(fields, zap.String("rows", "-"))} else {fields = append(fields, zap.Int64("rows", rows))}fields = append(fields, zap.String("sql", sql))l.log.Warn("", fields...)case l.LogLevel == logger.Info:for _, customField := range l.customFields {fields = append(fields, customField(ctx))}fields = append(fields,zap.Error(err),zap.String("file", utils.FileWithLineNum()),zap.Duration("latency", elapsed),)sql, rows := fc()if rows == -1 {fields = append(fields, zap.String("rows", "-"))} else {fields = append(fields, zap.Int64("rows", rows))}fields = append(fields, zap.String("sql", sql))l.log.Info("", fields...)}
}// Immutable custom immutable field
// Deprecated: use Any instead
func Immutable(key string, value interface{}) func(ctx context.Context) zap.Field {return Any(key, value)
}// Any custom immutable any field
func Any(key string, value interface{}) func(ctx context.Context) zap.Field {field := zap.Any(key, value)return func(ctx context.Context) zap.Field { return field }
}// String custom immutable string field
func String(key string, value string) func(ctx context.Context) zap.Field {field := zap.String(key, value)return func(ctx context.Context) zap.Field { return field }
}// Int64 custom immutable int64 field
func Int64(key string, value int64) func(ctx context.Context) zap.Field {field := zap.Int64(key, value)return func(ctx context.Context) zap.Field { return field }
}// Uint64 custom immutable uint64 field
func Uint64(key string, value uint64) func(ctx context.Context) zap.Field {field := zap.Uint64(key, value)return func(ctx context.Context) zap.Field { return field }
}// Float64 custom immutable float32 field
func Float64(key string, value float64) func(ctx context.Context) zap.Field {field := zap.Float64(key, value)return func(ctx context.Context) zap.Field { return field }
}
自定义结构体
// Logger logger for gorm2
type Logger struct {
log *zap.Logger
logger.Config
customFields []func(ctx context.Context) zap.Field
}
关键在于 customFields定义了一个接受传Context参数的方法。在初始化日志的地方,传从Context中获取对应参数的函数,比如,从context中接受traceId。
由此,gorm log with traceId目的实现。