Tag Go

一种应用于特定场景的支持LRU的线程安全的无锁uint32->uint32 cache实现

Category 存储
Tag Go
Posted on
View

1. 前言

几年前给公司前台业务一个QPS很高的接口做了一个优化,主要请求来源是当前在线用户,接口核心逻辑就是从codis中根据一个数字查询对应的用户id(小于1亿),这两个数字的映射关系是不变的,可以理解为codis中有一个map[uint32]uint32的映射表,这个映射表只增不改。

因为接口对codis造成压力很大,因此决定在Go内存中将映射关系缓存下来,但由于这个映射表很大所以不能全部缓存中内存。因此结合业务逻辑决定引入了一个支持LRU淘汰策略的uint32 -> uint32的高性能缓存组件。

调研之后发现市面上Go的各种线程安全还支持LRU的缓存都是有锁的,性能可能受限,因此决定根据应用场景自己搞个特殊的缓存组件。

2. 实现原理

首先还是贴一下源码仓库地址: https://github.com/Orlion/intcache

2.1 结构体定义:

type IntCache struct {
	b          uint8
	buckets    [][8]uint64
	lruBuckets []uint32
}

func New(b uint8) *IntCache {
	cap := 1 << b
	return &IntCache{
		b:          b,
		buckets:    make([][8]uint64, cap),
		lruBuckets: make([]uint32, cap),
	}
}

image.png

如上图所示,一个IntCache2^bbucketlruBucket,一个bucket有8个K-V对,一个K-V对使用uint64来存储,前32bit存储key,后32bit存储value。一个lruBucket有8个lru值,采用uint32存储,每4bit存储bucket对应的每个K-V对的lru值。

这里你可能会很奇怪为什么lru要单独存储,不要急,继续往下看,读流程时我会详细解释。

2.2 写流程

func (c *IntCache) Set(key uint32, value uint32) {
        if key == 0 && value == 0 {
		panic("key and value can't be 0")
	}
	bucketi := key & (1<<c.b - 1)
	for i := 0; i < 8; i++ {
		e := atomic.LoadUint64(&c.buckets[bucketi][i])
		if e == 0 {
			atomic.StoreUint64(&c.buckets[bucketi][i], uint64(key)<<32|uint64(value))
			c.updLru(bucketi, i)
			return
		}

		if uint32(e>>32) == key {
			e = uint64(key)<<32 | uint64(value)
			atomic.StoreUint64(&c.buckets[bucketi][i], e)
			c.updLru(bucketi, i)
			return
		}
	}

	// find the min lru
	lrus := atomic.LoadUint32(&c.lruBuckets[bucketi])
	var (
		minLru uint32
		mini   int
	)

	for i := 0; i < 8; i++ {
		lru := lrus | 0b1111<<uint32(i)
		if lru < minLru {
			minLru = lru
			mini = i
		}
	}

	atomic.StoreUint64(&c.buckets[bucketi][mini], uint64(key)<<32|uint64(value))
	c.updLru(bucketi, mini)
}

写入步骤如下:

  1. key对容量取模,计算出key落到哪个桶里,然后遍历桶中8个槽(K-V对)
  2. 如果遍历槽为0,说明这个槽还没有被占用,写入当前槽,并更新lru值为7,其他槽的lru值-1
  3. 如果遍历槽的key等于当前的key,则更新这个槽的值,并更新lru值为7,其他槽的lru值-1
  4. 如果遍历完后没有空槽也没有命中key,则找到lru值最小的,淘汰掉然后写入新key并更新lru值

2.3 读流程

func (c *IntCache) Get(key uint32) (value uint32, exists bool) {
	bucketi := key & (1<<c.b - 1)
	for i := 0; i < 8; i++ {
		e := atomic.LoadUint64(&c.buckets[bucketi][i])
                if e == 0 {
			break
		}

		if uint32(e>>32) == key {
			value = uint32(e)
			exists = true
			c.updLru(bucketi, i)
			break
		}
	}

	return
}

读取步骤如下:

  1. key对容量取模,计算出key落到哪个桶里,然后遍历桶中8个槽(K-V对)
  2. 如果遍历到的槽为0,说明后面的槽都是没有数据的,无需继续遍历
  3. 如果遍历到的槽的key等于查询的key,则返回value,并更新lru值

3. 总结

3.1 为什么lru要单独存储

每个bucket占用8*8=64B,正好是一个x86 cpu cacheline的大小,刚好填满一个cacheline,这样遍历bucket上8个槽实际只需要读取第一个槽时访问一次内存,后续访问都会直接从cpu cache中读到(当然前提是没有写请求造成cacheline过期),这样可以充分利用cpu缓存。

如果lru值与bucket存储在一起,那么系统中大量的读请求修改lru值就会造成cacheline过期的可能性就会变大,而如果分开存储,读请求不会造成cacheline过期。

你可能会问频繁的写入也会造成cacheline过期影响性能啊,但是我们这是一个典型的读多写少的系统,而且大量的bucket也降低了cacheline过期的几率。

3.2 缺陷

3.2.1 适用场景有限

由于我这个组件用在了在线用户访问的场景中,我将bucket数量设置为日活人数/8,hash冲突的几率还是比较小的,从监控看缓存命中率还是比较可观的。

但是由于不是严格的LRU,因此其他业务场景可能不适用。

3.2.2 value与lru值的更新不是原子的

因为要提高cpu cache命中率,因此value更新与lru更新是分离的,无法做到原子性,这也是很不严谨的,但是我们这个业务场景中不需要严谨的lru,所以可以忽略。

4. 基准测试

与fastcache对比的基准测试代码

func BenchmarkIntcache(b *testing.B) {
	rand.Seed(1)
	var B uint8 = 21
	m := New(B)
	for i := 0; i < b.N; i++ {
		intcacheBenchmarkFn(m)
	}
}

func BenchmarkFastcache(b *testing.B) {
	rand.Seed(1)
	m := fastcache.New(1 << 21 * (64 + 8))
	for i := 0; i < b.N; i++ {
		fastcacheBenchmarkFn(m)
	}
}

func intcacheBenchmarkFn(m *IntCache) {
	wg := &sync.WaitGroup{}
	for i := 0; i < 1000; i++ {
		wg.Add(1)
		go func() {
			for j := 0; j < 300; j++ {
				key := rand.Uint32()
				m.Set(key, key)
				m.Get(key)
			}
			wg.Done()
		}()
	}
	wg.Wait()
}

func fastcacheBenchmarkFn(m *fastcache.Cache) {
	wg := &sync.WaitGroup{}
	for i := 0; i < 1000; i++ {
		wg.Add(1)
		go func() {
			for j := 0; j < 300; j++ {
				key := rand.Uint32()
				b := make([]byte, 4) // uint64的大小为8字节
				binary.LittleEndian.PutUint32(b, key)
				m.Set(b, b)
				r := make([]byte, 4)
				m.Get(r, b)
			}
			wg.Done()
		}()
	}
	wg.Wait()
}

都是1000个协程并发读写300次,结果:

goos: darwin
goarch: arm64
pkg: github.com/Orlion/intcache
BenchmarkIntcache
BenchmarkIntcache-10                  14          78865301 ns/op
BenchmarkFastcache
BenchmarkFastcache-10                 10         113746746 ns/op
PASS
ok      github.com/Orlion/intcache      9.767s

可以看到我们这个实现要快一点。 image.png

...

阅读全文 »

记一次SIMD指令优化计算的失败经历

Category 汇编
Tag Go
Posted on
View

1. 前言

书接上回 《统计一个数字二进制位1的个数》,现在我们已经知道如何快速计算出一个int64数字的二进制位1的个数,那么回到我们最初的需求,我们的目的是快速统计一个bitmap中二进制位1的个数,假设我们使用[]uint64来实现bitmap,那么如果要统计这个bitmap中二进制位1的个数,我们可以遍历每个元素,计算出每个uint64元素二进制位1的个数,最后加起来,代码大概如下:

type Bitmap []uint64

func (bitmap Bitmap) OnesCount() (count int) {
	for _, v := range bitmap {
		count += OnesCount64(v)
	}

	return
}

const m0 = 0x5555555555555555 // 01010101 ...
const m1 = 0x3333333333333333 // 00110011 ...
const m2 = 0x0f0f0f0f0f0f0f0f // 00001111 ...

// 计算出x中二进制位1的个数,该函数上篇文章有详细解释,看不懂可以再回去看下
func OnesCount64(x uint64) int {
	const m = 1<<64 - 1
	x = x>>1&(m0&m) + x&(m0&m)
	x = x>>2&(m1&m) + x&(m1&m)
	x = (x>>4 + x) & (m2 & m)
	x += x >> 8
	x += x >> 16
	x += x >> 32
	return int(x) & (1<<7 - 1)
}

这种实现方式在bitmap元素过多,切片长度过长的情况下,计算十分耗时。那么如何优化这段代码呢?

2. 优化

现代CPU一般都支持SIMD指令,通过SIMD指令可以并行执行多个计算,以加法运算为例,如果我们要计算{A0,A1,A2,A3}四个数与{B0,B1,B2,B3}的和,不使用SIMD指令的话,需要挨个计算A0+B0A1+B1A2+B2A3+B3的和。使用SIMD指令的话,可以将{A0,A1,A2,A3}{A0,A1,A2,A3}四个数加载到xmm(128bit)/ymm(256bit)/zmm(512bit)寄存器中,然后使用一条指令就可以同时计算对应的和。这样理论上可以获得N倍的性能提升。

image.png

我们可以采用SIMD指令将OnesCount64函数并行化,并行计算4个uint64数字的结果,代码实现如下:

在popcnt.go文件中定义SimdPopcntQuad函数

package popcnt

func SimdPopcntQuad(nums [4]uint64) [4]uint64

在popcnt.s文件中我们使用汇编实现SimdPopcntQuad函数

#include "textflag.h"

TEXT ·SimdPopcntQuad(SB),NOSPLIT,$0-64
    VMOVDQU nums+0(FP), Y0 // Y0 = x,将四个uint64数字加载到Y0寄存器
    MOVQ $0x5555555555555555, AX
    MOVQ AX, X9
    VPBROADCASTQ X9, Y5 // Y5 = m0 // 上面三行代码将4个m0加载到Y5寄存器
    MOVQ $0x3333333333333333, AX
    MOVQ AX, X9
    VPBROADCASTQ X9, Y6 // Y6 = m1 // 上面三行代码将4个m1加载到Y6寄存器
    MOVQ $0x0f0f0f0f0f0f0f0f, AX
    MOVQ AX, X9
    VPBROADCASTQ X9, Y7 // Y7 = m2 // 上面三行代码将4个m2加载到Y7寄存器
    MOVQ $0x7f, AX
    MOVQ AX, X9
    VPBROADCASTQ X9, Y8 // Y8 = m;上面三行代码将4个m3加载到Y8寄存器
    VPSRLQ $1, Y0, Y1 // Y1 = x>>1;Y0寄存器上四个uint64数字并行右移1位
    VPAND Y1, Y5, Y1 // Y1 = x>>1&m0;Y1寄存器上四个uint64数字并行与Y5寄存器上的四个m0并行与,结果存到Y1寄存器
    VPAND Y0, Y5, Y2 // Y2 = x&m0
    VPADDQ Y1, Y2, Y0 // x = x>>1&m0 + x&m0
    VPSRLQ $2, Y0, Y1 // Y1 = x>>2
    VPAND Y1, Y6, Y1 // Y1 = x>>2&m1
    VPAND Y0, Y6, Y2 // Y2 = x&m1
    VPADDQ Y1, Y2, Y0 // x = x>>2&m1 + x&m1
    VPSRLQ $4, Y0, Y1 // Y1 = x>>4
    VPAND Y1, Y7, Y1 // Y1 = x>>4&m2
    VPAND Y0, Y7, Y2 // Y2 = x&m2
    VPADDQ Y1, Y2, Y0 // x = x>>2&m2 + x&m2
    VPSRLQ $8, Y0, Y1 // Y1 = x >> 8
    VPADDQ Y1, Y0, Y0 // x += x >> 8
    VPSRLQ $16, Y0, Y1 // Y1 = x >> 16
    VPADDQ Y1, Y0, Y0 // x += x >> 16
    VPSRLQ $32, Y0, Y1 // Y1 = x >> 32
    VPADDQ Y1, Y0, Y0 // x += x >> 32
    VPAND Y0, Y8, Y0 // x & (1<<7-1)
    VMOVDQU Y0, ret+32(FP) // 将结果加载到内存中返回值的位置
    RET

Benchmark

理论上讲如此优化之后我们应该可以获得四倍的性能提升,所以我们写个基准测试验证下:

// 优化之后的并行计算测试
func BenchmarkSimdPopcntQuad(b *testing.B) {
        // 使用随机数防止编译阶段被编译器预先计算出来
	rand.Seed(time.Now().UnixNano())
	nums := [4]uint64{rand.Uint64(), rand.Uint64(), rand.Uint64(), rand.Uint64()}
	for i := 0; i < b.N; i++ {
		SimdPopcntQuad(nums)
	}
}

// 优化之前的顺序计算测试
func BenchmarkSerial(b *testing.B) {
        // 使用随机数防止编译阶段被编译器预先计算出来
	rand.Seed(time.Now().UnixNano())
	nums := [4]uint64{rand.Uint64(), rand.Uint64(), rand.Uint64(), rand.Uint64()}
	for i := 0; i < b.N; i++ {
		serialPopcntQuad(nums)
	}
}

func serialPopcntQuad(nums [4]uint64) [4]uint64 {
	return [4]uint64{uint64(bits.OnesCount64(nums[0])), uint64(bits.OnesCount64(nums[1])), uint64(bits.OnesCount64(nums[2])), uint64(bits.OnesCount64(nums[3]))}
}

运行后结果如下

