go 解析函数声明

基于 go 自带的 parser/ast 库, 解析 go 源码提取一些信息,生成汇编 SYSCALL
gensyscall.go
win32 库之前是纯字符串解析, 这种做法并不好, 所以也打算替换成使用基于 ast 的解析方式

package main

import (
    "flag"
    "fmt"
    "go/ast"
    "go/parser"
    "go/token"
    "html/template"
    "io/ioutil"
    "os"
    "os/exec"
    "regexp"
    "strconv"
    "strings"
)

const templateGoFunc = `func {{.Name -}}
({{range $i, $e := .Params -}}
{{with $e}}{{if $i}}, {{end}}{{.Name}} {{.Type}}{{end}}
{{- end}}){{if .Results | len | lt 1}} (
{{- range $i, $e := .Results -}}
{{with $e}}{{if $i}}, {{end}}{{if .Name | len}}{{.Name}} {{end}}{{.Type}}{{end}}
{{- end}})
{{- else if .Results | len | eq 1 -}}
{{with index .Results 0}}{{if .Name | len}} ({{.Name}} {{.Type}}){{else}} {{.Type}}{{end}}{{end}}
{{- end}}`

const templateLinuxAMD64 = `// generated by gensyscall.go
// GOFILE={{.GOFILE}} GOPACKAGE={{.GOPACKAGE}} GOOS={{.GOOS}} GOARCH={{.GOARCH}}
// DO NOT EDIT!

#include "textflag.h"

{{range .Funcs -}}
// {{template "gofunc" .}}
TEXT ·{{.Name}}(SB), NOSPLIT, $0-{{.FrameSize}}
{{- if .Blocking}}
    CALL runtime·entersyscall(SB)
{{- end -}}
{{range $i, $e := .Params -}}
{{with $e}}
    MOV{{.Size | SizeToChar}} {{.Name}}+{{.Offset}}(FP), {{$i | ArgToReg}}
{{- end}}
{{- end}}
    MOVQ ${{.Trap}}, AX
    SYSCALL
{{- range $i, $e := .Results -}}
{{with $e}}
    MOV{{.Size | SizeToChar}} {{$i | ResToReg}}, {{.Label}}+{{.Offset}}(FP){{end}}
{{- end}}{{- if .Blocking}}
    CALL runtime·exitsyscall(SB)
{{- end}}
    RET

{{end}}`

var sizes = map[string]int{
    "int":            8,
    "uint":           8,
    "int8":           1,
    "uint8":          1,
    "int16":          2,
    "uint16":         2,
    "int32":          4,
    "uint32":         4,
    "int64":          8,
    "uint64":         8,
    "uintptr":        8,
    "byte":           1,
    "bool":           1,
    "unsafe.Pointer": 8,
    "syscall.Handle": 8,
}

func sizeof(typ string) int {
    if typ[0] == '*' {
        return 8
    }
    return sizes[typ]
}

var systraps = map[string]int{}

func fucksystraps() {
    data, err := ioutil.ReadFile("/usr/include/x86_64-linux-gnu/asm/unistd_64.h")
    if err != nil {
        panic(err)
    }
    re := regexp.MustCompile(`#define\s+__NR_([_a-zA-Z0-9]+)\s+(\d+)`)
    for _, submatch := range re.FindAllSubmatch(data, -1) {
        systraps[string(submatch[1])], _ = strconv.Atoi(string(submatch[2]))
    }
}

var isblocking = map[string]bool{}

const stackalign = 8

type fielddecl struct {
    Name   string
    Type   string
    Offset int
    Size   int
    Label  string
}

type funcdecl struct {
    Name      string
    Blocking  bool
    Trap      int
    FrameSize int
    Params    []*fielddecl
    Results   []*fielddecl
}

func align(n, a int) int {
    return (n + a - 1) &^ (a - 1)
}

func (fn *funcdecl) preprocess() {
    if len(fn.Results) > 2 {
        panic(fmt.Sprintf("surprise motherfucker: %d", len(fn.Results)))
    }
    fn.Blocking = isblocking[fn.Name]
    var ok bool
    fn.Trap, ok = systraps[fn.Name]
    if !ok {
        panic(fmt.Sprintf("surprise motherfucker: %s", fn.Name))
    }
    size := 0
    for i := range fn.Params {
        param := fn.Params[i]
        n := sizeof(param.Type)
        if n == 0 {
            panic(fmt.Sprintf("surprise motherfucker: %s", param.Type))
        }
        size = align(size, n)
        param.Offset = size
        param.Size = n
        size += n
    }
    size = align(size, stackalign)
    for i := range fn.Results {
        result := fn.Results[i]
        n := sizeof(result.Type)
        if n == 0 {
            panic(fmt.Sprintf("surprise motherfucker: %s", result.Type))
        }
        size = align(size, n)
        if result.Name != "" {
            result.Label = result.Name
        } else if i == 0 {
            result.Label = "ret"
        } else {
            result.Label = "ret" + strconv.Itoa(i)
        }
        result.Offset = size
        result.Size = n
        size += n
    }
    fn.FrameSize = size
}

