Files
mstore/vendor/github.com/extism/go-sdk/runtime.go
T
2026-03-13 19:02:42 +02:00

196 lines
4.8 KiB
Go

package extism
import (
"context"
"github.com/tetratelabs/wazero/api"
)
// TODO: test runtime initialization for WASI and Haskell
type runtimeType uint8
const (
None runtimeType = iota
Haskell
Wasi
)
type guestRuntime struct {
mainRuntime moduleRuntime
runtimes map[string]moduleRuntime
init func(ctx context.Context) error
initialized bool
}
type moduleRuntime struct {
runtimeType runtimeType
init func(ctx context.Context) error
initialized bool
}
// detectGuestRuntime detects the runtime of the main module and all other modules
// it returns a guest runtime with an initialization function specific that invokes
// the initialization function of all the modules, with the main module last.
func detectGuestRuntime(p *Plugin) guestRuntime {
r := guestRuntime{runtimes: make(map[string]moduleRuntime)}
r.mainRuntime = detectModuleRuntime(p, p.mainModule)
for k, m := range p.modules {
r.runtimes[k] = detectModuleRuntime(p, m)
}
r.init = func(ctx context.Context) error {
for k, v := range r.runtimes {
p.Logf(LogLevelDebug, "Initializing runtime for module %v", k)
err := v.init(ctx)
if err != nil {
return err
}
v.initialized = true
}
m := r.mainRuntime
p.Logf(LogLevelDebug, "Initializing runtime for main module")
err := m.init(ctx)
if err != nil {
return err
}
m.initialized = true
return nil
}
return r
}
// detectModuleRuntime detects the specific runtime of a given module
// it returns a module runtime with an initialization function specific to that module
func detectModuleRuntime(p *Plugin, m api.Module) moduleRuntime {
runtime, ok := haskellRuntime(p, m)
if ok {
return runtime
}
runtime, ok = wasiRuntime(p, m)
if ok {
return runtime
}
p.Log(LogLevelTrace, "No runtime detected")
return moduleRuntime{runtimeType: None, init: func(_ context.Context) error { return nil }, initialized: true}
}
// Check for Haskell runtime initialization functions
// Initialize Haskell runtime if `hs_init` and `hs_exit` are present,
// by calling the `hs_init` export
func haskellRuntime(p *Plugin, m api.Module) (moduleRuntime, bool) {
initFunc := m.ExportedFunction("hs_init")
if initFunc == nil {
return moduleRuntime{}, false
}
params := initFunc.Definition().ParamTypes()
if len(params) != 2 || params[0] != api.ValueTypeI32 || params[1] != api.ValueTypeI32 {
p.Logf(LogLevelTrace, "hs_init function found with type %v", params)
}
reactorInit := m.ExportedFunction("_initialize")
init := func(ctx context.Context) error {
if reactorInit != nil {
_, err := reactorInit.Call(ctx)
if err != nil {
p.Logf(LogLevelError, "Error running reactor _initialize: %s", err.Error())
}
}
_, err := initFunc.Call(ctx, 0, 0)
if err == nil {
p.Log(LogLevelDebug, "Initialized Haskell language runtime.")
}
return err
}
p.Log(LogLevelTrace, "Haskell runtime detected")
return moduleRuntime{runtimeType: Haskell, init: init}, true
}
// Check for initialization functions defined by the WASI standard
func wasiRuntime(p *Plugin, m api.Module) (moduleRuntime, bool) {
if !p.hasWasi {
return moduleRuntime{}, false
}
// WASI supports two modules: Reactors and Commands
// we prioritize Reactors over Commands
// see: https://github.com/WebAssembly/WASI/blob/main/legacy/application-abi.md
if r, ok := reactorModule(m, p); ok {
return r, ok
}
return commandModule(m, p)
}
// Check for `_initialize` this is used by WASI to initialize certain interfaces.
func reactorModule(m api.Module, p *Plugin) (moduleRuntime, bool) {
init := findFunc(m, p, "_initialize")
if init == nil {
return moduleRuntime{}, false
}
p.Logf(LogLevelTrace, "WASI runtime detected")
p.Logf(LogLevelTrace, "Reactor module detected")
return moduleRuntime{runtimeType: Wasi, init: init}, true
}
// Check for `__wasm__call_ctors`, this is used by WASI to
// initialize certain interfaces.
func commandModule(m api.Module, p *Plugin) (moduleRuntime, bool) {
init := findFunc(m, p, "__wasm_call_ctors")
if init == nil {
return moduleRuntime{}, false
}
p.Logf(LogLevelTrace, "WASI runtime detected")
p.Logf(LogLevelTrace, "Command module detected")
return moduleRuntime{runtimeType: Wasi, init: init}, true
}
func findFunc(m api.Module, p *Plugin, name string) func(context.Context) error {
initFunc := m.ExportedFunction(name)
if initFunc == nil {
return nil
}
params := initFunc.Definition().ParamTypes()
if len(params) != 0 {
p.Logf(LogLevelTrace, "%v function found with type %v", name, params)
return nil
}
return func(ctx context.Context) error {
p.Logf(LogLevelDebug, "Calling %v", name)
_, err := initFunc.Call(ctx)
return err
}
}
func equal(actual []byte, expected []byte) bool {
if len(actual) != len(expected) {
return false
}
for i, k := range actual {
if expected[i] != k {
return false
}
}
return true
}