# go test -bench=. -v
=== RUN   TestSimdPopcntQuad
--- PASS: TestSimdPopcntQuad (0.00s)
goos: linux
goarch: amd64
pkg: github.com/Orlion/popcnt
cpu: Intel Core Processor (Broadwell, no TSX)
BenchmarkSimdPopcntQuad
BenchmarkSimdPopcntQuad-8        3693530               330.8 ns/op
BenchmarkSerial
BenchmarkSerial-8               539924296                2.232 ns/op
PASS
ok      github.com/Orlion/popcnt        2.993s

可以看到优化后的并行计算比原始的顺序计算慢了150倍😭,失败~

image.png

3. 分析

虽然优化失败了,但是我们还是要分析复盘下其中的原因,从中汲取一些经验,下面我们从两方面来分析下。

3.1 未优化函数为什么快?

首先我们可以看到未优化的函数serialPopcntQuad计算四个数字竟然只花了2ns,根据Numbers Everyone Should Know一文,访存的时间大概是100ns,这就有点离谱了,计算竟然不从内存加载我们的参数?

下面我们写段main函数,使用随机数来调用下serialPopcntQuad函数,然后反汇编看下汇编代码分析下。

func main() {
	rand.Seed(time.Now().UnixNano())
	nums := [4]uint64{rand.Uint64(), rand.Uint64(), rand.Uint64(), rand.Uint64()}
	results := serialPopcntQuad(nums)
	fmt.Println(results)
}

func serialPopcntQuad(nums [4]uint64) [4]uint64 {
	return [4]uint64{uint64(bits.OnesCount64(nums[0])), uint64(bits.OnesCount64(nums[1])), uint64(bits.OnesCount64(nums[2])), uint64(bits.OnesCount64(nums[3]))}
}

编译后反汇编:

image.png

从汇编代码中可以看到在调用bits.OnesCount64之前会判断cpu是否支持popcnt指令,如果支持则使用popcnt指令来计算而不是调用bits.OnesCount64来计算,恰好我机器支持popcnt指令,省略了bits.OnesCount64中的一堆计算,因此计算速度非常快。

3.2 优化后为什么慢?

正如3.1中所提到的,相较于cpu计算,访存的代价是非常高的,大概是100ns,而我们汇编代码中为了使用SIMD指令实现统计算法有大量的访存操作。

受限于本人对汇编掌握程度,上面的汇编代码质量应该是很差的,并不能证明SIMD性能差,可能有性能更高的实现,请各位大佬指点。

而且当前Go汇编在不指定编译参数的情况下只能采用旧函数调用约定,必须采用内存传参,所以导致最终基准测试的结果很差。

4. 收获

这一通瞎折腾虽然最终结果失败,但还是有很多收获的。首先真实的体会到了访存有多慢,所以日后在进行性能优化时就会注意这一点,尽量使代码能命中CPU缓存。

再一个就是之前并没有使用过SIMD指令,也没有接触过这种级别的优化,这次算是入门了。

后端选手,水平有限,各位计算机科学家见笑了。

image.png

5. 参考资料

  1. 玩转SIMD指令编程

...

阅读全文 »

统计一个数字二进制位1的个数

Tag Go
Posted on
View

最近一个需求需要使用golang实现一个兼容redis的无压缩的bitmap,需要提供一个bitcoun函数来统计这个bitmap中二进制位1的个数,查了一圈并没有找到类似的第三方库,因此决定自己实现一个.(利用一切机会造轮子

1. 问题简化

问题本质实际就是给定一个数字,比如一个二进制数10101101,计算出这个数字中二进制位1的个数,对于10101101这个数字来说它有5个位为1,即:10101101

对于这个问题,最简单的办法就是挨位数,不过这个办法太笨了,没有逼格。

那么有没有银弹呢?答案是肯定的,而且还不止一种。 退后 ,我要开始装逼了

2. 查表法

对于一个8位的数字来说,它只有256个值,因此完全可以预先计算好每个值的二进制位1个个数写入到映射表中,使用时直接查询这张映射表即可。

伪代码如下所示:

var count1map = map[uint8]uint8 {
    0b0000_0000: 0,
    0b0000_0001: 1,
    ...
    0b1111_1111: 8,
}

func bitcount(x uint8) uint8 {
    return count1map[x]
}

3. 移位法

查表法虽然可以应对8位这样值数量有限的数字,但是对于uint64 or int64这样64位的数字来说,它的值数量是非常多的,我们无法在内存中维护这样巨大的映射表,因此不能使用查表法来解决

Golang在bits包中提供一个OnesCount64(x uint64) int的函数,可以计算一个64位数字中二进制为1的个数,其源码如下:

const m0 = 0x5555555555555555 // 01010101 ...
const m1 = 0x3333333333333333 // 00110011 ...
const m2 = 0x0f0f0f0f0f0f0f0f // 00001111 ...
const m3 = 0x00ff00ff00ff00ff // etc.
const m4 = 0x0000ffff0000ffff

func OnesCount64(x uint64) int {
	const m = 1<<64 - 1
	x = x>>1&(m0&m) + x&(m0&m)
	x = x>>2&(m1&m) + x&(m1&m)
	x = (x>>4 + x) & (m2 & m)
	x += x >> 8
	x += x >> 16
	x += x >> 32
	return int(x) & (1<<7 - 1)
}

初看起来是有点懵逼的,一顿位运算操作怎么就能把1的个数算出来了呢?

这段代码注释中标明其来源于Hacker’s Delight第5章

骚操作

别着急,我们还是采用自底向上的思想来拆解下。

3.1 2位数字二进制位1的个数

我们先想一下如何计算2位的数字二进制位1的个数,答案是非常简单的:

func OnesCount2(x uint2) int {
    return (x & 0b01) + ((x >> 1) & 0b01)
}

x & 0b01就是求第0位是不是1,((x >> 1) & 0b01)就是求第1位是不是1,加起来就是x这个2位数字二进制位1的个数。

3.2 4位数字二进制位1的个数

对于一个4位数字,如1011,我们先按照3.1中的算法分别求出第3位与第2位即10 和 第1位与第0位即11的二进制位1的个数,然后再加起来就得出这个4位数字的二进制位1的个数了。

伪代码如下所示:

func OnesCount4(x uint4) int {
    x = x & 0b0101 + x >> 1 & 0b0101
    return x & 0b0011 + x >> 2 & 0b0011
}

计算过程如图: image.png

3.3 8位数字二进制位1的个数

8位数字计算过程与4位计算过程本质是相同的,都是拆解组合,伪代码如下:

func OnesCount8(x uint8) int {
    x = x & 0b01010101 + x >> 1 & 0b01010101
    x = x & 0b00110011 + x >> 2 & 0b00110011
    return x & 0b00001111 + x >> 4 && 0b00001111
}

计算过程如下: image.png

64位数字重复这个过程即可,回头看golang的代码应该就可以看懂了,这里就不再详细解释了。

另外这个算法过程还可以进一步优化,详细可以参考下:计算汉明权重的SWAR(SIMD within a Register)算法 感兴趣的可以研究一下,这里就不赘述了。

4. POPCNT指令

一些较新的CPU上支持POPCNT指令,可以通过硬件直接进行计算,Golang代码示例如下:

main.go文件

package main

import (
    "fmt"
    "math/bits"
    "math/rand"
    "time"
)

func main() {
    rand.Seed(time.Now().Unix())
    for i := 0; i < 100; i++ {
        var num = rand.Uint64()
        if popcnt(num) != bits.OnesCount64(num) {

            panic(fmt.Sprintf("i: %d, popcnt(%b) = %d, bits.OnesCount64(%b) = %d", i, num, popcnt(num), num, bits.OnesCount64(num)))
        }
    }
    fmt.Println("ok")
}

func popcnt(x uint64) int

amd64.s 文件

#include "textflag.h"

TEXT main·popcnt(SB), NOSPLIT, $0-8
    MOVQ x+0(FP), AX // 将参数x移到AX寄存器
    BYTE $0xf3; BYTE $0x48; BYTE $0x0f; BYTE $0xb8; BYTE $0xc0  // 计算二进制X中1的个数,golang编译器不支持POPCNT指令,这行对应于POPCNT AX, AX
    MOVQ AX, ret+8(FP) // 将结果存入ret
    RET

...

阅读全文 »

Go数据库连接池设置不合理导致大量TIME_WAIT连接占满端口问题排查与解决

Category Golang
Tag Go
Posted on
View

1. 问题与背景

最近公司内部准备尝试使用下腾讯的TDSQL,因此组内同学写了一段很简单的查询TDSQL的go web程序,使用ab对其进行一个简单压测以获取TDSQL的性能表现,go代码如下:

package main

import (
    "crypto/md5"
    "database/sql"
    "fmt"
    "log"
    "math/rand"
    "net/http"
    "strconv"
    "time"

    "github.com/gin-gonic/gin"
    _ "github.com/go-sql-driver/mysql"
)

func main() {
    r := gin.New()
    r.Use(gin.Logger())
    r.GET("/test", func(c *gin.Context) {
        c.JSON(200, gin.H{
            "message": "test",
        })
    })

    dbconnect, err := sql.Open("mysql", "user:passwd@tcp(10.43.0.43:3306)/dbname")
    if err != nil {
        panic(err)
    }
    dbconnect.SetMaxIdleConns(5)
    dbconnect.SetMaxOpenConns(10)
    dbconnect.SetConnMaxLifetime(time.Hour)
    dbconnect.SetConnMaxIdleTime(time.Hour)
    r.GET("tdsql_test", func(context *gin.Context) {
        muid := fmt.Sprintf("%x", md5.Sum([]byte(strconv.Itoa(rand.Intn(1000000000000)))))
        rows, err := dbconnect.Query(fmt.Sprintf("select muid from rtb_channel_0 where muid='%s'", muid))
        if err != nil {
            log.Fatal(err)
            context.JSON(http.StatusInternalServerError, gin.H{
                "error_code": -3,
            })
            return
        }

        rows.Close()

        context.JSON(http.StatusOK, gin.H{
            "error_code": 0,
            "error_msg":  muid,
            "data":       "result",
        })
    })

    r.Run(":9000")
}

ab压测命令如下:

ab -c 10 -n 500000 "http://127.0.0.1:9000/tdsql_test"

压测开始不久之后代码log.Fatal(err)就打印出了错误信息并退出了: image.png

dial tcp 10.43.0.43:3306: connect: cannot assign requested address

这段错误信息是说无法连接到10.43.0.43:3306 ,原因是无法分配请求地址号,就是说本地端口号都被占用了。那么我们就开始进行排查,端口号究竟是被谁占满的?

2. 排查过程

2.1 通过netstat命令查看端口都被谁占用

netstat -nta | grep 10.43.0.43

有如下输出: image.png 可以看到有大量处于TIME_WAIT状态的TCP连接,这些连接占用了大量的端口。那么这些TIMI_WAIT状态的TCP连接是从哪来的呢?

为了弄清楚这个问题,我们必须知道TIME_WAIT状态是怎么回事。

2.1.1 TIME_WAIT

image.png

上图是经典的TCP四次挥手断开连接的过程。可以看到在四次挥手的过程中,主动关闭连接的一端在收到对端发送的FIN包之后会进入TIME_WAIT状态,会等待2MSL之后才能真正关闭连接。

MSL: 最长报文段寿命(Maximum Segment Lifetime),是一个工程值(经验值),RFC标准是2分钟,不过有点太长了,一般是30秒,1分钟,2分钟。

为什么客户端TIME_WAIT状态等待2MSL呢?四次握手最后一步客户端向服务端响应ACK,后有两种情况:

  1. 服务端没有收到ACK,这时服务端会超时重传FIN
  2. 服务端收到了ACK,但是客户端不知道服务端有没有收到

无论1还是2,客户端都需要等待,要取这两种情况等待的最大值以应对最坏情况的发生,这个最坏的情况就是: 去向ACK消息的最大生存时间(MSL) + 来向FIN消息的最大生存时间(MSL) 。可以看到加起来正好是2MSL。等待2MSL,客户端就可以放心大胆的释放TCP连接了,此时可以使用该端口号连接任何服务器。

如果没有TIME_WAIT,新连接直接复用该连接占用的端口话,恰好回复的ACK包没有达到对端,导致对方重传FIN包,这时新连接就会被错误的关闭。

2.1.2 使用了连接池为什么还会出现大量的TIME_WAIT连接呢

首先大量的TIME_WAIT连接说明了我们的go程序建立了大量的连接然后又关闭了,但是理论上使用了连接池连接都应该得到复用,不会建立大量的连接才对。

这时我首先检查了是不是连接池的ConnMaxLifetimeConnMaxIdleTime设置的太小导致连接被关闭。我回看了代码发现同事设置了一个小时的时长,那么就不可能是这个原因了。

然后我将怀疑的矛头指向了TDSQL,因为TDSQL是我们首次使用,之前使用Mysql时也没有遇到过这个问题,会不会是TDSQL发送/回复了某个特殊的包导致了客户端主动断开呢?

2.2 验证是否是TDSQL的问题

为了验证上述的猜想,使用tcpdump在服务器上抓了个包

tcpdump -i any host 10.43.0.43 -w output.pcap

然后将抓到的outout.pcap包down下来后丢到本机wireshark中进行分析,选择一个端口过滤下可以看到整个tcp连接的所有包。

image.png

可以看到TDSQL发送的都是正常的mysql协议包,并没有什么特殊的包,因此到这里基本可以确认不是TDSQL的锅。

那么排查的重点又回到了golang连接池,golang连接池为什么会主动断开连接?

2.3 golang为什么会主动断开连接?

由于golang整个sql包非常复杂,我们可以自底向上的思路来排查问题,首先我们找到mysql驱动包go-sql-driver/mysql中关闭连接的函数:

func (mc *mysqlConn) Close() (err error)

它位于connection.go文件中。

下面我们使用dlv 来启动上面的go程序

$ dlv debug main.go

进入到dlv控制台,然后在控制台中输入break mysql.(*mysqlConn).Close在这个函数上打个断点,继续输入c 命令让程序继续执行,然后使用ab命令进行这个web程序进行压测。

不出所料程序断在了mysql.(*mysqlConn).Close。然后我们使用bt命令来打印下调用栈: image.png

可以清楚的看到整个核心调用链是:

rwos.Close -> driverConn.releaseConn -> DB.putConn -> driverConn.Close -> mysqlConn.Close

问题的关键是DB.putConn ,我们可以分析下源码:

func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
    ...
    added := db.putConnDBLocked(dc, nil)
    db.mu.Unlock()

    if !added {
        dc.Close()
        return
    }
}

