go thread-safe rand

go 全局的 rand 是线程安全的, 通过 mutex 来保证, 但是 go 提供的 rand.NewSource 使用的 rngSource 并不是线程安全的
这里提供一种基于 TLS 的实现

mtrnd.go

package mtrnd

import (
    "math/rand"
    "sync"
    "time"

    "github.com/vizee/asm/hack"
)

var (
    rndmu sync.Mutex
    rnds  []*rand.Rand
)

func currnd() *rand.Rand {
    pid := hack.ProcPin()
    if pid >= len(rnds) {
        hack.ProcUnpin()
        rndmu.Lock()
        pid = hack.ProcPin()
        if pid >= len(rnds) {
            t := make([]*rand.Rand, pid+1)
            n := copy(t, rnds)
            for i := n; i < len(t); i++ {
                t[i] = rand.New(rand.NewSource(time.Now().UnixNano() ^ int64(i)))
            }
            rnds = t
        }
        rndmu.Unlock()
    }
    r := rnds[pid]
    hack.ProcUnpin()
    return r
}

func Int() int {
    return currnd().Int()
}

func Intn(n int) int {
    return currnd().Intn(n)
}

func Uint64() uint64 {
    return currnd().Uint64()
}

mtrnd_test.go:

package mtrnd

import (
    "math/rand"
    "testing"
)

func TestMT(t *testing.T) {
    var n [256]int
    for i := 0; i < 1000000; i++ {
        n[Intn(256)]++
    }
    min := n[0]
    max := n[0]
    for i := 0; i < 256; i++ {
        if n[i] < min {
            min = n[i]
        }
        if n[i] > max {
            max = n[i]
        }
        t.Log(n[i])
    }
    t.Log("diff", max-min)
}

func BenchmarkGlobal(b *testing.B) {
    for i := 0; i < b.N; i++ {
        _ = rand.Int()
    }
}

func BenchmarkMT(b *testing.B) {
    for i := 0; i < b.N; i++ {
        _ = Int()
    }
}

func BenchmarkGlobalP(b *testing.B) {
    b.RunParallel(func(pb *testing.PB) {
        for pb.Next() {
            _ = rand.Int()
        }
    })
}

func BenchmarkMTP(b *testing.B) {
    b.RunParallel(func(pb *testing.PB) {
        for pb.Next() {
            _ = Int()
        }
    })
}

benchmark result (i7 4c8t):

BenchmarkGlobal-8       50000000            32.5 ns/op
BenchmarkMT-8           50000000            29.4 ns/op
BenchmarkGlobalP-8      10000000           170 ns/op
BenchmarkMTP-8          200000000            7.06 ns/op

标签: none

添加新评论