目录

singleflight

防止缓存穿透算法

代码参考:
https://github.com/coredns/coredns/blob/v1.9.1/plugin/pkg/singleflight/singleflight.go

什么是singleflight算法?

项目开发中有一个常见场景:
        给redis缓存内容设置一个过期时间,当缓存未命中的时候,再访问数据库获取数据。

缓存穿透 的意思是:大量请求透过缓存,直接访问到数据库。

如图所示,因为缓存失效,在步骤 2 完成前,大量请求已经运行到步骤 1 阶段,服务器就要承受较大的压力(看你请求数有大)。如何解决这个问题? singleflight算法提供了很好的思路

../images/2112211222.jpg
缓存穿透场景示意

思路:只允许一个请求穿透,其他所有请求在步骤 3 处排队等待。
singleflight 源码如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package singleflight

import "sync"

type call struct {
    wg  sync.WaitGroup
    val interface{}
    err error
}

type Group struct {
    mu sync.Mutex     
    m  map[uint64]*call // 懒加载
}

func (g *Group) Do(key uint64, fn func() (interface{}, error)) (interface{}, error) {
  +----------------------------------+
  | g.mu.Lock()                      |
  | if g.m == nil {                  |
  |     g.m = make(map[uint64]*call) |
  | }                                |         //1号访问者访问时
  | if c, ok := g.m[key]; ok {       |         //无法从map里获取值(缓存击穿)
  |     g.mu.Unlock()                +------>  //它新建一个map的值(是个指针)
  |     c.wg.Wait()                  |         //其他访问者可以获取到该指针值,并阻塞
  |     return c.val, c.err          |         //直到1号访问者获取值后共享结果
  | }                                |         
  | c := new(call)                   |
  | c.wg.Add(1)                      |
  | g.m[key] = c                     |
  | g.mu.Unlock()                    |
  +----------------------------------+

  +----------------------+
  | c.val, c.err = fn()  |        //仅1号访问者访问,获取一次结果
  | c.wg.Done()          +----->  //结果存map里
  +----------------------+        //其他访问者复用此结果

  +------------------+
  | g.mu.Lock()      |       //仅一号访问者能运行到这
  | delete(g.m, key) +---->  //等其他请求访问完后
  | g.mu.Unlock()    |       //删除掉map中的数据,因为已经写到redis了,不需要了
  +------------------+

    return c.val, c.err
}

模拟缓存穿透

思路:

  1. 用httprouter写一个服务器,用两个map模拟redis和mysql的数据,初始时只给mysql分配数据(所以初始开5k gorutine 访问的时候就相当于缓存穿透的过程)
  2. 客户端并发产生5k请求,观察结果

服务端

以下内容见代码:

  1. 接口
  2. 模拟数据存储方式

为了更真实地还原缓穿透,在访问mysql读数据前睡眠两秒钟,能保证未使用singleflight的时候大量请求访问到mysql,更好地对比观察结果

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// server/main.go
package main

import (
	"errors"
	"fmt"
	"log"
	"net/http"
	"sync"
	"time"

	"github.com/julienschmidt/httprouter"
)
var group *Group
var mux sync.RWMutex
var rwMux sync.RWMutex
var redisDataBase map[string]string
var mysqlDataBase map[string]string
var countRedisHit int
var countMysqlHit int

func GetFromRedis(key string) (string, error) {
	if data, ok := redisDataBase[key]; ok {
		mux.Lock()
		countRedisHit++
		mux.Unlock()
		return data, nil
	}
	if data, err := GetFromMySql(key); err == nil {
		return data, nil
	} else {
		return "", err
	}
}

func GetFromMySql(key string) (string, error) {
	time.Sleep(time.Second * 1)
	// defer fmt.Println(time.Now())
	if data, ok := mysqlDataBase[key]; ok {
		// write to redis
		rwMux.Lock()
		redisDataBase[key] = "data stored in redis"
		rwMux.Unlock()
		// 相当于设置过期时间 2s
		// 每次从存到redis缓存的值都会在2s以后过期
		go func(key string) {
			time.Sleep(time.Second * 2)
			rwMux.Lock()
			delete(redisDataBase, key)
			rwMux.Unlock()
		}(key)
		mux.Lock()
		countMysqlHit++
		mux.Unlock()
		return data, nil
	} else {
		return "", errors.New("data not found!")
	}
}