func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
    if db.closed {
        return false
    }
    if db.maxOpen > 0 && db.numOpen > db.maxOpen {
        return false
    }
    if c := len(db.connRequests); c > 0 {
        var req chan connRequest
        var reqKey uint64
        for reqKey, req = range db.connRequests {
            break
        }
        delete(db.connRequests, reqKey) // Remove from pending requests.
        if err == nil {
            dc.inUse = true
        }
        req <- connRequest{
            conn: dc,
            err:  err,
        }
        return true
    } else if err == nil && !db.closed {
        if db.maxIdleConnsLocked() > len(db.freeConn) { // db.maxIdleConnsLocked()取自db.maxIdleCount
            db.freeConn = append(db.freeConn, dc)
            db.startCleanerLocked()
            return true
        }
        db.maxIdleClosed++
    }
    return false
}

从源码中我们可以得知是DB.putConnDBLocked返回了false导致了连接关闭。为了验证这一点可以继续使用dlv 在dc.Close()这一行打个断点,然后重新压测这个程序: image.png 可以看到程序确实是走到了dc.Close()这一行,我们继续打印上下文的数据: image.png

到这里我们就知道了是由于db.maxIdleCount == len(db.freeConn)导致了连接没有被复用。

db.maxIdleCount是我们代码中设置的dbconnect.SetMaxIdleConns(5)也就是5

那么问题的原因其实就很简单了,我们设置了最大闲置连接数为5,最大可建立连接数为10,那么进程中最多可出现10个连接,这10个连接中只有5个可以被丢回到连接池中复用,而另外5个连接由于超过了我们设置最大闲置连接数5所以不会被丢回到连接池中复用,因此使用完就close了。当并发高的情况下就会出现大量的连接打开与关闭。

3. 解决

最大闲置连接数设置成一个大于等于最大连接数的值即可,比如下面这样:

dbconnect.SetMaxIdleConns(10)
dbconnect.SetMaxOpenConns(10)

...

阅读全文 »

与世界分享我刚编的mysql http隧道工具-hersql原理与使用

Category Mysql
Tag Mysql
Tag Go
Posted on
View

原文地址:https://blog.fanscore.cn/a/53/

1. 前言

本文是与世界分享我刚编的转发ntunnel_mysql.php的工具的后续,之前的实现有些拉胯,这次重构了下。需求背景是为了在本地macbook上通过开源的mysql可视化客户端(dbeaver、Sequel Ace等)访问我司测试环境的mysql,整个测试环境的如图所示:

image.png

那么就有以下几种方式:

  • 客户端直连mysql #Pass# 测试环境mysql只提供了内网ip,只允许测试环境上的机器连接,因此不可行
  • 通过ssh隧道连接 #Pass# 测试环境机器倒是可以ssh上去,但是只能通过堡垒机接入,且堡垒机不允许ssh隧道,因此不可行
  • navicat http隧道连接 #Pass# 测试环境有机器提供了公网ip开放了http服务,因此技术上是可行的,但navicat非开源免费软件,我司禁止使用,因此不可行
  • 测试环境选一台机器建立mysql代理转发请求 #Pass# 测试环境机器只开放了80端口,且已被nginx占用,因此不可行
  • 内网穿透 这个想法很好,下次不要再想了

image.png

既然上面的方式都不行,那怎么办呢?因此我产生了一个大胆的想法

2. 一个大胆的想法

大概架构如下 image.png

首先,在本地pc上启动一个sidecar进程,该进程监听3306端口,实现mysql协议,将自己伪装为一个mysql server。本地pc上的mysql客户端连接到sidecar,发送请求数据包给sidecar,从sidecar读取响应包。

然后在测试环境某台机器上启动transport进程,该进程启动http服务,由nginx代理转发请求,相当于监听在80端口,然后连接到测试环境的mysql server。

sidecar会将来自客户端的请求包通过http请求转发给transporttransport将请求包转发到测试环境对应的mysql server,然后读取mysql的响应数据包,然后将响应数据包返回给sidecarsidecar再将响应包返回给mysql客户端。

遵循上述的基本原理,我将其实现出来: https://github.com/Orlion/hersql。但是在描述hersql的实现细节之前我们有必要了解下mysql协议

3. mysql协议

mysql客户端与服务端交互过程主要分为两个阶段:握手阶段与命令阶段。交互流程如下: image.png

在最新版本中,握手过程比上面要复杂,会多几次交互

3.1 握手阶段

在握手阶段,3次握手建立tcp连接后服务端会首先发送一个握手初始化包,包含了 * 协议版本号:指示所使用的协议版本。 * 服务器版本:指示MySQL服务器版本的字符串。 * 连接ID:在当前连接中唯一标识客户端的整数。 * 随机数据:包含一个随机字符串,用于后续的身份验证。 * 服务器支持的特性标志:指示服务器支持的客户端功能的位掩码。 * 字符集:指示服务器使用的默认字符集。 * 默认的身份验证插件名(低版本没有该数据)

随后客户端会发送一个登录认证包,包含了:

  • 协议版本号:指示所使用的协议版本。
  • 用户名:用于身份验证的用户名。
  • 加密密码:客户端使用服务端返回的随机数对密码进行加密
  • 数据库名称:连接后要使用的数据库名称。
  • 客户端标志:客户端支持的功能的位掩码。
  • 最大数据包大小:客户端希望接收的最大数据包大小。
  • 字符集:客户端希望使用的字符集。
  • 插件名称:客户端希望使用的身份验证插件的名称。

服务端收到客户端发来的登录认证包验证通过后会发送一个OK包,告知客户端连接成功,可以转入命令交互阶段

在mysql 8.0默认的身份验证插件为caching_sha2_password,低版本为mysql_native_password,两者的验证交互流程有所不同个,caching_sha2_password在缓存未命中的情况下还会多几次交互。另外如果服务端与客户端的验证插件不同的话,也是会多几次交互。

3.2 命令阶段

在命令阶段,客户端会发送命令请求包到服务端。数据包的第一个字节标识了当前请求的类型,常见的命令有:

  • COM_QUERY命令,执行SQL查询语句。
  • COM_INIT_DB命令,连接到指定的数据库。
  • COM_QUIT命令,关闭MySQL连接。
  • COM_FIELD_LIST命令,列出指定表的字段列表。
  • COM_PING命令,向MySQL服务器发送PING请求。
  • COM_STMT_系列预处理语句命令

请求响应的模式是客户端会发一个请求包,服务端会回复n(n>=0)个响应包

最后客户端断开连接时会主动发送一个COM_QUIT命令包通知服务端断开连接

4. hersql数据流转过程

在了解mysql协议之后我们就可以来看下hersql的数据流转过程了。

image.png

transport连接mysql server时必须要知道目标数据库的地址与端口号(mysql client连接的是sidecar),所以hersql要求mysql client需要在数据库名中携带目标数据库的地址与端口号。

transport发给mysql server的登录请求包中需要包含用mysql server发来的随机数加密之后的密码,但是mysql client给到sidecar的登录请求包中的密码是用sidecar给的随机数加密的,因此无法直接拿来使用,所以hersql要求mysql client需要在数据库名中携带密码原文,transport会用mysql server给的随机数进行加密, 这也是hersql的局限。

5. hersql使用

上面介绍了一堆原理性的东西,那么如何使用呢?

5.1 在一台能够请求目标mysql server的机器上部署hersql transport

首先你需要下载下来hersql的源码:https://github.com/Orlion/hersql,还需要安装下golang,这些都完成后你就可以启动hersql transport了。但是先别着急,我先解释下transport的配置文件tranport.example.yaml:

server:
  # transport http服务监听的地址
  addr: :8080

log:
  # 标准输出的日志的日志级别
  stdout_level: debug
  # 文件日志的日志级别
  level: error
  # 文件日志的文件地址
  filename: ./storage/transport.log
  # 日志文件的最大大小(以MB为单位), 默认为 100MB。日志文件超过此大小会创建个新文件继续写入
  maxsize: 100
  # maxage 是根据文件名中编码的时间戳保留旧日志文件的最大天数。 
  maxage: 168
  # maxbackups 是要保留的旧日志文件的最大数量。默认是保留所有旧日志文件。
  maxbackups: 3
  # 是否应使用 gzip 压缩旋转的日志文件。默认是不执行压缩。
  compress: false

你可以根据你的需求修改配置,然后就可以启动transport

$ go run cmd/transport/main.go -conf=transport.example.yaml

一般情况下都是会先编译为可执行文件,由systemd之类的工具托管transport进程,保证transport存活。这里简单期间直接用go run起来

5.2 在你本地机器部署启动hersql sidecar

同样的,你需要下载下来hersql的源码:https://github.com/Orlion/hersql,提前安装好golang。修改下sidecar的配置文件sidecar.example.yaml:

server:
  # sidecar 监听的地址,之后mysql client会连接这个地址
  addr: 127.0.0.1:3306
  # transport http server的地址
  transport_addr: http://x.x.x.x:xxxx
log:
  # 与transport配置相同

就可以启动sidecar

$ go run cmd/sidecar/main.go -conf=sidecar.example.yaml

同样的,一般情况下也都是会先编译为可执行文件,mac上是launchctl之类的工具托管sidecar进程,保证sidecar存活。这里简单期间直接用go run起来

5.3 客户端连接

上面的步骤都执行完成后,就可以打开mysql客户端使用了。数据库地址和端口号需要填写sidecar配置文件中的addr地址,sidercar不会校验用户名和密码,因此用户名密码可以随意填写

重点来了: 数据库名必须要填写,且必须要按照以下格式填写

[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]

举个例子:

root:123456@tcp(10.10.123.123:3306)/BlogDB

如图所示: image.png

5.4 举个例子

目标mysql服务器

  • 地址:10.10.123.123:3306
  • 数据库:BlogDB
  • 用户名:root
  • 密码:123456

可以直连目标mysql服务器的机器

  • 地址:10.10.123.100
  • 开放端口:8080

那么transport可以配置为

server:
  addr: :8080

sidecar可以配置为

server:
  addr: 127.0.0.1:3306
  transport_addr: http://10.10.123.100:8080

客户端连接配置

  • 服务器地址:127.0.0.1
  • 端口: 3306
  • 数据库名root:123456@tcp(10.10.123.123:3306)/BlogDB

5.5 局限

hersql目前只支持mysql_native_password的认证方式,mysql8默认的认证方式是caching_sha2_password,所以如果要通过hersql连接mysql8需要注意登录用户的认证方式是否是mysql_native_password,如果是caching_sha2_password那暂时是无法使用的。

6. 参考资料

如果hersql对你有帮助欢迎点个star

...

阅读全文 »

redis georadius源码分析与性能优化

Category Redis
Tag Redis
Tag Go
Posted on
View

背景

最近接到一个需求,开发中使用了redis georadius命令取附近给定距离内的点。完工后对服务进行压测后发现georadius的性能比预期要差,因此我分析了georadius的源码,并对原始的实现方案进行了优化,总结成了本文。

我们生产环境使用的redis版本为4.0.13,因此本文redis源码皆为4.0.13版本的源码

redis geo原理

往redis中添加坐标的命令是GEOADD key longitude latitude member [longitude latitude member ...],实际上redis会将经纬度转成一个52bit的整数作为zsetscore,然后添加到zset中,所以实际上redis geo底层就是个zset,你甚至可以直接使用zset的命令来操作一个geo类型的key。

那么经纬度是如何转成52bit整数的呢?业内广泛使用的方法是首先对经纬度分别按照二分法编码,然后将各自的编码交叉组合成最后的编码。我们以116.505021, 39.950898这个坐标为例看下如何编码:

  • 第一次二分操作,把经度分为两个区间:[-180,0)[0,180]116.505021落在右区间,因此用1表示第一次编码后的值
  • 第二次二分操作,把[0,180]分为两个区间[0,90)[90,180]116.505021落在右区间,因此用1表示第二次编码后的值
  • 第三次二分操作,把[90,180]分为两个区间[90,135)[135,180]116.505021落在左区间,因此用0表示第二次编码后的值
  • 按照这种方法依次处理,做完5次后,得到经度值的5位编码值:11010
分区次数 左区间 右区间 经度116.505021在区间 编码值
1 [-180, 0) [0, 180] [0, 180] 1
2 [0, 90) [90, 180] [90, 180] 1
3 [90, 135) [135, 180] [90, 135]) 0
4 [90, 112.5) [112.5, 135] [112.5, 135] 1
5 [112.5, 123.75) [123.75, 180] [112.5, 123.75] 0
  • 按照同样的方法对纬度值进行编码,得到纬度值的5位编码值:10111
分区次数 左区间 右区间 纬度39.950898在区间 编码值
1 [-90, 0) [0, 90] [0, 90] 1
2 [0, 45) [45, 90] [0, 45] 0
3 [0, 22.5) [22.5, 45] [22.5, 45]) 1
4 [22.5, 33.75) [33.75, 45] [33.75, 45] 1
5 [33.75, 39.375) [39.375, 45] [39.375, 45] 1

然后将经度编码11010和纬度编码值10111交叉得到最终geohash值1110011101

image.png

通常会使用base32将编码值转成字符串表示的hash值,与本文无关这里不多做介绍

根据如上的算法通常可以直观的写出如下的代码:

