1- //go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))
2-
3- // The above build constraint is based of the cgo directives in this file:
4- // https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go
51package wasm
62
73import (
4+ "bytes"
85 "context"
96 "crypto/sha256"
107 "errors"
@@ -15,10 +12,11 @@ import (
1512 "os"
1613 "path/filepath"
1714 "runtime"
18- "runtime/trace"
1915 "strings"
2016
21- wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
17+ "github.com/tetratelabs/wazero"
18+ "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
19+ "github.com/tetratelabs/wazero/sys"
2220 "golang.org/x/sync/singleflight"
2321 "google.golang.org/grpc"
2422 "google.golang.org/grpc/codes"
@@ -31,31 +29,13 @@ import (
3129 "github.com/sqlc-dev/sqlc/internal/plugin"
3230)
3331
34- func Enabled () bool {
35- return true
36- }
37-
38- // This version must be updated whenever the wasmtime-go dependency is updated
39- const wasmtimeVersion = `v14.0.0`
32+ var flight singleflight.Group
4033
41- func cacheDir () (string , error ) {
42- cache := os .Getenv ("SQLCCACHE" )
43- if cache != "" {
44- return cache , nil
45- }
46- cacheHome := os .Getenv ("XDG_CACHE_HOME" )
47- if cacheHome == "" {
48- home , err := os .UserHomeDir ()
49- if err != nil {
50- return "" , err
51- }
52- cacheHome = filepath .Join (home , ".cache" )
53- }
54- return filepath .Join (cacheHome , "sqlc" ), nil
34+ type runtimeAndCode struct {
35+ rt wazero.Runtime
36+ code wazero.CompiledModule
5537}
5638
57- var flight singleflight.Group
58-
5939// Verify the provided sha256 is valid.
6040func (r * Runner ) getChecksum (ctx context.Context ) (string , error ) {
6141 if r .SHA256 != "" {
@@ -70,67 +50,26 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
7050 return sum , nil
7151}
7252
73- func (r * Runner ) loadModule (ctx context.Context , engine * wasmtime. Engine ) (* wasmtime. Module , error ) {
53+ func (r * Runner ) loadAndCompile (ctx context.Context ) (* runtimeAndCode , error ) {
7454 expected , err := r .getChecksum (ctx )
7555 if err != nil {
7656 return nil , err
7757 }
78- value , err , _ := flight .Do (expected , func () (interface {}, error ) {
79- return r .loadSerializedModule (ctx , engine , expected )
80- })
81- if err != nil {
82- return nil , err
83- }
84- data , ok := value .([]byte )
85- if ! ok {
86- return nil , fmt .Errorf ("returned value was not a byte slice" )
87- }
88- return wasmtime .NewModuleDeserialize (engine , data )
89- }
90-
91- func (r * Runner ) loadSerializedModule (ctx context.Context , engine * wasmtime.Engine , expectedSha string ) ([]byte , error ) {
9258 cacheDir , err := cache .PluginsDir ()
9359 if err != nil {
9460 return nil , err
9561 }
96-
97- pluginDir := filepath .Join (cacheDir , expectedSha )
98- modName := fmt .Sprintf ("plugin_%s_%s_%s.module" , runtime .GOOS , runtime .GOARCH , wasmtimeVersion )
99- modPath := filepath .Join (pluginDir , modName )
100- _ , staterr := os .Stat (modPath )
101- if staterr == nil {
102- data , err := os .ReadFile (modPath )
103- if err != nil {
104- return nil , err
105- }
106- return data , nil
107- }
108-
109- wmod , err := r .loadWASM (ctx , cacheDir , expectedSha )
62+ value , err , _ := flight .Do (expected , func () (interface {}, error ) {
63+ return r .loadAndCompileWASM (ctx , cacheDir , expected )
64+ })
11065 if err != nil {
11166 return nil , err
11267 }
113-
114- moduRegion := trace .StartRegion (ctx , "wasmtime.NewModule" )
115- module , err := wasmtime .NewModule (engine , wmod )
116- moduRegion .End ()
117- if err != nil {
118- return nil , fmt .Errorf ("define wasi: %w" , err )
119- }
120-
121- err = os .Mkdir (pluginDir , 0755 )
122- if err != nil && ! os .IsExist (err ) {
123- return nil , fmt .Errorf ("mkdirall: %w" , err )
124- }
125- out , err := module .Serialize ()
126- if err != nil {
127- return nil , fmt .Errorf ("serialize: %w" , err )
128- }
129- if err := os .WriteFile (modPath , out , 0444 ); err != nil {
130- return nil , fmt .Errorf ("cache wasm: %w" , err )
68+ data , ok := value .(* runtimeAndCode )
69+ if ! ok {
70+ return nil , fmt .Errorf ("returned value was not a compiled module" )
13171 }
132-
133- return out , nil
72+ return data , nil
13473}
13574
13675func (r * Runner ) fetch (ctx context.Context , uri string ) ([]byte , string , error ) {
@@ -174,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error)
174113 return wmod , actual , nil
175114}
176115
177- func (r * Runner ) loadWASM (ctx context.Context , cache string , expected string ) ([] byte , error ) {
116+ func (r * Runner ) loadAndCompileWASM (ctx context.Context , cache string , expected string ) (* runtimeAndCode , error ) {
178117 pluginDir := filepath .Join (cache , expected )
179118 pluginPath := filepath .Join (pluginDir , "plugin.wasm" )
180119 _ , staterr := os .Stat (pluginPath )
@@ -203,7 +142,26 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([
203142 }
204143 }
205144
206- return wmod , nil
145+ wazeroCache , err := wazero .NewCompilationCacheWithDir (filepath .Join (cache , "wazero" ))
146+ if err != nil {
147+ return nil , fmt .Errorf ("wazero.NewCompilationCacheWithDir: %w" , err )
148+ }
149+
150+ config := wazero .NewRuntimeConfig ().WithCompilationCache (wazeroCache )
151+ rt := wazero .NewRuntimeWithConfig (ctx , config )
152+
153+ if _ , err := wasi_snapshot_preview1 .Instantiate (ctx , rt ); err != nil {
154+ return nil , fmt .Errorf ("wasi_snapshot_preview1 instantiate: %w" , err )
155+ }
156+
157+ // Compile the Wasm binary once so that we can skip the entire compilation
158+ // time during instantiation.
159+ code , err := rt .CompileModule (ctx , wmod )
160+ if err != nil {
161+ return nil , fmt .Errorf ("compile module: %w" , err )
162+ }
163+
164+ return & runtimeAndCode {rt : rt , code : code }, nil
207165}
208166
209167// removePGCatalog removes the pg_catalog schema from the request. There is a
@@ -245,75 +203,34 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
245203 return fmt .Errorf ("failed to encode codegen request: %w" , err )
246204 }
247205
248- engine := wasmtime .NewEngine ()
249- module , err := r .loadModule (ctx , engine )
250- if err != nil {
251- return fmt .Errorf ("loadModule: %w" , err )
252- }
253-
254- linker := wasmtime .NewLinker (engine )
255- if err := linker .DefineWasi (); err != nil {
256- return err
257- }
258-
259- dir , err := os .MkdirTemp (os .Getenv ("SQLCTMPDIR" ), "out" )
206+ runtimeAndCode , err := r .loadAndCompile (ctx )
260207 if err != nil {
261- return fmt .Errorf ("temp dir: %w" , err )
262- }
263-
264- defer os .RemoveAll (dir )
265- stdinPath := filepath .Join (dir , "stdin" )
266- stderrPath := filepath .Join (dir , "stderr" )
267- stdoutPath := filepath .Join (dir , "stdout" )
268-
269- if err := os .WriteFile (stdinPath , stdinBlob , 0755 ); err != nil {
270- return fmt .Errorf ("write file: %w" , err )
208+ return fmt .Errorf ("loadBytes: %w" , err )
271209 }
272210
273- // Configure WASI imports to write stdout into a file.
274- wasiConfig := wasmtime .NewWasiConfig ()
275- wasiConfig .SetArgv ([]string {"plugin.wasm" , method })
276- wasiConfig .SetStdinFile (stdinPath )
277- wasiConfig .SetStdoutFile (stdoutPath )
278- wasiConfig .SetStderrFile (stderrPath )
211+ var stderr , stdout bytes.Buffer
279212
280- keys := []string {"SQLC_VERSION" }
281- vals := []string {info .Version }
213+ conf := wazero .NewModuleConfig ().
214+ WithName ("" ).
215+ WithArgs ("plugin.wasm" , method ).
216+ WithStdin (bytes .NewReader (stdinBlob )).
217+ WithStdout (& stdout ).
218+ WithStderr (& stderr ).
219+ WithEnv ("SQLC_VERSION" , info .Version )
282220 for _ , key := range r .Env {
283- keys = append (keys , key )
284- vals = append (vals , os .Getenv (key ))
285- }
286- wasiConfig .SetEnv (keys , vals )
287-
288- store := wasmtime .NewStore (engine )
289- store .SetWasi (wasiConfig )
290-
291- linkRegion := trace .StartRegion (ctx , "linker.DefineModule" )
292- err = linker .DefineModule (store , "" , module )
293- linkRegion .End ()
294- if err != nil {
295- return fmt .Errorf ("define wasi: %w" , err )
221+ conf = conf .WithEnv (key , os .Getenv (key ))
296222 }
297223
298- // Run the function
299- fn , err := linker .GetDefault (store , "" )
300- if err != nil {
301- return fmt .Errorf ("wasi: get default: %w" , err )
224+ result , err := runtimeAndCode .rt .InstantiateModule (ctx , runtimeAndCode .code , conf )
225+ if result != nil {
226+ defer result .Close (ctx )
302227 }
303-
304- callRegion := trace .StartRegion (ctx , "call _start" )
305- _ , err = fn .Call (store )
306- callRegion .End ()
307-
308- if cerr := checkError (err , stderrPath ); cerr != nil {
228+ if cerr := checkError (err , stderr ); cerr != nil {
309229 return cerr
310230 }
311231
312232 // Print WASM stdout
313- stdoutBlob , err := os .ReadFile (stdoutPath )
314- if err != nil {
315- return fmt .Errorf ("read file: %w" , err )
316- }
233+ stdoutBlob := stdout .Bytes ()
317234
318235 resp , ok := reply .(protoreflect.ProtoMessage )
319236 if ! ok {
@@ -331,23 +248,21 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
331248 return nil , status .Error (codes .Unimplemented , "" )
332249}
333250
334- func checkError (err error , stderrPath string ) error {
251+ func checkError (err error , stderr bytes. Buffer ) error {
335252 if err == nil {
336253 return err
337254 }
338255
339- var wtError * wasmtime.Error
340- if errors .As (err , & wtError ) {
341- if code , ok := wtError .ExitStatus (); ok {
342- if code == 0 {
343- return nil
344- }
256+ if exitErr , ok := err .(* sys.ExitError ); ok {
257+ if exitErr .ExitCode () == 0 {
258+ return nil
345259 }
346260 }
261+
347262 // Print WASM stdout
348- stderrBlob , rferr := os . ReadFile ( stderrPath )
349- if rferr == nil && len (stderrBlob ) > 0 {
350- return errors .New (string ( stderrBlob ) )
263+ stderrBlob := stderr . String ( )
264+ if len (stderrBlob ) > 0 {
265+ return errors .New (stderrBlob )
351266 }
352267 return fmt .Errorf ("call: %w" , err )
353268}
0 commit comments