go thread-safe rand

go 全局的 rand 是线程安全的, 通过 mutex 来保证, 但是 go 提供的 rand.NewSource 使用的 rngSource 并不是线程安全的
这里提供一种基于 TLS 的实现

mtrnd.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package mtrnd

import (
"math/rand"
"sync"
"time"

"github.com/vizee/asm/hack"
)

var (
rndmu sync.Mutex
rnds []*rand.Rand
)

func currnd() *rand.Rand {
pid := hack.ProcPin()
if pid >= len(rnds) {
hack.ProcUnpin()
rndmu.Lock()
pid = hack.ProcPin()
if pid >= len(rnds) {
t := make([]*rand.Rand, pid+1)
n := copy(t, rnds)
for i := n; i < len(t); i++ {
t[i] = rand.New(rand.NewSource(time.Now().UnixNano() ^ int64(i)))
}
rnds = t
}
rndmu.Unlock()
}
r := rnds[pid]
hack.ProcUnpin()
return r
}

func Int() int {
return currnd().Int()
}

func Intn(n int) int {
return currnd().Intn(n)
}

func Uint64() uint64 {
return currnd().Uint64()
}

mtrnd_test.go:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package mtrnd

import (
"math/rand"
"testing"
)

func TestMT(t *testing.T) {
var n [256]int
for i := 0; i < 1000000; i++ {
n[Intn(256)]++
}
min := n[0]
max := n[0]
for i := 0; i < 256; i++ {
if n[i] < min {
min = n[i]
}
if n[i] > max {
max = n[i]
}
t.Log(n[i])
}
t.Log("diff", max-min)
}

func BenchmarkGlobal(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = rand.Int()
}
}

func BenchmarkMT(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = Int()
}
}

func BenchmarkGlobalP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = rand.Int()
}
})
}

func BenchmarkMTP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = Int()
}
})
}

benchmark result (i7 4c8t):

1
2
3
4
BenchmarkGlobal-8    	50000000	        32.5 ns/op
BenchmarkMT-8 50000000 29.4 ns/op
BenchmarkGlobalP-8 10000000 170 ns/op
BenchmarkMTP-8 200000000 7.06 ns/op