Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions cmd/plug/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@ package main

import (
"bytes"
"errors"
"fmt"
"go/format"
"go/types"
"io"
"io/fs"
"log"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"

"github.com/lufia/plug/plugcore"
"golang.org/x/tools/go/ast/astutil"

"github.com/lufia/plug/plugcore"
)

type Stub struct { // Plug?
Expand All @@ -23,11 +28,25 @@ type Stub struct { // Plug?
func Rewrite(stub *Stub) (string, error) {
filePath := stub.f.path
name := filepath.Base(filePath)
dir := filepath.Join("plug", stub.f.pkg.path)
cacheDir, err := os.UserCacheDir()
if err != nil {
return "", fmt.Errorf("failed to get cachedir: %w", err)
}
dir := filepath.Join(cacheDir, "plug", runtime.Version(), stub.f.pkg.PathVersion())
if err := os.MkdirAll(dir, 0755); err != nil && !os.IsExist(err) {
return "", fmt.Errorf("failed to create %s: %w", dir, err)
}
if verbose {
log.Printf("cachedir: %s\n", dir)
}
file := filepath.Join(dir, name)
_, err = os.Stat(file)
if err == nil {
return file, nil
}
if !errors.Is(err, fs.ErrNotExist) {
return "", fmt.Errorf("failed to stat %s: %w", file, err)
}
w, err := os.Create(file)
if err != nil {
return "", fmt.Errorf("failed to create %s: %w", file, err)
Expand All @@ -43,21 +62,17 @@ func Rewrite(stub *Stub) (string, error) {
return file, nil
}

func pkgPath(v any) string {
return reflect.TypeOf(v).PkgPath()
}

func rewriteFile(w io.Writer, stub *Stub) error {
fset := stub.f.pkg.c.Fset
path := pkgPath(plugcore.Object{})
path := reflect.TypeOf(plugcore.Object{}).PkgPath()
astutil.AddImport(fset, stub.f.f, path)

var buf bytes.Buffer
for _, fn := range stub.fns {
rewriteFunc(&buf, fn)
}
if verbose {
fmt.Printf("====\n%s\n====\n", buf.Bytes())
log.Printf("====\n%s\n====\n", buf.Bytes())
}
s, err := format.Source(buf.Bytes())
if err != nil {
Expand Down
37 changes: 24 additions & 13 deletions cmd/plug/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ func main() {
flag.BoolVar(&verbose, "v", false, "enable verbose log")
flag.Parse()

pkgPath, err := loadPackagePath(".")
pkgPath, modVers, err := loadPackagePath(".")
if err != nil {
log.Fatal(err)
}
syms, err := FindPlugSyms(pkgPath)
if err != nil {
log.Fatal(err)
}
stubs := Group(syms)
stubs := Group(syms, modVers)

var o Overlay
for filePath, stub := range stubs {
Expand All @@ -56,11 +56,11 @@ func main() {
}
}

func loadPackagePath(dir string) (string, error) {
func loadPackagePath(dir string) (string, map[string]string, error) {
// loader.Import does not handle "." notation that means current package.
dir, err := filepath.Abs(dir)
if err != nil {
return "", err
return "", nil, err
}
s := dir
file := filepath.Join(s, "go.mod")
Expand All @@ -70,36 +70,47 @@ func loadPackagePath(dir string) (string, error) {
break
}
if !os.IsNotExist(err) {
return "", err
return "", nil, err
}
up := filepath.Dir(s)
if up == s {
return "", fmt.Errorf("go.mod is not exist")
return "", nil, fmt.Errorf("go.mod is not exist")
}
s = up
file = filepath.Join(s, "go.mod")
}
data, err := os.ReadFile(file)
if err != nil {
return "", err
return "", nil, err
}
modPath := modfile.ModulePath(data)

f, err := modfile.Parse(file, data, nil)
if err != nil {
return "", nil, err
}
modPath := f.Module.Mod.Path
if modPath == "" {
return "", fmt.Errorf("%s: invalid go.mod syntax", file)
return "", nil, fmt.Errorf("%s: invalid go.mod syntax", file)
}
slug, err := filepath.Rel(s, dir)
if err != nil {
return "", err
return "", nil, err
}
pkgPath := path.Join(modPath, filepath.ToSlash(slug))

modVers := make(map[string]string)
for _, r := range f.Require {
modVers[r.Mod.Path] = r.Mod.Version
}
return path.Join(modPath, filepath.ToSlash(slug)), nil
return pkgPath, modVers, nil
}

// Group returns a map of Stub indexed by filePath.
func Group(syms []*Sym) map[string]*Stub {
func Group(syms []*Sym, modVers map[string]string) map[string]*Stub {
stubs := make(map[string]*Stub)
for _, sym := range syms {
pkgPath := sym.PkgPath()
pkg, err := LoadPackage(pkgPath)
pkg, err := LoadPackage(pkgPath, modVers[pkgPath])
if err != nil {
log.Fatalf("failed to load package %s: %v\n", pkgPath, err)
}
Expand Down
17 changes: 13 additions & 4 deletions cmd/plug/pkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@ import (

type Pkg struct {
*loader.PackageInfo
c *loader.Config
path string
c *loader.Config
path string
version string // If it is empty, maybe it is the stdlib
}

func (pkg *Pkg) PathVersion() string {
s := pkg.path
if v := pkg.version; v != "" {
s += "@" + v
}
return s
}

type File struct {
Expand All @@ -30,7 +39,7 @@ type Func struct {

var pkgCache = make(map[string]*Pkg)

func LoadPackage(pkgPath string) (*Pkg, error) {
func LoadPackage(pkgPath, modVersion string) (*Pkg, error) {
if pkg, ok := pkgCache[pkgPath]; ok {
return pkg, nil
}
Expand All @@ -42,7 +51,7 @@ func LoadPackage(pkgPath string) (*Pkg, error) {
if err != nil {
return nil, err
}
pkg := &Pkg{p.Package(pkgPath), &c, pkgPath}
pkg := &Pkg{p.Package(pkgPath), &c, pkgPath, modVersion}
pkgCache[pkgPath] = pkg
return pkg, nil
}
Expand Down