// 该代码来源于https://github.com/HDT3213/godis/blob/master/lib/geohash/geohash.go
func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) {
	box := [2][2]float64{
		{-180, 180}, // lng
		{-90, 90},   // lat
	}
	pos := [2]float64{longitude, latitude}
	hash := &bytes.Buffer{}
	bit := 0
	var precision uint = 0
	code := uint8(0)
	for precision < bitSize {
		for direction, val := range pos {
			mid := (box[direction][0] + box[direction][1]) / 2
			if val < mid {
				box[direction][1] = mid
			} else {
				box[direction][0] = mid
				code |= bits[bit]
			}
			bit++
			if bit == 8 {
				hash.WriteByte(code)
				bit = 0
				code = 0
			}
			precision++
			if precision == bitSize {
				break
			}
		}
	}
	if code > 0 {
		hash.WriteByte(code)
	}
	return hash.Bytes(), box
}

可以看到基本就是上述算法的实际描述,但是redis源码中却是另外一种算法:

int geohashEncode(const GeoHashRange *long_range, const GeoHashRange *lat_range,
                  double longitude, double latitude, uint8_t step,
                  GeoHashBits *hash) {
    // 参数检查此处代码省略
    ...
    
    double lat_offset =
        (latitude - lat_range->min) / (lat_range->max - lat_range->min);
    double long_offset =
        (longitude - long_range->min) / (long_range->max - long_range->min);

    lat_offset *= (1 << step);
    long_offset *= (1 << step);
    // lat_offset与long_offset交叉
    hash->bits = interleave64(lat_offset, long_offset);
    return 1;
}

那么该如何理解redis的这种算法呢?我们假设经度用3位来编码 image.png 可以看到编码值从左到右实际就是从000111依次加1递进的,给定的经度值在这条线的位置(偏移量)就是其编码值。假设给定经度值为50,那么它在这条线的偏移量就是(50 - -180) / (180 - -180) * 8 = 5即101

georadius原理

georadius命令格式为GEORADIUS key longitude latitude radius m|km|ft|mi [WITHCOORD] [WITHDIST] [WITHHASH] [COUNT count] [ASC|DESC] [STORE key] [STOREDIST key],以给定的经纬度为中心, 返回键包含的位置元素当中, 与中心的距离不超过给定最大距离的所有位置元素。

image.png

首先需要明确一点的是并非两个坐标点编码相近其距离越近,以上图为例,虽然A所在区块的编码与C所在区块编码较之B更相近,但实际B点距离A点更近。为了避免这种问题redis中会先计算出给定点东南西北以及东北、东南、西北、西南八个区块以及自己身所在的区块即九宫格区域内所有坐标点,然后计算与当前点的距离,再进一步筛选出符合距离条件的点。

假设要查附近100km的点,那么要保证矩形的边长要大于100km,才能保证能获取到所有符合条件的点,地球半径约6372.797km,第一次分割后可以得到四个东西长6372.797*π,南北长3186.319*π,继续切割:

分割次数 东西长(km) 南北长(km)
1 6372.797*π 3186.319*π
2 3186.319*π 1593.160*π
3 1593.160*π 796.58*π
4 796.58*π 398.29*π
5 398.29*π 199.145*π
6 199.145*π 99.573*π
7 99.573*π 49.787*π

分割到第七次时南北长49.787*π,如果再切分长度为24.894*π,长度小于100km,因此停止分割,所以如果要查附近100km的点,我们需要的精度为7

redis中根据给定的距离估算出需要的精度的代码如下

const double MERCATOR_MAX = 20037726.37;

uint8_t geohashEstimateStepsByRadius(double range_meters, double lat) {
    if (range_meters == 0) return 26;
    int step = 1;
    while (range_meters < MERCATOR_MAX) {
        range_meters *= 2;
        step++;
    }
    step -= 2;
    // 高纬度地区地球半径小因此适当降低精度
    if (lat > 66 || lat < -66) {
        step--;
        if (lat > 80 || lat < -80) step--;
    }

    if (step < 1) step = 1;
    if (step > 26) step = 26;
    return step;
}

调用encode0函数就能计算出给定点在step = geohashEstimateStepsByRadius()精度级别所在矩形区域的geohash值。接下来计算该矩形区域附近的八个区域。

...
// 调用encode0函数计算geohash
geohashEncode(&long_range,&lat_range,longitude,latitude,steps,&hash);
// 计算出附近八个区域
geohashNeighbors(&hash,&neighbors);
...

一个区域的东侧区域只要将经度的编码值+1即可,反之西侧区域只要将经度编码值-1即可,北侧区域只要将纬度的编码值+1即可,南侧区域只要将纬度的编码值-1即可。对应redis源码如下:

void geohashNeighbors(const GeoHashBits *hash, GeoHashNeighbors *neighbors) {
    neighbors->east = *hash;
    neighbors->west = *hash;
    neighbors->north = *hash;
    neighbors->south = *hash;
    neighbors->south_east = *hash;
    neighbors->south_west = *hash;
    neighbors->north_east = *hash;
    neighbors->north_west = *hash;
    // 纬度加1就是东侧区域
    geohash_move_x(&neighbors->east, 1);
    geohash_move_y(&neighbors->east, 0);
    // 纬度减1就是西侧区域
    geohash_move_x(&neighbors->west, -1);
    geohash_move_y(&neighbors->west, 0);
    // 精度减1就是南侧区域
    geohash_move_x(&neighbors->south, 0);
    geohash_move_y(&neighbors->south, -1);

    geohash_move_x(&neighbors->north, 0);
    geohash_move_y(&neighbors->north, 1);

    geohash_move_x(&neighbors->north_west, -1);
    geohash_move_y(&neighbors->north_west, 1);

    geohash_move_x(&neighbors->north_east, 1);
    geohash_move_y(&neighbors->north_east, 1);

    geohash_move_x(&neighbors->south_east, 1);
    geohash_move_y(&neighbors->south_east, -1);

    geohash_move_x(&neighbors->south_west, -1);
    geohash_move_y(&neighbors->south_west, -1);
}

image.png 如上图所示,当给定点在中心区域的东北侧时,西北、西、西南、南、东南五个方向的区域中的所有点距离给定点肯定超过了给定距离,所以可以过滤掉,redis代码如下所示:

if (steps >= 2) {
    if (area.latitude.min < min_lat) {
        GZERO(neighbors.south); // 南侧区域置零,过滤南侧区域
        GZERO(neighbors.south_west);
        GZERO(neighbors.south_east);
    }
    if (area.latitude.max > max_lat) {
        GZERO(neighbors.north);
        GZERO(neighbors.north_east);
        GZERO(neighbors.north_west);
    }
    if (area.longitude.min < min_lon) {
        GZERO(neighbors.west);
        GZERO(neighbors.south_west);
        GZERO(neighbors.north_west);
    }
    if (area.longitude.max > max_lon) {
        GZERO(neighbors.east);
        GZERO(neighbors.south_east);
        GZERO(neighbors.north_east);
    }
}

计算出区块后下一步就需要将九宫格区域中的所有坐标点拿出来,依次计算与给定点的距离,然后过滤出符合给定距离的点

// 遍历九宫格内所有点,依次计算与给定点的距离,然后过滤出符合给定距离的点添加到ga中
int membersOfAllNeighbors(robj *zobj, GeoHashRadius n, double lon, double lat, double radius, geoArray *ga) {
    GeoHashBits neighbors[9];
    unsigned int i, count = 0, last_processed = 0;
    int debugmsg = 1;

    neighbors[0] = n.hash;
    neighbors[1] = n.neighbors.north;
    neighbors[2] = n.neighbors.south;
    neighbors[3] = n.neighbors.east;
    neighbors[4] = n.neighbors.west;
    neighbors[5] = n.neighbors.north_east;
    neighbors[6] = n.neighbors.north_west;
    neighbors[7] = n.neighbors.south_east;
    neighbors[8] = n.neighbors.south_west;

    // 遍历九宫格
    for (i = 0; i < sizeof(neighbors) / sizeof(*neighbors); i++) {
        ...
        // 当给定距离过大时,区块可能会重复
        if (last_processed &&
            neighbors[i].bits == neighbors[last_processed].bits &&
            neighbors[i].step == neighbors[last_processed].step)
        {
            continue;
        }
        // 取出宫格内所有点,依次计算距离,符合条件后添加到ga中
        count += membersOfGeoHashBox(zobj, neighbors[i], ga, lon, lat, radius);
        last_processed = i;
    }
    return count;
}

int membersOfGeoHashBox(robj *zobj, GeoHashBits hash, geoArray *ga, double lon, double lat, double radius) {
    GeoHashFix52Bits min, max;
    // 根据区块的geohash值计算出对应的zset的score的上下限[min,max]
    scoresOfGeoHashBox(hash,&min,&max);
    // 取出底层的zset中的[min,max]范围内的元素,依次计算距离,符合条件后添加到ga中
    return geoGetPointsInRange(zobj, min, max, lon, lat, radius, ga);
}

georadius优化

从上一节中可以看到,给定距离范围越大,则九宫格区域越大,九宫格区域内的点就越多,而每个点都需要计算与中间点的距离,距离计算又涉及到大量的三角函数计算,所以这部分计算是十分消耗CPU的。又因为redis工作线程是单线程的,因此无法充分利用多核,无法通过增加redis server的CPU核数来提升性能,只能添加从库。

距离计算算法及优化可以看下美团的这篇文章: https://tech.meituan.com/2014/09/05/lucene-distance.html

对于这个问题,我们可以将九宫格以及距离计算部分提升到我们的应用程序即redis客户端来进行,步骤如下: * 在客户端计算出九宫格区域,然后转为zset score的范围 * 使用zrangebyscore命令从redis取出score范围内的所有点 * 遍历所有点依次计算与给定点的距离,筛选出符合距离条件的点

陌陌好像也是使用了这种方案:https://mp.weixin.qq.com/s/DL2P49y4R1AE2MIdkxkZtQ

由于我们使用golang进行开发,因此我将redis中的georadius部分代码转为了golang代码,并整理成一个库开源在了github:https://github.com/Orlion/go-georadius

原本的写法是:

client.GeoRadius(key, longitude, latitude, &redis.GeoRadiusQuery{
	Radius:    1000,
	Unit:      "m", // 距离单位
	Count:     1,          // 返回1条
	WithCoord: true,       // 将位置元素的经纬度一并返回
	WithDist:  true,       // 一并返回距离
})

改造后:

ga := make([]redis.Z, 0)
ranges := geo.NeighborRanges(longitude, latitude, 1000)
for _, v := range ranges {
    zs, _ := client.ZRangeByScoreWithScores(key, redis.ZRangeBy{
		Min: strconv.Itoa(int(v[0])),
		Max: strconv.Itoa(int(v[1])),
	}).Result()
	for _, z := range zs {
	    dist := geox.GetDistanceByScore(longitude, latitude, uint64(z.Score))
		if dist < 1000 {
		    ga = append(ga, z)
		}
	}
}

压测结果对比

43w坐标点,取附近50000m(九宫格内有14774点,符合条件的点约6000个)

50km优化前