// 正常版
func GetUserInfo(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
	// 从缓存中查结果
	queryValues := req.URL.Query()
	if res, err := GetFromRedis(queryValues.Get("name")); err != nil {
		log.Fatal("err:", err)
	} else {
		fmt.Fprintf(w, res)
	}
}

// singflight版本
func GetUserInfo1(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
	// 从缓存中查结果
	queryValues := req.URL.Query()
	function := func() (interface{}, error) {
		if res, err := GetFromRedis(queryValues.Get("name")); err == nil {
			return res, nil
		} else {
			return "", err
		}
	}
	// 原来是没有共用一个group啊,我说怎么没效果
	res, err := group.Do(uint64(1), function)
	if err != nil {
		log.Fatal("err from singleflight:", err)
	}
	fmt.Fprintf(w, res.(string))
}

func ShowHitCount(w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
	fmt.Fprintf(w, "redis:%d\nmysql:%d\n", countRedisHit, countMysqlHit)
}

func init() {
	redisDataBase, mysqlDataBase = map[string]string{}, map[string]string{}
	mysqlDataBase["zqw"] = "data stored in mysql" // 初始时redis里无数据
	group = &Group{}
}

func hello(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
	queryValues := req.URL.Query()
    fmt.Fprintf(w, "hello, %s!\n", queryValues.Get("name"))
}

func main() {

	router := httprouter.New()
	router.GET("/", hello)
	router.GET("/userinfo", GetUserInfo)
	router.GET("/userinfo1", GetUserInfo1)
	router.GET("/count", ShowHitCount)
	log.Fatal(http.ListenAndServe(":9090", router))

}

客户端

思路比较简单(思路在代码里),并发请求,打印输出结果,计时

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package main

import (
	"fmt"
	"io"
	"log"
	"net/http"
	"sync"
	"time"
)

func curl(str string) string {
	resp, err := client.Get(str)
	if err != nil {
		log.Println("error:", err)
		
	}
	res, err := io.ReadAll(resp.Body)
	defer resp.Body.Close()
	return string(res)
}

var client http.Client
func init() {
	client = http.Client{}
}
func main() {
	timeStampA := time.Now()
	defer client.CloseIdleConnections()
	var wg sync.WaitGroup
	for i := 0; i < 5000; i++ {
		// time.Sleep(time.Microsecond*1)
		wg.Add(1)
		go func() {
			defer wg.Done()
            // 下面两个轮着来
            //curl("http://localhost:9090/userinfo?name=zqw")
			curl("http://localhost:9090/userinfo1?name=zqw")
		}()
	}
	wg.Wait()
	res := curl("http://localhost:9090/count")
	fmt.Println(res)
	timeStampB := time.Now()
	fmt.Println("运行共用时: ", timeStampB.Sub(timeStampA).Seconds())
}

结果

未使用singleflight:

1
2
3
4
redis:0
mysql:5000

运行共用时:  1.5597763709999999

使用singleflight:

1
2
3
4
redis:0
mysql:1  # 其余4999次访问全部取的是第一次访问的结果

运行共用时:  1.216015732

结果可以复现
如果出现异常的话可能操作系统默认设置的打开文件数太少,建议把并发量改小点
我已经默认设置了 ulimit -n 8192

尝试1w并发,结果异常
尝试10w并发,结果不断panic(在客户端读io的时候,或者服务器程序挂掉)

结论

singleflight算法短小精悍,其思想非常值得学习
本质上是把访问磁盘的压力转移到了访问内存

写代码的时候出现了很多问题:

  1. 访问完mysql并向redis写入数据的时候(实际上是向map写数据),服务器频繁宕机(panic提示读取内存错误),通过加锁解决了
  2. unlock写成lock导致服务器端死锁,不能响应服务
  3. singleflight的变量没有设置成全局,每次都是在handleFunc处理函数里面重新定义,导致没有发挥singleflight的作用,后面检查代码找到了错误
  4. 并发量超过5w以后会出现各种莫名其妙的错误(runtime io panic),运行总耗时50倍以上,相当于性能降低10倍以上,放弃