func appendExpr(buf []byte, expr ast.Expr) []byte {
    switch typ := expr.(type) {
    case *ast.Ident:
        buf = append(buf, typ.Name...)
    case *ast.BasicLit:
        buf = append(buf, typ.Value...)
    case *ast.SelectorExpr:
        buf = append(append(appendExpr(buf, typ.X), '.'), typ.Sel.Name...)
    case *ast.ArrayType:
        buf = append(buf, '[')
        if typ.Len != nil {
            buf = appendExpr(buf, typ.Len)
        }
        buf = appendExpr(append(buf, ']'), typ.Elt)
    case *ast.StarExpr:
        buf = appendExpr(append(buf, '*'), typ.X)
    default:
        panic(fmt.Sprintf("surprise motherfucker: %#+v", expr))
    }
    return buf
}

func parseFields(list *ast.FieldList) []*fielddecl {
    if list == nil {
        return nil
    }
    buf := make([]byte, 0, 64)
    var fields []*fielddecl
    for _, field := range list.List {
        buf = appendExpr(buf[:0], field.Type)
        typ := string(buf)
        if field.Names == nil {
            fields = append(fields, &fielddecl{Name: "", Type: typ})
        } else {
            for _, name := range field.Names {
                fields = append(fields, &fielddecl{Name: name.Name, Type: typ})
            }
        }
    }
    return fields
}

func parseDecl(file string) ([]*funcdecl, error) {
    fset := token.NewFileSet()
    astf, err := parser.ParseFile(fset, file, nil, parser.DeclarationErrors)
    if err != nil {
        return nil, err
    }

    var decls []*funcdecl
    for _, decl := range astf.Decls {
        fd, ok := decl.(*ast.FuncDecl)
        if !ok {
            continue
        }
        decl := &funcdecl{
            Name:    fd.Name.Name,
            Params:  parseFields(fd.Type.Params),
            Results: parseFields(fd.Type.Results),
        }
        decl.preprocess()
        decls = append(decls, decl)
    }
    return decls, nil
}

var (
    argtoreg = [...]string{
        "DI", "SI", "DX", "R10", "R8", "R9",
    }
    restoreg = [...]string{
        "AX", "DX",
    }
    sizetochar = map[int]string{
        1: "B",
        2: "W",
        4: "L",
        8: "Q",
    }
)

type sourceData struct {
    GOFILE    string
    GOPACKAGE string
    GOOS      string
    GOARCH    string
    Funcs     []*funcdecl
}

func main() {
    var (
        blocking string
        traps    string
        data     sourceData
    )
    flag.StringVar(&blocking, "blocking", "", "blocking functions")
    flag.StringVar(&traps, "traps", "", "custom syscall trap")
    flag.StringVar(&data.GOFILE, "gofile", os.Getenv("GOFILE"), "GOFILE")
    flag.StringVar(&data.GOPACKAGE, "gopackage", os.Getenv("GOPACKAGE"), "GOPACKAGE")
    flag.StringVar(&data.GOOS, "goos", os.Getenv("GOOS"), "GOOS")
    flag.StringVar(&data.GOARCH, "goarch", os.Getenv("GOARCH"), "GOARCH")
    flag.Parse()

    if data.GOFILE == "" || data.GOPACKAGE == "" {
        panic(fmt.Sprintf("missing GOFILE/GOPACKAGE"))
    }
    if data.GOOS != "linux" || data.GOARCH != "amd64" {
        panic(fmt.Sprintf("unsupport platform: %s/%s", data.GOOS, data.GOARCH))
    }

    for _, s := range strings.Split(blocking, ",") {
        s = strings.TrimSpace(s)
        if s == "" {
            continue
        }
        isblocking[s] = true
    }

    fucksystraps()
    for _, s := range strings.Split(traps, ",") {
        s = strings.TrimSpace(s)
        if s == "" {
            continue
        }
        pair := strings.SplitN(s, ":", 2)
        if len(pair) != 2 || pair[0] == "" {
            continue
        }
        trap, ok := systraps[pair[1]]
        if !ok {
            var err error
            trap, err = strconv.Atoi(pair[1])
            if err != nil {
                panic(err)
            }
        }
        systraps[pair[0]] = trap
    }

    var err error
    data.Funcs, err = parseDecl(data.GOFILE)
    if err != nil {
        panic(err)
    }

    tpl := template.New("gensyscall")
    tpl.Funcs(map[string]interface{}{
        "ArgToReg": func(args ...interface{}) string {
            return argtoreg[args[0].(int)]
        },
        "ResToReg": func(args ...interface{}) string {
            return restoreg[args[0].(int)]
        },
        "SizeToChar": func(args ...interface{}) string {
            return sizetochar[args[0].(int)]
        },
    })
    template.Must(tpl.New("gofunc").Parse(templateGoFunc))
    template.Must(tpl.New("genasm").Parse(templateLinuxAMD64))
    outputfile := "syscall_" + data.GOOS + "_" + data.GOARCH + ".s"
    f, err := os.Create(outputfile)
    if err != nil {
        panic(err)
    }
    defer f.Close()
    err = tpl.ExecuteTemplate(f, "genasm", &data)
    if err != nil {
        panic(err)
    }
    err = exec.Command("go", "vet").Run()
    if err != nil {
        panic(err)
    }
}

标签: none

添加新评论