Concurrency Level:      5
Time taken for tests:   89.770 seconds
Complete requests:      5000
Failed requests:        0
Write errors:           0
Total transferred:      720000 bytes
HTML transferred:       0 bytes
Requests per second:    55.70 [#/sec] (mean)
Time per request:       89.770 [ms] (mean)
Time per request:       17.954 [ms] (mean, across all concurrent requests)
Transfer rate:          7.83 [Kbytes/sec] received

Connection Times (ms)
              min  mean[+/-sd] median   max
Connect:        0    0   0.0      0       0
Processing:    23   90  10.7     90     159
Waiting:       23   89  10.7     89     159
Total:         23   90  10.7     90     159

Percentage of the requests served within a certain time (ms)
  50%     90
  66%     93
  75%     96
  80%     97
  90%    102
  95%    107
  98%    111
  99%    116
 100%    159 (longest request)

50km优化后

Concurrency Level:      5
Time taken for tests:   75.447 seconds
Complete requests:      5000
Failed requests:        0
Write errors:           0
Total transferred:      720000 bytes
HTML transferred:       0 bytes
Requests per second:    66.27 [#/sec] (mean)
Time per request:       75.447 [ms] (mean)
Time per request:       15.089 [ms] (mean, across all concurrent requests)
Transfer rate:          9.32 [Kbytes/sec] received

Connection Times (ms)
              min  mean[+/-sd] median   max
Connect:        0    0   0.0      0       0
Processing:    21   75  14.2     75     159
Waiting:       21   75  14.1     75     159
Total:         21   75  14.2     75     159

Percentage of the requests served within a certain time (ms)
  50%     75
  66%     80
  75%     84
  80%     86
  90%     92
  95%     98
  98%    104
  99%    111
 100%    159 (longest request)

可以看到性能并没有巨大的提升,我们减小距离范围到5km(符合条件的点有130个)再看下压测结果

5km优化前

Concurrency Level:      5
Time taken for tests:   14.006 seconds
Complete requests:      5000
Failed requests:        0
Write errors:           0
Total transferred:      720000 bytes
HTML transferred:       0 bytes
Requests per second:    356.99 [#/sec] (mean)
Time per request:       14.006 [ms] (mean)
Time per request:       2.801 [ms] (mean, across all concurrent requests)
Transfer rate:          50.20 [Kbytes/sec] received

Connection Times (ms)
              min  mean[+/-sd] median   max
Connect:        0    0   0.0      0       0
Processing:     2   14   5.5     12      33
Waiting:        2   14   5.5     12      33
Total:          2   14   5.5     12      34

Percentage of the requests served within a certain time (ms)
  50%     12
  66%     16
  75%     19
  80%     20
  90%     22
  95%     23
  98%     27
  99%     28
 100%     34 (longest request)

5km优化后

Concurrency Level:      5
Time taken for tests:   16.661 seconds
Complete requests:      5000
Failed requests:        0
Write errors:           0
Total transferred:      720000 bytes
HTML transferred:       0 bytes
Requests per second:    300.11 [#/sec] (mean)
Time per request:       16.661 [ms] (mean)
Time per request:       3.332 [ms] (mean, across all concurrent requests)
Transfer rate:          42.20 [Kbytes/sec] received

Connection Times (ms)
              min  mean[+/-sd] median   max
Connect:        0    0   0.0      0       0
Processing:     3   17   5.8     16      66
Waiting:        3   16   5.8     16      66
Total:          3   17   5.8     16      66

Percentage of the requests served within a certain time (ms)
  50%     16
  66%     20
  75%     21
  80%     22
  90%     24
  95%     26
  98%     28
  99%     30
 100%     66 (longest request)

可以看到当优化后性能更差了

image.png

猜测造成这个结果的原因应该是附近5km九宫格内的点比较少,所以优化后实际没减少多少距离计算,但多了n(n<=9)倍的请求数,多了额外的命令解析与响应内容的消耗,因此这种优化方案仅仅适用于附近点特别多的情况

参考资料

...

阅读全文 »

golang sync.Pool分析

Category Golang
Tag Go
Posted on
View

如何使用就不讲了,网上很多文章

1. 结构

type Pool struct {
	noCopy noCopy // 用于保证pool不会被复制
	local     unsafe.Pointer // 实际类型是 [P]poolLocal
	localSize uintptr        // local的size
	victim     unsafe.Pointer // 在新一轮GC来临时接管local,用于减少GC之后冷启动之后的性能抖动
	victimSize uintptr        // 在新一轮GC来临时接管localSize
	New func() interface{} // 当pool中没有对象时会调用这个函数生成一个新的
}

type poolLocal struct {
	poolLocalInternal
	// 避免false sharing问题
	pad [128 - unsafe.Sizeof(poolLocalInternal{})%128]byte
}

type poolLocalInternal struct {
    // P的私有缓存区,使用时无需加锁,Put对象时优先放到这里
	private interface{}
	// 公共缓存区,本地P可以pushHead/popHead,其他P只能popTail
	shared  poolChain
}

// 双端队列
type poolChain struct {
	head *poolChainElt
	tail *poolChainElt
}

type poolChainElt struct {
	poolDequeue
	next, prev *poolChainElt
}

// 环形队列
type poolDequeue struct {
	headTail uint64 // 头尾指针,之所以用一个变量持有两个字段大概率是为了方便原子操作一次性修改两个值吧
	vals []eface // 容量从8开始,依次x2,上限为2 ^ 30
}

type eface struct {
	typ, val unsafe.Pointer
}

image.png

2. Get

2.1 主流程

func (p *Pool) Get() interface{} {
    // 当G与P绑定禁止抢占,返回P对应的poolLocal以及P的id
	l, pid := p.pin()
	x := l.private
	l.private = nil
	if x == nil {
	    // 如果private为空则从shared头部pop出一个
		x, _ = l.shared.popHead()
		if x == nil {
		    // 如果shared中也没有则尝试从其他P的shared尾部偷一个
			x = p.getSlow(pid)
		}
	}
	// 解除非抢占
	runtime_procUnpin()
	if x == nil && p.New != nil {
	    // 如果上面的步骤都没有取到则New个出来
		x = p.New()
	}
	return x
}

其中涉及到了一些函数,我们再看下具体实现

2.2 pin

pin的作用是将当前G与P绑定,禁止被抢占。那么为什么要禁止被抢占呢?原因是G被抢占后再恢复执行之后再绑定的可能就不是被抢占之前的P了

func (p *Pool) pin() (*poolLocal, int) {
    // 执行绑定并返回当前pid
	pid := runtime_procPin()
	s := atomic.LoadUintptr(&p.localSize) 
	l := p.local
	if uintptr(pid) < s { // pid<localSize说明已经完成了poolLocal的创建,可以取
		return indexLocal(l, pid), pid
	}
	// pid>=localSize说明poolLocal还没有创建或者用户通过runtime.GOMAXPROCS(X)增加了p的数量,需要先创建
	return p.pinSlow()
}

// 返回local[i]即当前P的poolLocal
func indexLocal(l unsafe.Pointer, i int) *poolLocal {
	lp := unsafe.Pointer(uintptr(l) + uintptr(i)*unsafe.Sizeof(poolLocal{}))
	return (*poolLocal)(lp)
}

// runtime/proc.go
func procPin() int {
	_g_ := getg()
	mp := _g_.m
    // 完成禁止抢占
    // 调度器执行抢占g之前会canPreemptM(mp *m)判断是否可以执行抢占,而canPreemptM有一个条件为m.locks==0
	mp.locks++
	return int(mp.p.ptr().id)
}

2.2.1 pinSlow

pinSlow主要用来在poolLocal还未创建时创建新poolLocal

func (p *Pool) pinSlow() (*poolLocal, int) {
    // 解除禁止抢占
	runtime_procUnpin()
	// 上锁
	allPoolsMu.Lock()
	defer allPoolsMu.Unlock()
	// 禁止抢占
	pid := runtime_procPin()
	s := p.localSize
	l := p.local
	if uintptr(pid) < s {
	    // 上锁之前可能其他线程已经进入到pinSlow了,所以再判断一下
		return indexLocal(l, pid), pid
	}
	if p.local == nil {
	    // 说明local第一次初始化,需要将pool加到allPools中
		allPools = append(allPools, p)
	}
	// 获取p的数量
	size := runtime.GOMAXPROCS(0)
	// 创建local
	local := make([]poolLocal, size)
	atomic.StorePointer(&p.local, unsafe.Pointer(&local[0]))
	atomic.StoreUintptr(&p.localSize, uintptr(size))
	return &local[pid], pid
}

2.3 poolChain.popHead

我们再看下Get主流程中从shared中通过popHead从shared头部pop出一个对象的实现

func (c *poolChain) popHead() (interface{}, bool) {
	d := c.head
	for d != nil { // 从链表头开始遍历,d的type为*poolDequeue
		if val, ok := d.popHead(); ok {
			return val, ok
		}
		d = loadPoolChainElt(&d.prev)
	}
	return nil, false
}

2.3.1 poolDequeue.popHead

poolDequeue.popHead用来从环形队列头部pop出一个缓存对象

func (d *poolDequeue) popHead() (interface{}, bool) {
	var slot *eface
	for {
		ptrs := atomic.LoadUint64(&d.headTail)
		head, tail := d.unpack(ptrs)
		if tail == head {
		    // 如果头尾指针相等则队列为空
			return nil, false
		}

        // 通过不断重试来实现无锁编程
        // 尝试将head-1然后修改poolDequeue.headTail
		head--
		ptrs2 := d.pack(head, tail)
		if atomic.CompareAndSwapUint64(&d.headTail, ptrs, ptrs2) {
		    // 取到head对应的槽
			slot = &d.vals[head&uint32(len(d.vals)-1)]
			break
		}
	}

	val := *(*interface{})(unsafe.Pointer(slot))
	if val == dequeueNil(nil) {
		val = nil // 通过3.2.1 poolDequeue.pushHead分析,貌似val不可能为nil
	}
	
	// 清空槽
	*slot = eface{}
	return val, true
}

func (d *poolDequeue) unpack(ptrs uint64) (head, tail uint32) {
    // dequeueBits=32
	const mask = 1<<dequeueBits - 1
	head = uint32((ptrs >> dequeueBits) & mask) // &mask为了将高32位清零
	tail = uint32(ptrs & mask)
	return
}

func (d *poolDequeue) pack(head, tail uint32) uint64 {
	const mask = 1<<dequeueBits - 1
	return (uint64(head) << dequeueBits) |
		uint64(tail&mask)
}

type dequeueNil *struct{}

2.4 getSlow

再看下主流程中的getSlow函数的实现,getSlow用于在当前P缓存中没有时从其他P的共享缓存区偷缓存对象

func (p *Pool) getSlow(pid int) interface{} {
	size := atomic.LoadUintptr(&p.localSize) 
	locals := p.local
	// 从其他P偷
	for i := 0; i < int(size); i++ {
		l := indexLocal(locals, (pid+i+1)%int(size))
		// 从其他P的共享区的尾部偷
		if x, _ := l.shared.popTail(); x != nil {
			return x
		}
	}

    // 如果没有偷到,则尝试victim cache。我们将在尝试从所有主缓存中偷取之后这样做,因为我们想让victim cache中的对象尽可能的老化
	size = atomic.LoadUintptr(&p.victimSize)
	if uintptr(pid) >= size {
		return nil
	}
	locals = p.victim
	l := indexLocal(locals, pid)
	if x := l.private; x != nil {
		l.private = nil
		return x
	}
	for i := 0; i < int(size); i++ {
		l := indexLocal(locals, (pid+i)%int(size))
		if x, _ := l.shared.popTail(); x != nil {
			return x
		}
	}

    // 走到这里说明victim cache中也没有对象
	// 将victim cache标记为空,下次就不用尝试victim cache了
	atomic.StoreUintptr(&p.victimSize, 0)

	return nil
}

2.4.1 poolChain.popTail

getSlow中会通过poolChain.popTail从双端队列尾部pop对象,看下具体是如何操作的

func (c *poolChain) popTail() (interface{}, bool) {
    // 先获取双端队列的尾节点,因为这时被偷的P与当前P在并行,所以需要通过原子操作获取
	d := loadPoolChainElt(&c.tail)
	if d == nil {
	    // 尾节点为空说明被偷的P的双端队列为空直接返回即可
		return nil, false
	}

	for {
	    // 在pop tail之前load下一个指针是非常重要的,通常,d可能暂时为空,但是如果
	    // 在pop之前next非nil并且pop失败,那么d永远为空,这是唯一可以安全的从链表
	    // 中删除d的方法
		d2 := loadPoolChainElt(&d.next)

		if val, ok := d.popTail(); ok {
			return val, ok
		}

		if d2 == nil {
		    // next为空遍历终止
			return nil, false
		}
		
		// 走到这里说明当前尾节点为空,并且有下一个节点
		// 此时当前尾节点不可能有新对象被push进来了,可以删除掉了
		// 尝试将环形队列的尾节点指针改成它的下一个节点
		if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&c.tail)), unsafe.Pointer(d), unsafe.Pointer(d2)) {
			// 走到这里说明赢得了race,清除prev指针以便gc能够收集空dequeue
			// 因此popHead不会在必要时再次备份
			storePoolChainElt(&d2.prev, nil)
		}
		d = d2
	}
}

要修改poolChain的空尾节点指针为尾节点的下一个节点必须同时满足下面两个条件(即删除当前尾节点) * 当前尾节点环形队列为空 * 当前尾节点必须有下一个节点

我们注意到golang中先获取了当前尾节点的next再popTail,这是为什么呢?如果先popTail再获取next有可能遇到这样的情况:

  1. d的队列为空popTail没有获取到数据
  2. 另外一个线程向d中push了n个对象,此时d不为空,并且生成了下一个节点
  3. 原子获取next,next不为空
  4. 误将还有缓存对象的d删除

2.4.1.1

func (d *poolDequeue) popTail() (interface{}, bool) {
	var slot *eface
	for {
		ptrs := atomic.LoadUint64(&d.headTail)
		head, tail := d.unpack(ptrs)
		if tail == head {
			// 说明队列为空
			return nil, false
		}

		ptrs2 := d.pack(head, tail+1)
		if atomic.CompareAndSwapUint64(&d.headTail, ptrs, ptrs2) {
			// 修改成功,这个solt就被我们占有了
			slot = &d.vals[tail&uint32(len(d.vals)-1)]
			break
		}
	}

    // 从槽中取值
	val := *(*interface{})(unsafe.Pointer(slot))
	if val == dequeueNil(nil) {
		val = nil
	}

	slot.val = nil
	atomic.StorePointer(&slot.typ, nil)

	return val, true
}

3. Put

3.1 主流程

// sync/pool.go
func (p *Pool) Put(x interface{}) {
	if x == nil {
		return
	}
	l, _ := p.pin() // 禁止G被抢占,并返回当前P对应的poolLocal
	if l.private == nil {
	    // 如果私有缓存为空则直接放到私有缓存区
		l.private = x
		x = nil
	}
	if x != nil {
	    // 如果私有缓存已经被占了,则放到共享缓存区头
		l.shared.pushHead(x)
	}
	runtime_procUnpin() // 解除禁止抢占
}

接下来我们看下新元素具体是如何放到共享缓冲区头部的

3.2 poolChain.pushHead

func (c *poolChain) pushHead(val interface{}) {
	d := c.head
	if d == nil {
	    // 头结点为空则需要进行初始化
		const initSize = 8 // 第一个节点的环形队列的长度为8
		d = new(poolChainElt)
		d.vals = make([]eface, initSize)
		// 其他P可能会从尾部偷对象,所以poolChain的tail需要用atomic set,保证对其他P可见
		c.head = d
		storePoolChainElt(&c.tail, d)
	}

    // push到头结点双端队列的头部
	if d.pushHead(val) {
		return
	}

	// 走到这里说明当前头节点的环形队列已经满了,所以申请一个新的节点
	// 新节点环形队列的长度为旧队列的两倍,但如果大于1 << 32 / 4则长度为1 << 32 / 4
	newSize := len(d.vals) * 2
	if newSize >= dequeueLimit {
		newSize = dequeueLimit
	}

	d2 := &poolChainElt{prev: d}
	d2.vals = make([]eface, newSize)
	// 修改双端队列的头节点为新创建的节点
	c.head = d2
	// 其他P可能会用到next所以需要原子store
	storePoolChainElt(&d.next, d2)
	d2.pushHead(val)
}

3.2.1 poolDequeue.pushHead

poolDequeue.pushHead用于将对象放到环形队列上

func (d *poolDequeue) pushHead(val interface{}) bool {
	ptrs := atomic.LoadUint64(&d.headTail)
	head, tail := d.unpack(ptrs)
	// dequeueBits = 32
	if (tail+uint32(len(d.vals)))&(1<<dequeueBits-1) == head {
		// 如果环形队列tail加上长度等于head,说明队列实际已经满了
		return false
	}
	// 找到head对应的槽,slot的类型为*eface
	slot := &d.vals[head&uint32(len(d.vals)-1)]

	typ := atomic.LoadPointer(&slot.typ)
	if typ != nil {
		return false
	}

	if val == nil {
		val = dequeueNil(nil) // 追了下代码调用链貌似val不可能为nil
	}
	// 将val放到槽上
	*(*interface{})(unsafe.Pointer(slot)) = val

	// 增加头指针
	atomic.AddUint64(&d.headTail, 1<<dequeueBits)
	return true
}

4. GC

上面的流程中我们清楚了对象是如何被缓存已经如何被写入和获取的,但是缓存池容量不是无限的,何时清理呢?答案是GC时。

sync/pool.go中有个init函数,在这个函数中注册了GC时如何清理Pool的函数

func init() {
    // 编译器会将poolCleanup赋值给runtime/mgc.go文件中`poolcleanup`变量
    // 在runtime.clearpools()函数中会调用poolcleanup,而在gcStart函数中在开始标记之前会调用clearpools()
	runtime_registerPoolCleanup(poolCleanup)
}

func poolCleanup() {
	for _, p := range oldPools {
	    // 先清除所有旧pool中的victim
	    // 之后gc就能标记清理旧pool中缓存的对象了
		p.victim = nil
		p.victimSize = 0
	}

	for _, p := range allPools {
	    // 用victim接管pool
		p.victim = p.local
		p.victimSize = p.localSize
		p.local = nil
		p.localSize = 0
	}

	oldPools, allPools = allPools, nil
}

当时看到这里时还有一点疑惑,poolCleanup为什么不写成下面这样:

func poolCleanup() {
	for _, p := range allPools {
	    // 用victim接管pool
		p.victim = p.local
		p.victimSize = p.localSize
		p.local = nil
		p.localSize = 0
	}
}

看起来也能清理缓存队列。但实际有个非常浅而已见的坑,那就是allPools这个切片一直在增长,必须将allPools设置为nil清理下才行,所以就必须引入oldPools。

5. 总结

sync.Pool为每个P搞一个缓存队列,避免所有线程共用同一个队列引发的锁竞争问题。

5.1 Put流程

  1. push到双端队列的头部的环形队列头部,如果环形队列已满则创建一个新的环形队列
  2. 将环形队列作为双端队列的新头部

5.2 Get流程

  1. 先从当前P缓冲区的私有缓存取
  2. 如果私有缓存没有从共享缓存区的双端队列的环形队列的头部pop
  3. 还没获取到则从其他P的共享缓存区的双端队列的环形队列的尾部pop
  4. 还没获取到则从victim cache中取

5.3 总结

总的来说只要清楚了sync.Pool的数据结构基本都理解的大差不差了,还是很简单的。

...

阅读全文 »

Golang切片与实现原理

Category Golang
Tag Go
Posted on
View

本文Golang版本为1.13.4

Slice底层结构

go中切片实际是一个结构体,它位于runtime包的slice.go文件中

type slice struct {
	array unsafe.Pointer
	len   int
	cap   int
}

array是切片用来存储数据的底层数组的指针,len为切片中元素的数量,cap为切片的容量即数组的长度

切片的初始化

创建一个切片有以下几种方式

1. 通过字面量创建

arr1 := [3]int{1,2,3} // 创建一个数组
s1 := []int{1,2,3} // 创建一个len为3,cap为3的切片

上面的创建方式非常容易与数组的另一个创建方式弄混

arr2 := [...]int{1,2,3} // 创建一个数组,数组长度由编译器推断

s1在内存上的结构如下图: image.png

2. 通过make()函数创建

s1 := make([]int, 10) // 创建一个长度为10,容量为10的切片
s2 := make([]int, 5, 10) // 创建一个长度为5,容量为5的切片

s2的内存结构如图: image.png

3. 通过数组/切片创建另一个切片

通过数组/切片创建另一个切片语法为

slice[i:j:k]

其中i表示开始切的位置,包括该位置,如果没有则表示从0开始切;j表示切到的位置,不包括该位置,如果没有j则切到最后;k控制切片的容量切到的位置,不包括该位置,如果没有则切到尾部。下面举几个例子说明:

a := [10]int{0,1,2,3,4,5,6,7,8,9}
s1 := a[2:5:9] // s1结果为[2,3,4], len:3, cap:7
s2 := a[2:5:10] // s2结果为[2,3,4] len:3, cap:8
s3 := a[2:7:10] // s3结果为[2,3,4,5,6] len:5, cap:8
s4 := a[2:] // s4结果为[2,3,4,5,6,7,8,9] len:8, cap:8
s5 := a[:3] // s5结果为[0,1,2] len:3, cap:10
s6 := a[::3] // 编译报错: middle index required in 3-index slice
s7 := a[:] s7结果为[0,1,2,3,4,5,6,7,8,9], len:10, cap:10
s10 := s1[1:3] s10结果为[3,4], len:2, cap: 6。注意s10的cap是6,而不是7!

s1与s10在内存上的结构如图: image.png

由于as1s2s3s4s5s7共享同一个数组,所以其中任意一个变量通过索引修改了底层数组元素的值,相当于修改了以上所有变量:

s2[3] = 30

执行上面的代码后:变量a变成了[0,1,2,30,4,5,6,7,8,9]、s1变成了[2,30,4]、…… s7变成了[0,1,2,30,4,5,6,7,8,9]

nil切片与空切片

var s11 []int
var s12 = make([]int, 0)

上面的s11为nil,s12是空切片,他们在内存上的结构如图: image.png

我写了段代码验证了下:

var s10 = make([]int, 0)
sh10 := (*reflect.SliceHeader)(unsafe.Pointer(&s10))
println(unsafe.Pointer(sh10.Data))
var s11 []int
sh11 := (*reflect.SliceHeader)(unsafe.Pointer(&s11))
println(unsafe.Pointer(sh11.Data))
var s12 = make([]int, 0)
sh12 := (*reflect.SliceHeader)(unsafe.Pointer(&s12))
println(unsafe.Pointer(sh12.Data))
var s13 = make([]int, 0)
sh13 := (*reflect.SliceHeader)(unsafe.Pointer(&s13))
println(unsafe.Pointer(sh13.Data))

打印结果如下:

0xc00006af08
0x0
0xc00006af08
0xc00006af08

根据打印结果可以看到是上面的结构无误

切片创建源码

我们打印下下面代码对应的汇编,看下golang是如何为我们创建出来一个切片的

func main() {
	tttttt := make([]int, 999)
	fmt.Println(tttttt)
}

通过go tool compile -S -l slice.go打印对应汇编(-l是禁止内联),下面只摘取关键部分

"".main STEXT size=181 args=0x0 locals=0x48
	...
	// 栈增加72个字节
        0x0013 00019 (slice.go:5)       SUBQ    $72, SP
	// 将当前栈底地址加载到到当前栈顶地址+64处
        0x0017 00023 (slice.go:5)       MOVQ    BP, 64(SP)
	// 栈底修改为栈顶地址+64
        0x001c 00028 (slice.go:5)       LEAQ    64(SP), BP
        ...
        0x0021 00033 (slice.go:6)       LEAQ    type.int(SB), AX
        ...
	// 下面三行实际是把tuntime.makeslice放到栈上的指定位置
        0x0028 00040 (slice.go:6)       MOVQ    AX, (SP)
        0x002c 00044 (slice.go:6)       MOVQ    $999, 8(SP)
        0x0035 00053 (slice.go:6)       MOVQ    $999, 16(SP)

上面的部分画个图可能更清晰些: image.png

继续看汇编:

	// 调用runtime.makeslice函数
        0x003e 00062 (slice.go:6)       CALL    runtime.makeslice(SB)
        ...
	// 将返回值加载到AX寄存器
        0x0043 00067 (slice.go:6)       MOVQ    24(SP), AX
        ...
	// 下面就是调用fmt.Println函数的代码了
        0x0048 00072 (slice.go:7)       MOVQ    AX, (SP)
        0x004c 00076 (slice.go:7)       MOVQ    $999, 8(SP)
        0x0055 00085 (slice.go:7)       MOVQ    $999, 16(SP)
        0x005e 00094 (slice.go:7)       CALL    runtime.convTslice(SB)
        ...
        0x0063 00099 (slice.go:7)       MOVQ    24(SP), AX
        ...
        0x0068 00104 (slice.go:7)       XORPS   X0, X0
        0x006b 00107 (slice.go:7)       MOVUPS  X0, ""..autotmp_1+48(SP)
        ...
        0x0070 00112 (slice.go:7)       LEAQ    type.[]int(SB), CX
        ...
        0x0077 00119 (slice.go:7)       MOVQ    CX, ""..autotmp_1+48(SP)
        ...
        0x007c 00124 (slice.go:7)       MOVQ    AX, ""..autotmp_1+56(SP)
        ...
        0x0081 00129 (slice.go:7)       LEAQ    ""..autotmp_1+48(SP), AX
        ...
        0x0086 00134 (slice.go:7)       MOVQ    AX, (SP)
        0x008a 00138 (slice.go:7)       MOVQ    $1, 8(SP)
        0x0093 00147 (slice.go:7)       MOVQ    $1, 16(SP)
        0x009c 00156 (slice.go:7)       CALL    fmt.Println(SB)
        0x00a1 00161 (slice.go:8)       MOVQ    64(SP), BP
        0x00a6 00166 (slice.go:8)       ADDQ    $72, SP
        0x00aa 00170 (slice.go:8)       RET
        0x00ab 00171 (slice.go:8)       NOP
        ...
        0x00ab 00171 (slice.go:5)       CALL    runtime.morestack_noctxt(SB)
        ...
        0x00b0 00176 (slice.go:5)       JMP     0

上面出现了一个关键函数,即runtime.makeslice,(在堆上分配时才会调用这个函数)我们看下它的实现:

func makeslice(et *_type, len, cap int) unsafe.Pointer {
	// 这里实际是计算切片所占的内存大小,即元素的大小乘容量
	// mem为所需内存大小,overflow标识是否溢出
	mem, overflow := math.MulUintptr(et.size, uintptr(cap))
	if overflow || mem > maxAlloc || len < 0 || len > cap {
		// 如果溢出或者所需内存大于最大可分配内存或者len、cap不合法则报错
		mem, overflow := math.MulUintptr(et.size, uintptr(len))
		if overflow || mem > maxAlloc || len < 0 {
			panicmakeslicelen()
		}
		panicmakeslicecap()
	}
	// 调用mallocgc从go内存管理器获取一块内存
	return mallocgc(mem, et, true)
}

函数传参

切片作为函数参数传参时实际上是复制了一个runtime.slice结构体,而非是传递的runtime.slice结构体指针,举个栗子:

func main() {
	slice := []int{0,1,2}
	foo(slice)
}
func foo(slice []int) {
	...
}

其实就等价于

type Slice struct {
	ptr *[3]int
        len int
	cap int
}

func main() {
	slice := Slice{&[3]int{1,2,3}, 0, 0}
	foo(slice)
}
func foo(slice Slice) {
	...
}

因为函数的形参与实参共享同一个数组,这就导致当把一个切片作为参数传递到另一个函数时,在函数内修改形参的某个下标的值时也会修改了实参。描述的比较绕,下面看一个实例:

func main() {
	param := []int{0, 1, 2}
	foo(param)
	fmt.Println(param)
}

func foo(arg []int) {
	arg[1] = 10
}

打印结果为[0,10,2],原因是param与arg共享同一个底层数组,函数foo内修改了arg[1]实际是将两者的底层数组下标为1的元素修改为了10,所以main函数中的param[1]也就变成了10。 在foo函数内修改arg的len字段,是不会影响到param的len的,下面我们验证下:

func main() {
	param := []int{0, 1, 2}
	foo(param)
	fmt.Println(param)
	fmt.Println(len(param))
}

func foo(arg []int) {
	arg[1] = 10
	argSlice := (*reflect.SliceHeader)(unsafe.Pointer(&arg))
	argSlice.Len = 10
	fmt.Println(len(arg))
}

打印结果如下:

10
[0 10 2]
3

验证成功。

切片扩容

当通过append函数向切片中添加数据时,如果切片的容量不足,需要进行扩容,实际调用的是runtime包中的growslice()函数

// runtime/slice.go
func growslice(et *_type, old slice, cap int) slice {
	...

	// 下面就是计算新容量的部分了
	newcap := old.cap
	doublecap := newcap + newcap
	if cap > doublecap {
		// 如果所需容量大于当前容量的两倍,则新容量为所需容量
		newcap = cap
	} else {
		// 下面是所需容量<=当前容量两倍的逻辑
		if old.len < 1024 {
			// 如果当前长度<1024则新容量为当前容量x2
			newcap = doublecap
		} else {
			// 下面是当前长度>=1024的逻辑
			// 新容量每次增加自身的1/4,直到超过所需容量
			for 0 < newcap && newcap < cap {
				newcap += newcap / 4
			}
			// 如果溢出则新容量为所需容量
			if newcap <= 0 {
				newcap = cap
			}
		}
	}

	// 此处省略分配内存的代码
	...

	// p为新分配的底层数组的地址
	// 从old.array处拷贝lenmem个字节到p
	memmove(p, old.array, lenmem)
	// 返回新的切片
	return slice{p, old.len, newcap}
}

...

阅读全文 »

Go源码解析之sync.Mutex锁

Category Golang
Tag Go
Posted on
View

本文使用Golang版本为:go1.13.4

Mutex的使用

先通过一段简单代码看下Go中Mutex的用法

func main() {
	a := 1
	m := sync.Mutex{}
	go func(){
		m.Lock()
		b := a
		a = b + 1
		m.Unlock()
	}()

	m.Lock()
	fmt.Println(a)
	m.Unlock()
}

Mutex的设计

在解释Lock()和Unlock()源码之前我们必须先整体了解下Mutex的设计,不然下面的源码很难看懂。

我们首先看下sync.Mutex这个结构体

type Mutex struct {
	state int32 // 锁的当前状态,共三种
	sema  uint32 // 信号量,用于阻塞和唤醒goroutine
}

锁的三个状态,它们使用Mutex.state的低三位来标识

mutexLocked = 1 << iota // 锁定状态,二进制表示即 ...001
mutexWoken // 唤醒状态,二进制表示即 ...010
mutexStarving // 饥饿状态,二进制表示即...100

mutexLocked位于state的第一位,mutexWoken位于state的第二位,mutexStarving位于state的第三位,如下图: image.png

Mutex锁有两种模式:正常模式和饥饿模式。正常模式时,waiter按照先到先得的方式获取锁,一个waiter被唤醒后并不能直接获取到锁,它需要与新到的goroutine抢占锁,但是新到的goroutine已经在CPU上运行了,所以它大概率抢不过新到的goroutine,如果抢不到锁waiter就需要在等待队列队头继续等待,而这可能会导致一个waiter等待很长时间。为了避免waiter等待过久,当waiter超过1ms没有抢到锁时就会将当前锁切换到饥饿模式。

切换到饥饿模式后,锁将从解锁的goroutine切换到等待队列的队头waiter,新来的goroutine不会去尝试获取锁,也不会自旋,它们会排到等待队列的队尾。

如果某waiter获取到了锁,那么在满足以下两个条件之一时,它会将当前锁从饥饿模式切换到正常模式。

  1. 它是最后一个waiter
  2. 它等待锁的时间不到1ms

了解了Mutex的设计后我们再继续看Lock()与Unlock()的实现。

加锁Lock()的实现

func (m *Mutex) Lock() {
	if atomic.CompareAndSwapInt32(&m.state, 0, mutexLocked) {
		// 这里本有竞争检测的代码,无意义,已被我删除
		return
	}
	m.lockSlow()
}

函数中首先通过CAS操作尝试获得锁,如果m.state为0即当前锁闲置就将它设置为1,如果尝试失败则进入m.lockSlow()

m.lockSlow()的实现

m.lockSlow()中用到了这几个函数:runtime_canSpin()runtime_doSpin()runtime_SemacquireMutex(),我们先挨个解释下这几个函数的作用再看m.lockSlow()的源码。

runtime_canSpin()

该函数的作用是判断能够进入自旋,下面看下源码

// Active spinning for sync.Mutex.
//go:linkname sync_runtime_canSpin sync.runtime_canSpin
//go:nosplit
func sync_runtime_canSpin(i int) bool { // i是当前自旋次数
	if i >= 4|| ncpu <= 1 || gomaxprocs <= int32(sched.npidle+sched.nmspinning)+1 {
		return false
	}
	if p := getg().m.p.ptr(); !runqempty(p) {
		return false
	}
	return true
}

通过这个函数我们可以看到,runtime层判断能够自旋必须满足以下几个条件

  • 当前自旋次数不能>=4
  • 必须是多核CPU
  • 至少有一个其他正在运行的P
  • 当前P本地G队列为空

这里解释下gomaxprocs <= int32(sched.npidle+sched.nmspinning)+1这个条件: gomaxprocs是进程中P数量上限,sched.npidle是空闲的P的数量、sched.nmspinning是自旋中的M的数量gomaxprocs - sched.npidle - sched.nmspinning=当前运行中的P的数量,当前运行中的P数量-1(当前P) = 其他P的数量,所以这个条件就是至少有一个其他正在运行的P。

runtime_doSpin()

其源码为:

//go:linkname sync_runtime_doSpin sync.runtime_doSpin
//go:nosplit
func sync_runtime_doSpin() {
	procyield(30)
}

这里我们仅看下AMD64平台上proyield的实现:

TEXT runtime·procyield(SB),NOSPLIT,$0-0
	MOVL	cycles+0(FP), AX // 将第一个参数即30加载到AX寄存器
again:
	PAUSE // CPU空转,达到占用CPU的效果
	SUBL	$1, AX // AX寄存器-1
	JNZ	again // 如果不为0则继续执行PAUSE指令,否则退出
	RET

到这里可以看出runtime_doSpin()实际就是CPU空转30次。

runtime_SemacquireMutex()

其实现位于runtime包的sema.go文件中

//go:linkname sync_runtime_SemacquireMutex sync.runtime_SemacquireMutex
func sync_runtime_SemacquireMutex(addr *uint32, lifo bool, skipframes int) {
	semacquire1(addr, lifo, semaBlockProfile|semaMutexProfile, skipframes)
}

semacquire1的实现并非本文重点,这里大概解释下这个函数的作用:

  1. 如果lifo为true,则加到等待队列队头
  2. 如果lifo为false,则加到等待队列队尾
m.lockSlow()

了解了上面几个函数后我们来看下m.lockSlow()中是怎么处理的吧

func (m *Mutex) lockSlow() {
	var waitStartTime int64
	starving := false // 饥饿模式标志
	awoke := false // 唤醒标志
	iter := 0 // 已进行的自旋次数
	old := m.state // 保存当前锁状态
	for {
		// 进入自旋需要满足三个条件
		// 1. 当前锁状态是锁定状态,如果不是锁定状态就退出自旋尝试获取锁
		// 2. 当前不是饥饿状态,原因是饥饿状态时自旋无意义,因为锁会交给等待队列中的第一个waiter
		// 3. runtime_canSpin判断能够自旋
		if old&(mutexLocked|mutexStarving) == mutexLocked && runtime_canSpin(iter) {
			if !awoke && old&mutexWoken == 0 && old>>mutexWaiterShift != 0 &&
				atomic.CompareAndSwapInt32(&m.state, old, old|mutexWoken) {
				// 如果没有唤醒 且 当前锁状态不在唤醒状态
				// 且 当前有等待者则尝试通过CAS将锁状态标记为唤醒
				// 标记为唤醒后,Unlock()中就不会通过信号量唤醒其他锁定的goroutine了
				// 如果CAS成功则标识自己为唤醒
				awoke = true
			}
			// CPU空转30次
			runtime_doSpin()
			// 自旋次数+1
			iter++
			// 更新当前锁状态
			old = m.state
			// 继续尝试自旋
			continue
		}

		// 如果判断不能进入自旋则进入以下逻辑
		// 进到这里有三种情况:
		// 1. 当前已解锁,锁处于正常状态
		// 2. 当前已解锁,锁处于饥饿状态
		// 3. 当前未解锁,锁处于正常状态
		// 4. 当前未解锁,锁处于饥饿状态

		// old是锁的当前状态,new是期望状态,在下面会尝试将锁通过CAS更新为期望状态
		new := old
		if old&mutexStarving == 0 {
			// 如果当前锁是正常状态则尝试获取锁
			new |= mutexLocked
		}
		if old&(mutexLocked|mutexStarving) != 0 {
			// 等待数+1
			// 如果锁当前处于饥饿状态,当前goroutine不能获取锁,需要进到等待队列队尾排队等待,所以等待数需要+1
			// 如果当前锁处于锁定状态,也需要进到等待队列等待
			new += 1 << mutexWaiterShift
		}
		if starving && old&mutexLocked != 0 {
			// 如果当前处于饥饿模式并且锁定状态
			// 则尝试设置为饥饿状态
			new |= mutexStarving
		}
		if awoke {
			if new&mutexWoken == 0 {
				// 如果当前goroutine抢到了唤醒,但是唤醒标志还为0说明出现了异常情况
				throw("sync: inconsistent mutex state")
			}
			// 如果在自旋时当前goroutine抢到唤醒了,则尝试将锁标记为未唤醒
			new &^= mutexWoken
		}
		// 尝试将锁状态由旧状态修改为期望状态
		if atomic.CompareAndSwapInt32(&m.state, old, new) {
			// 修改成功
			// 如果旧状态既不是锁定状态也不是饥饿状态
			// 说明了抢到了锁,则退出循环
			if old&(mutexLocked|mutexStarving) == 0 {
				break
			}
			
			queueLifo := waitStartTime != 0
			if waitStartTime == 0 {
				// 记录等待开始时间
				waitStartTime = runtime_nanotime()
			}
			// 通过信号量阻塞当前goroutine
			// 如果waitStartTime为0,则说明当前goroutine是一个新来的goroutine,那么queueLifo=false,意味加到队尾。
			// 如果waitStartTime不为0,意味当前goroutine是一个被唤醒的goroutine,那么queueLifo=true,意味着加到队头
			runtime_SemacquireMutex(&m.sema, queueLifo, 1)
			// 如果等待时间超过了1ms则切换到饥饿模式
			starving = starving || runtime_nanotime()-waitStartTime > starvationThresholdNs
			// 更新当前锁状态
			old = m.state
			// 如果当前锁处于饥饿状态
			if old&mutexStarving != 0 {
				// 如果当前锁处于锁定状态或者唤醒状态或者没有waiter,异常
				if old&(mutexLocked|mutexWoken) != 0 || old>>mutexWaiterShift == 0 {
					throw("sync: inconsistent mutex state")
				}
				// 因为当前goroutine已经获取了锁,delta用于将等待队列-1
				delta := int32(mutexLocked - 1<<mutexWaiterShift)
				// 如果当前不是锁定模式或者只有一个waiter
				// 就通过delta -= mutexStarving和atomic.AddInt32操作将锁的饥饿状态位设置为0,表示为正常模式
				if !starving || old>>mutexWaiterShift == 1 {
					delta -= mutexStarving
				}
				atomic.AddInt32(&m.state, delta)
				break
			}
			awoke = true
			iter = 0
		} else {
			old = m.state
		}
	}
}

同样的,我已将无关代码和注释删除。

解锁Unlock()的实现

func (m *Mutex) Unlock() {
        // 将锁定状态置为0
	new := atomic.AddInt32(&m.state, -mutexLocked)
	if new != 0 {
	    // 如果锁上存在等待者或者处于饥饿模式则进入unlockSlow()
		m.unlockSlow(new)
	}
}

Unlock()本身非常简单,下面重点关注下unlockSlow()的实现

func (m *Mutex) unlockSlow(new int32) {
	if (new+mutexLocked)&mutexLocked == 0 {
		// 如果解锁一个未锁定的锁则抛出异常
		throw("sync: unlock of unlocked mutex")
	}
	if new&mutexStarving == 0 {
		// 处于正常模式
		old := new
		for {
			// 如果没有等待者则无需唤醒任何goroutine,另外以下三种情况也无需唤醒
			// 1. 锁处于锁定状态,说明Unlock()解锁后紧接着就被其他goroutine获取,就不用再唤醒了
			// 2. 锁处于唤醒状态,说明有等待的goroutine已经被唤醒了,不用再尝试唤醒了
			// 3. 锁处于饥饿模式,锁会交给等待队列队头的等待者,不能往下进行
			if old>>mutexWaiterShift == 0 || old&(mutexLocked|mutexWoken|mutexStarving) != 0 {
				
				return
			}
			// 流程走到这里说明当前有等待者并且锁处于空闲状态(三个标志位都为0)
			// 说明等待者还没有被唤醒,需要唤醒等待者
			// 通过CAS将等待者数量-1,并且设置为唤醒
			new = (old - 1<<mutexWaiterShift) | mutexWoken
			if atomic.CompareAndSwapInt32(&m.state, old, new) {
				// 通过信号量唤醒等待者goroutine,然后退出
				runtime_Semrelease(&m.sema, false, 1)
				return
			}
			// CAS修改失败,说明锁的状态已经被修改,有以下几种可能性:
			// 1. 有新的等待者进来
			// 2. 锁被其他goroutine获取(Unlokc()中已经解锁了,走到这里可能已经被其他goroutine)
			// 3. 锁进入了饥饿模式
	
			// 更新锁状态,进入到下一个循环
			old = m.state
		}
	} else {
		// 处于饥饿模式则直接通过信号量唤醒等待队列头的goroutine
		// 此时state的mutexLocked还没有加锁,唤醒的goroutine会持有锁
		// 在此期间,如果有新的goroutine来请求锁, 因为mutex处于饥饿状态,不会抢占锁
		runtime_Semrelease(&m.sema, true, 1)
	}
}

后言

Mutex虽然代码简单,但由于并行的原因导致case太多,所以还是不太好理解了,建议大家代入到具体的场景中去分析。

...

阅读全文 »

深入理解原子操作的本质

Category Golang
Tag Go
Posted on
View

引言

本文以go1.14 darwin/amd64中的原子操作为例,探究原子操作的汇编实现,引出LOCK指令前缀可见性MESI协议Store BufferInvalid Queue内存屏障,通过对CPU体系结构的探究,从而理解以上概念,并在最终给出一些事实。

Go中的原子操作

我们以atomic.CompareAndSwapInt32为例,它的函数原型是:

func CompareAndSwapInt32(addr *int32, old, new int32) (swapped bool)

对应的汇编代码为:

// sync/atomic/asm.s 24行
TEXT ·CompareAndSwapInt32(SB),NOSPLIT,$0
	JMP	runtime∕internal∕atomic·Cas(SB)

通过跳转指令JMP跳转到了runtime∕internal∕atomic·Cas(SB),由于架构的不同对应的汇编代码也不同,我们看下amd64平台对应的代码:

// runtime/internal/atomic/asm_amd64.s 17行
TEXT runtime∕internal∕atomic·Cas(SB),NOSPLIT,$0-17
	MOVQ	ptr+0(FP), BX // 将函数第一个实参即addr加载到BX寄存器
	MOVL	old+8(FP), AX // 将函数第二个实参即old加载到AX寄存器
	MOVL	new+12(FP), CX // // 将函数第一个实参即new加载到CX寄存器
	LOCK // 本文关键指令,下面会详述
	CMPXCHGL	CX, 0(BX) // 把AX寄存器中的内容(即old)与BX寄存器中地址数据(即addr)指向的数据做比较如果相等则把第一个操作数即CX中的数据(即new)赋值给第二个操作数
	SETEQ	ret+16(FP) // SETEQ与CMPXCHGL配合使用,在这里如果CMPXCHGL比较结果相等则设置本函数返回值为1,否则为0(16(FP)是返回值即swapped的地址)
	RET // 函数返回

从上面代码中可以看到本文的关键:LOCK。它实际是一个指令前缀,它后面必须跟read-modify-write指令,比如:ADD, ADC, AND, BTC, BTR, BTS, CMPXCHG, CMPXCH8B, CMPXCHG16B, DEC, INC, NEG, NOT, OR, SBB, SUB, XOR, XADD, XCHG

LOCK实现原理

在早期CPU上LOCK指令会锁总线,即其他核心不能再通过总线与内存通讯,从而实现该核心对内存的独占。

这种做法虽然解决了问题但是性能太差,所以在Intel P6 CPU(P6是一个架构,并非具体CPU)引入一个优化:如果数据已经缓存在CPU cache中,则锁缓存,否则还是锁总线。

Cache Coherency

CPU Cache与False Sharing 一文中详细介绍了CPU缓存的结构,CPU缓存带来了一致性问题,举个简单的例子:

// 假设CPU0执行了该函数
var a int = 0
go func fnInCpu0() {
    time.Sleep(1 * time.Second)
    a = 1 // 2. 在CPU1加载完a之后CPU0仅修改了自己核心上的cache但是没有同步给CPU1
}()
// CPU1执行了该函数
go func fnInCpu1() {
    fmt.Println(a) // 1. CPU1将a加载到自己的cache,此时a=0
    time.Sleep(3 * time.Second)
    fmt.Println(a) // 3. CPU1从cache中读到a=0,但此时a已经被CPU0修改为0了
}()

上例中由于CPU没有保证缓存的一致性,导致了两个核心之间的同一数据不可见从而程序出现了问题,所以CPU必须保证缓存的一致性,下面将介绍CPU是如何通过MESI协议做到缓存一致的。

MESI是以下四种cacheline状态的简称:

  • M(Modified):此状态为该cacheline被该核心修改,并且保证不会在其他核心的cacheline上
  • E(Exclusive):标识该cacheline被该核心独占,其他核心上没有该行的副本。该核心可直接修改该行而不用通知其他核心。
  • S(Share):该cacheline存在于多个核心上,但是没有修改,当前核心不能直接修改,修改该行必须与其他核心协商。
  • I(Invaild):该cacheline无效,cacheline的初始状态,说明要么不在缓存中,要么内容已过时。

核心之间协商通信需要以下消息机制:

  • Read: CPU发起数据读取请求,请求中包含数据的地址
  • Read Response: Read消息的响应,该消息有可能是内存响应的,有可能是其他核心响应的(即该地址存在于其他核心上cacheline中,且状态为Modified,这时必须返回最新数据)
  • Invalidate: 核心通知其他核心将它们自己核心上对应的cacheline置为Invalid
  • Invalidate ACK: 其他核心对Invalidate通知的响应,将对应cacheline置为Invalid之后发出该确认消息
  • Read Invalidate: 相当于Read消息+Invalidate消息,即当前核心要读取数据并修改该数据。
  • Write Back: 写回,即将Modified的数据写回到低一级存储器中,写回会尽可能地推迟内存更新,只有当替换算法要驱逐更新过的块时才写回到低一级存储器中。

手画状态转移图

image.png

这里有个存疑的地方:CPU从内存中读到数据I状态是转移到S还是E,查资料时两种说法都有。个人认为应该是E,因为这样另外一个核心要加载副本时只需要去当前核心上取就行了不需要读内存,性能会更高些,如果你有不同看法欢迎在评论区交流。

一些规律

  1. CPU在修改cacheline时要求其他持有该cacheline副本的核心失效,并通过Invalidate ACK来接收反馈
  2. cacheline为M意味着内存上的数据不是最新的,最新的数据在该cacheline上
  3. 数据在cacheline时,如果状态为E,则直接修改;如果状态为S则需要广播Invalidate消息,收到Invalidate ACK后修改状态为M;如果状态为I(包括cache miss)则需要发出Read Invalidate

Store Buffer

当CPU要修改一个S状态的数据时需要发出Invalidate消息并等待ACK才写数据,这个过程显然是一个同步过程,但这对于对计算速度要求极高的CPU来说显然是不可接受的,必须对此优化。 因此我们考虑在CPU与cache之间加一个buffer,CPU可以先将数据写入到这个buffer中并发出消息,然后它就可以去做其他事了,待消息响应后再从buffer写入到cache中。但这有个明显的逻辑漏洞,考虑下这段代码:

a = 1
b = a + 1

假设a初始值为0,然后CPU执行a=1,数据被写入Store Buffer还没有落地就紧接着执行了b=a+1,这时由于a还没有修改落地,因此CPU读到的还是0,最终计算出来b=1。

为了解决这个明显的逻辑漏洞,又提出了Store Forwarding:CPU可以把Buffer读出来传递(forwarding)给下面的读取操作,而不用去cache中读。 image.png

这倒是解决了上面的漏洞,但是还存在另外一个问题,我们看下面这段代码:

a = 0
flag = false
func runInCpu0() {
    a = 1
    flag = true
}

func runInCpu1() {
    while (!flag) {
   	continue
    }
    print(a)
}

对于上面的代码我们假设有如下执行步骤:

  1. 假定当前a存在于cpu1的cache中,flag存在于cpu0的cache中,状态均为E。
  2. cpu1先执行while(!flag),由于flag不存在于它的cache中,所以它发出Read flag消息
  3. cpu0执行a=1,它的cache中没有a,因此它将a=1写入Store Buffer,并发出Invalidate a消息
  4. cpu0执行flag=true,由于flag存在于它的cache中并且状态为E,所以将flag=true直接写入到cache,状态修改为M
  5. cpu0接收到Read flag消息,将cache中的flag=true发回给cpu1,状态修改为S
  6. cpu1收到cpu0的Read Response:flat=true,结束while(!flag)循环
  7. cpu1打印a,由于此时a存在于它的cache中a=0,所以打印出来了0
  8. cpu1此时收到Invalidate a消息,将cacheline状态修改为I,但为时已晚
  9. cpu0收到Invalidate ACK,将Store Buffer中的数据a=1刷到cache中

从代码角度看,我们的代码好像变成了

func runInCpu0() {
    flag = true
    a = 1
}

好像是被重新排序了,这其实是一种 伪重排序,必须提出新的办法来解决上面的问题

写屏障

CPU从软件层面提供了 写屏障(write memory barrier) 指令来解决上面的问题,linux将CPU写屏障封装为smp_wmb()函数。写屏障解决上面问题的方法是先将当前Store Buffer中的数据刷到cache后再执行屏障后面的写入操作。

SMP: Symmetrical Multi-Processing,即多处理器。

这里你可能好奇上面的问题是硬件问题,CPU为什么不从硬件上自己解决问题而要求软件开发者通过指令来避免呢?其实很好回答:CPU不能为了这一个方面的问题而抛弃Store Buffer带来的巨大性能提升,就像CPU不能因为分支预测错误会损耗性能增加功耗而放弃分支预测一样。

还是以上面的代码为例,前提保持不变,这时我们加入写屏障:

a = 0
flag = false
func runInCpu0() {
    a = 1
    smp_wmb()
    flag = true
}

func runInCpu1() {
    while (!flag) {
   	continue
    }
    print(a)
}

当cpu0执行flag=true时,由于Store Buffer中有a=1还没有刷到cache上,所以会先将a=1刷到cache之后再执行flag=true,当cpu1读到flag=true时,a也就=1了。

有文章指出CPU还有一种实现写屏障的方法:CPU将当前store buffer中的条目打标,然后将屏障后的“写入操作”也写到Store Buffer中,cpu继续干其他的事,当被打标的条目全部刷到cache中,之后再刷后面的条目。

Invalid Queue

上文通过写屏障解决了伪重排序的问题后,还要思考另一个问题,那就是Store Buffer size是有限的,当Store Buffer满了之后CPU还是要卡住等待Invalidate ACK。Invalidate ACK耗时的主要原因是CPU需要先将自己cacheline状态修改I后才响应ACK,如果一个CPU很繁忙或者处于S状态的副本特别多,可能所有CPU都在等它的ACK。

CPU优化这个问题的方式是搞一个Invalid Queue,CPU先将Invalidate消息放到这个队列中,接着就响应Invalidate ACK。然而这又带来了新的问题,还是以上面的代码为例

a = 0
flag = false
func runInCpu0() {
    a = 1
    smp_wmb()
    flag = true
}

func runInCpu1() {
    while (!flag) {
   	continue
    }
    print(a)
}

我们假设a在CPU0和CPU1中,且状态均为S,flag由CPU0独占

  1. CPU0执行a=1,因为a状态为S,所以它将a=1写入Store Buffer,并发出Invalidate a消息
  2. CPU1执行while(!flag),由于其cache中没有flag,所以它发出Read flag消息
  3. CPU1收到CPU0的Invalidate a消息,并将此消息写入了Invalid Queue,接着就响应了Invlidate ACK
  4. CPU0收到CPU1的Invalidate ACK后将a=1刷到cache中,并将其状态修改为了M
  5. CPU0执行到smp_wmb(),由于Store Buffer此时为空所以就往下执行了
  6. CPU0执行flag=true,因为flag状态为E,所以它直接将flag=true写入到cache,状态被修改为了M
  7. CPU0收到了Read flag消息,因为它cache中有flag,因此它响应了Read Response,并将状态修改为S
  8. CPU1收到Read flag Response,此时flag=true,所以结束了while循环
  9. CPU1打印a,由于a存在于它的cache中且状态为S,所以直接将cache中的a打印出来了,此时a=0,这显然发生了错误。
  10. CPU1这时才处理Invalid Queue中的消息将a状态修改为I,但为时已晚

为了解决上面的问题,CPU提出了读屏障指令,linux将其封装为了smp_rwm()函数。放到我们的代码中就是这样:

...
func runInCpu1() {
    while (!flag) {
   	continue
    }
    smp_rwm()
    print(a)
}

当CPU执行到smp_rwm()时,会将Invalid Queue中的数据处理完成后再执行屏障后面的读取操作,这就解决了上面的问题了。

除了上面提到的读屏障和写屏障外,还有一种全屏障,它其实是读屏障和写屏障的综合体,兼具两种屏障的作用,在linux中它是smp_mb()函数。 文章开始提到的LOCK指令其实兼具了内存屏障的作用。

几个问题

问题1: CPU采用MESI协议实现缓存同步,为什么还要LOCK

答: 1. MESI协议只维护缓存一致性,与可见性有关,与原子性无关。一个非原子性的指令需要加上lock前缀才能保证原子性。

问题2: 一条汇编指令是原子性的吗

  1. read-modify-write 内存的指令不是原子性的,以INC mem_addr为例,我们假设数据已经缓存在了cache上,指令的执行需要先将数据从cache读到执行单元中,再执行+1,然后写回到cache。
  2. 对于没有对齐的内存,读取内存可能需要多次读取,这不是原子性的。(在某些CPU上读取未对齐的内存是不被允许的)
  3. 其他未知原因…

问题3: Go中的原子读

我们看一个读取8字节数据的例子,直接看golang atomic.LoadUint64()汇编:

// uint64 atomicload64(uint64 volatile* addr);
1. TEXT runtime∕internal∕atomic·Load64(SB), NOSPLIT, $0-12
2.	MOVL	ptr+0(FP), AX // 将第一个参数加载到AX寄存器
3.	TESTL	$7, AX // 判断内存是否对齐
4.	JZ	2(PC) // 跳到这条指令的下两条处,即跳转到第6行
5.	MOVL	0, AX // crash with nil ptr deref 引用0x0地址会触发错误
6.	MOVQ	(AX), M0 // 将内存地址指向的数据加载到M0寄存器
7.	MOVQ	M0, ret+4(FP) // 将M0寄存器中数据(即内存指向的位置)给返回值
8.	EMMS // 清除M0寄存器
9.	RET

第3行TESTL指令对两个操作数按位与,如果结果为0,则将ZF设置为1,否则为0。所以这一行其实是判断传进来的内存地址是不是8的整数倍。

第4行JZ指令判断如果ZF即零标志位为1则执行跳转到第二个操作数指定的位置,结合第三行就是如果传入的内存地址是8的整数倍,即内存已对齐,则跳转到第6行,否则继续往下执行。

关于内存对齐可以看下我这篇文章:理解内存对齐

虽然MOV指令是原子性的,但是汇编中貌似没有加入内存屏障,那Golang是怎么实现可见性的呢?我这里也并没有完全的理解,不过大概意思是Golang的atomic会保证顺序一致性,详情可看下这篇文章:Memory Order Guarantees in Go

问题4:Go中的原子写

仍然以写一个8字节数据的操作为例,直接看golang atomic.LoadUint64()汇编:

TEXT runtime∕internal∕atomic·Store64(SB), NOSPLIT, $0-16
	MOVQ	ptr+0(FP), BX
	MOVQ	val+8(FP), AX
	XCHGQ	AX, 0(BX)
	RET

虽然没有LOCK指令,但XCHGQ指令具有LOCK的效果,所以还是原子性而且可见的。

总结

这篇文章花费了我大量的时间与精力,主要原因是刚开始觉得原子性只是个小问题,但是随着不断的深入挖掘,翻阅无数资料,才发现底下潜藏了无数的坑。 s70KdH.png

由于精力原因本文还有一些很重要的点没有讲到,比如acquire/release 语义等等。

另外客观讲本文问题很多,较真的话可能会对您造成一定的困扰,建议您可以将本文作为您研究计算机底层架构的一个契机,自行研究这方面的技术。

参考资料

...

阅读全文 »