working commit

This commit is contained in:
2026-03-13 19:02:42 +02:00
parent bebbf79c7a
commit 5c1da77f4c
1329 changed files with 314708 additions and 39 deletions
+694
View File
@@ -0,0 +1,694 @@
package extism
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"unsafe"
// TODO: is there a better package for this?
"github.com/gobwas/glob"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
type ValueType = byte
const (
// ValueTypeI32 is a 32-bit integer.
ValueTypeI32 = api.ValueTypeI32
// ValueTypeI64 is a 64-bit integer.
ValueTypeI64 = api.ValueTypeI64
// ValueTypeF32 is a 32-bit floating point number.
ValueTypeF32 = api.ValueTypeF32
// ValueTypeF64 is a 64-bit floating point number.
ValueTypeF64 = api.ValueTypeF64
// ValueTypePTR represents a pointer to an Extism memory block. Alias for ValueTypeI64
ValueTypePTR = ValueTypeI64
)
// HostFunctionStackCallback is a Function implemented in Go instead of a wasm binary.
// The plugin parameter is the calling plugin, used to access memory or
// exported functions and logging.
//
// The stack is includes any parameters encoded according to their ValueType.
// Its length is the max of parameter or result length. When there are results,
// write them in order beginning at index zero. Do not use the stack after the
// function returns.
//
// Here's a typical way to read three parameters and write back one.
//
// // read parameters in index order
// argv, argvBuf := DecodeU32(inputs[0]), DecodeU32(inputs[1])
//
// // write results back to the stack in index order
// stack[0] = EncodeU32(ErrnoSuccess)
//
// This function can be non-deterministic or cause side effects. It also
// has special properties not defined in the WebAssembly Core specification.
// Notably, this uses the caller's memory (via Module.Memory). See
// https://www.w3.org/TR/wasm-core-1/#host-functions%E2%91%A0
//
// To safely decode/encode values from/to the uint64 inputs/ouputs, users are encouraged to use
// Extism's EncodeXXX or DecodeXXX functions.
type HostFunctionStackCallback func(ctx context.Context, p *CurrentPlugin, stack []uint64)
// HostFunction represents a custom function defined by the host.
type HostFunction struct {
stackCallback HostFunctionStackCallback
Name string
Namespace string
Params []api.ValueType
Returns []api.ValueType
}
func (f *HostFunction) SetNamespace(namespace string) {
f.Namespace = namespace
}
// NewHostFunctionWithStack creates a new instance of a HostFunction, which is designed
// to provide custom functionality in a given host environment.
// Here's an example multiplication function that loads operands from memory:
//
// mult := NewHostFunctionWithStack(
// "mult",
// func(ctx context.Context, plugin *CurrentPlugin, stack []uint64) {
// a := DecodeI32(stack[0])
// b := DecodeI32(stack[1])
//
// stack[0] = EncodeI32(a * b)
// },
// []ValueType{ValueTypeI64, ValueTypeI64},
// ValueTypeI64
// )
func NewHostFunctionWithStack(
name string,
callback HostFunctionStackCallback,
params []ValueType,
returnTypes []ValueType) HostFunction {
return HostFunction{
stackCallback: callback,
Name: name,
Namespace: "extism:host/user",
Params: params,
Returns: returnTypes,
}
}
type CurrentPlugin struct {
plugin *Plugin
}
func (p *Plugin) currentPlugin() *CurrentPlugin {
return &CurrentPlugin{p}
}
func (p *CurrentPlugin) Log(level LogLevel, message string) {
p.plugin.Log(level, message)
}
func (p *CurrentPlugin) Logf(level LogLevel, format string, args ...any) {
p.plugin.Logf(level, format, args...)
}
// Memory returns the plugin's WebAssembly memory interface.
func (p *CurrentPlugin) Memory() api.Memory {
return p.plugin.Memory()
}
// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {
return p.AllocWithContext(context.Background(), n)
}
// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) {
out, err := p.plugin.extism.ExportedFunction("alloc").Call(ctx, uint64(n))
if err != nil {
return 0, err
} else if len(out) != 1 {
return 0, fmt.Errorf("expected 1 return, go %v", len(out))
}
return uint64(out[0]), nil
}
// Free the memory block specified by the given offset
func (p *CurrentPlugin) Free(offset uint64) error {
return p.FreeWithContext(context.Background(), offset)
}
// Free the memory block specified by the given offset
func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error {
_, err := p.plugin.extism.ExportedFunction("free").Call(ctx, uint64(offset))
if err != nil {
return err
}
return nil
}
// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) Length(offs uint64) (uint64, error) {
return p.LengthWithContext(context.Background(), offs)
}
// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) {
out, err := p.plugin.extism.ExportedFunction("length").Call(ctx, uint64(offs))
if err != nil {
return 0, err
} else if len(out) != 1 {
return 0, fmt.Errorf("expected 1 return, go %v", len(out))
}
return uint64(out[0]), nil
}
// Write a string to wasm memory and return the offset
func (p *CurrentPlugin) WriteString(s string) (uint64, error) {
return p.WriteBytes([]byte(s))
}
// WriteBytes writes a string to wasm memory and return the offset
func (p *CurrentPlugin) WriteBytes(b []byte) (uint64, error) {
ptr, err := p.Alloc(uint64(len(b)))
if err != nil {
return 0, err
}
ok := p.Memory().Write(uint32(ptr), b)
if !ok {
return 0, fmt.Errorf("failed to write to memory")
}
return ptr, nil
}
// ReadString reads a string from wasm memory
func (p *CurrentPlugin) ReadString(offset uint64) (string, error) {
buffer, err := p.ReadBytes(offset)
if err != nil {
return "", err
}
return string(buffer), nil
}
// ReadBytes reads a byte array from memory
func (p *CurrentPlugin) ReadBytes(offset uint64) ([]byte, error) {
length, err := p.Length(offset)
if err != nil {
return []byte{}, err
}
buffer, ok := p.Memory().Read(uint32(offset), uint32(length))
if !ok {
return []byte{}, fmt.Errorf("invalid memory block")
}
cpy := make([]byte, len(buffer))
copy(cpy, buffer)
return cpy, nil
}
func buildHostModule(ctx context.Context, rt wazero.Runtime, name string, funcs []HostFunction) (api.Module, error) {
builder := rt.NewHostModuleBuilder(name)
defineCustomHostFunctions(builder, funcs)
return builder.Instantiate(ctx)
}
func defineCustomHostFunctions(builder wazero.HostModuleBuilder, funcs []HostFunction) {
for _, f := range funcs {
// Go closures capture variables by reference, not by value.
// This means that if you directly use f inside the closure without creating
// a separate variable (closure) and assigning the value of f to it, you might run into unexpected behavior.
// All the closures created in the loop would end up referencing the same f, which could lead to incorrect or unintended results.
// See: https://github.com/extism/go-sdk/issues/5#issuecomment-1666774486
closure := f.stackCallback
builder.NewFunctionBuilder().WithGoFunction(api.GoFunc(func(ctx context.Context, stack []uint64) {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
closure(ctx, &CurrentPlugin{plugin}, stack)
return
}
panic("Invalid context, `plugin` key not found")
}), f.Params, f.Returns).Export(f.Name)
}
}
func instantiateEnvModule(ctx context.Context, rt wazero.Runtime) (api.Module, error) {
builder := rt.NewHostModuleBuilder("extism:host/env")
// A wrapper that creates allows calls from guest -> go host -> extism kernel wasm
// See https://github.com/extism/proposals/blob/main/EIP-007-extism-runtime-kernel.md.
extismFunc := func(name string, params []ValueType, results []ValueType) {
builder.
NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(func(ctx context.Context, m api.Module, stack []uint64) {
extism, ok := ctx.Value(PluginCtxKey("extism")).(api.Module)
if !ok {
panic("Invalid context, `extism` key not found")
}
f := extism.ExportedFunction(name)
if f == nil {
panic(fmt.Errorf("function %q not found in extism:host", name))
}
err := f.CallWithStack(ctx, stack)
if err != nil {
panic(err)
}
}), params, results).
Export(name)
}
extismFunc("alloc", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64})
extismFunc("free", []ValueType{ValueTypeI64}, []ValueType{})
extismFunc("load_u8", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI32})
extismFunc("input_load_u8", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI32})
extismFunc("store_u64", []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{})
extismFunc("store_u8", []ValueType{ValueTypeI64, ValueTypeI32}, []ValueType{})
extismFunc("input_set", []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{})
extismFunc("output_set", []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{})
extismFunc("input_length", []ValueType{}, []ValueType{ValueTypeI64})
extismFunc("input_offset", []ValueType{}, []ValueType{ValueTypeI64})
extismFunc("output_length", []ValueType{}, []ValueType{ValueTypeI64})
extismFunc("output_offset", []ValueType{}, []ValueType{ValueTypeI64})
extismFunc("length", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64})
extismFunc("length_unsafe", []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64})
extismFunc("reset", []ValueType{}, []ValueType{})
extismFunc("error_set", []ValueType{ValueTypeI64}, []ValueType{})
extismFunc("error_get", []ValueType{}, []ValueType{ValueTypeI64})
extismFunc("memory_bytes", []ValueType{}, []ValueType{ValueTypeI64})
builder.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(api.GoModuleFunc(inputLoad_u64)), []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64}).
Export("input_load_u64")
builder.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(load_u64), []ValueType{ValueTypeI64}, []ValueType{ValueTypeI64}).
Export("load_u64")
builder.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(store_u64), []ValueType{ValueTypeI64, ValueTypeI64}, []ValueType{}).
Export("store_u64")
hostFunc := func(name string, f interface{}) {
builder.NewFunctionBuilder().WithFunc(f).Export(name)
}
hostFunc("config_get", configGet)
hostFunc("var_get", varGet)
hostFunc("var_set", varSet)
hostFunc("http_request", httpRequest)
hostFunc("http_status_code", httpStatusCode)
hostFunc("http_headers", httpHeaders)
hostFunc("get_log_level", getLogLevel)
logFunc := func(name string, level LogLevel) {
hostFunc(name, func(ctx context.Context, m api.Module, offset uint64) {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
if LogLevel(pluginLogLevel.Load()) > level {
plugin.currentPlugin().Free(offset)
return
}
message, err := plugin.currentPlugin().ReadString(offset)
if err != nil {
panic(fmt.Errorf("failed to read log message from memory: %v", err))
}
plugin.Log(level, message)
plugin.currentPlugin().Free(offset)
return
}
panic("Invalid context, `plugin` key not found")
})
}
logFunc("log_trace", LogLevelTrace)
logFunc("log_debug", LogLevelDebug)
logFunc("log_info", LogLevelInfo)
logFunc("log_warn", LogLevelWarn)
logFunc("log_error", LogLevelError)
return builder.Instantiate(ctx)
}
func store_u64(ctx context.Context, mod api.Module, stack []uint64) {
p, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin)
if !ok {
panic("Invalid context")
}
offset := stack[0]
value := stack[1]
ok = p.Memory().WriteUint64Le(uint32(offset), value)
if !ok {
panic(fmt.Sprintf("could not write value '%v' at offset: %v", value, offset))
}
}
func load_u64(ctx context.Context, mod api.Module, stack []uint64) {
p, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin)
if !ok {
panic("Invalid context")
}
stack[0], ok = p.Memory().ReadUint64Le(uint32(stack[0]))
if !ok {
panic(fmt.Sprintf("could not read value at offset: %v", stack[0]))
}
}
func inputLoad_u64(ctx context.Context, mod api.Module, stack []uint64) {
p, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin)
if !ok {
panic("Invalid context")
}
offset, ok := ctx.Value(InputOffsetKey("inputOffset")).(uint64)
if !ok {
panic("Invalid context")
}
stack[0], ok = p.Memory().ReadUint64Le(uint32(stack[0] + offset))
if !ok {
panic(fmt.Sprintf("could not read value at offset: %v", stack[0]))
}
}
func configGet(ctx context.Context, m api.Module, offset uint64) uint64 {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
cp := plugin.currentPlugin()
name, err := cp.ReadString(offset)
if err != nil {
panic(fmt.Errorf("failed to read config name from memory: %v", err))
}
value, ok := plugin.Config[name]
if !ok {
// Return 0 without an error if key is not found
return 0
}
offset, err = cp.WriteString(value)
if err != nil {
panic(fmt.Errorf("failed to write config value to memory: %v", err))
}
return offset
}
panic("Invalid context, `plugin` key not found")
}
func varGet(ctx context.Context, m api.Module, offset uint64) uint64 {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
cp := plugin.currentPlugin()
name, err := cp.ReadString(offset)
if err != nil {
panic(fmt.Errorf("failed to read var name from memory: %v", err))
}
cp.Free(offset)
value, ok := plugin.Var[name]
if !ok {
// Return 0 without an error if key is not found
return 0
}
offset, err = cp.WriteBytes(value)
if err != nil {
panic(fmt.Errorf("failed to write var value to memory: %v", err))
}
return offset
}
panic("Invalid context, `plugin` key not found")
}
func varSet(ctx context.Context, m api.Module, nameOffset uint64, valueOffset uint64) {
plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin)
if !ok {
panic("Invalid context, `plugin` key not found")
}
if plugin.MaxVarBytes == 0 {
panic("Vars are disabled by this host")
}
cp := plugin.currentPlugin()
name, err := cp.ReadString(nameOffset)
if err != nil {
panic(fmt.Errorf("failed to read var name from memory: %v", err))
}
cp.Free(nameOffset)
// Remove if the value offset is 0
if valueOffset == 0 {
delete(plugin.Var, name)
return
}
value, err := cp.ReadBytes(valueOffset)
if err != nil {
panic(fmt.Errorf("failed to read var value from memory: %v", err))
}
cp.Free(valueOffset)
// Calculate size including current key/value
size := int(unsafe.Sizeof([]byte{})+unsafe.Sizeof("")) + len(name) + len(value)
for k, v := range plugin.Var {
size += len(k)
size += len(v)
size += int(unsafe.Sizeof([]byte{}) + unsafe.Sizeof(""))
}
if size >= int(plugin.MaxVarBytes) && valueOffset != 0 {
panic("Variable store is full")
}
plugin.Var[name] = value
}
func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOffset uint64) uint64 {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
cp := plugin.currentPlugin()
if plugin.LastResponseHeaders != nil {
for k := range plugin.LastResponseHeaders {
delete(plugin.LastResponseHeaders, k)
}
}
plugin.LastStatusCode = 0
requestJson, err := cp.ReadBytes(requestOffset)
if err != nil {
panic(fmt.Errorf("failed to read http request from memory: %v", err))
}
var request HttpRequest
err = json.Unmarshal(requestJson, &request)
cp.Free(requestOffset)
if err != nil {
panic(fmt.Errorf("invalid http request: %v", err))
}
// default method to GET and force to be upper
if request.Method == "" {
request.Method = "GET"
}
request.Method = strings.ToUpper(request.Method)
url, err := url.Parse(request.Url)
if err != nil {
panic(fmt.Errorf("invalid url: %v", err))
}
// deny all requests by default
hostMatches := false
for _, allowedHost := range plugin.AllowedHosts {
if allowedHost == url.Hostname() {
hostMatches = true
break
}
pattern := glob.MustCompile(allowedHost)
if pattern.Match(url.Hostname()) {
hostMatches = true
break
}
}
if !hostMatches {
panic(fmt.Errorf("HTTP request to '%v' is not allowed", request.Url))
}
var bodyReader io.Reader = nil
if bodyOffset != 0 {
body, err := cp.ReadBytes(bodyOffset)
if err != nil {
panic("failed to read response body from memory")
}
cp.Free(bodyOffset)
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequestWithContext(ctx, request.Method, request.Url, bodyReader)
if err != nil {
panic(err)
}
for key, value := range request.Headers {
req.Header.Set(key, value)
}
client := http.DefaultClient
resp, err := client.Do(req)
if err != nil {
panic(err)
}
defer resp.Body.Close()
if plugin.LastResponseHeaders != nil {
for k, v := range resp.Header {
plugin.LastResponseHeaders[strings.ToLower(k)] = strings.Join(v, ",")
}
}
plugin.LastStatusCode = resp.StatusCode
limiter := http.MaxBytesReader(nil, resp.Body, int64(plugin.MaxHttpResponseBytes))
body, err := io.ReadAll(limiter)
if err != nil {
panic(err)
}
if len(body) == 0 {
return 0
} else {
offset, err := cp.WriteBytes(body)
if err != nil {
panic("Failed to write resposne body to memory")
}
return offset
}
}
panic("Invalid context, `plugin` key not found")
}
func httpStatusCode(ctx context.Context, m api.Module) int32 {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
return int32(plugin.LastStatusCode)
}
panic("Invalid context, `plugin` key not found")
}
func httpHeaders(ctx context.Context, _ api.Module) uint64 {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
if plugin.LastResponseHeaders == nil {
return 0
}
data, err := json.Marshal(plugin.LastResponseHeaders)
if err != nil {
panic(err)
}
mem, err := plugin.currentPlugin().WriteBytes(data)
if err != nil {
panic(err)
}
return mem
}
panic("Invalid context, `plugin` key not found")
}
func getLogLevel(ctx context.Context, m api.Module) int32 {
// if _, ok := callCtx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
// panic("Invalid context, `plugin` key not found")
// }
return LogLevel(pluginLogLevel.Load()).ExtismCompat()
}
// EncodeI32 encodes the input as a ValueTypeI32.
func EncodeI32(input int32) uint64 {
return api.EncodeI32(input)
}
// DecodeI32 decodes the input as a ValueTypeI32.
func DecodeI32(input uint64) int32 {
return api.DecodeI32(input)
}
// EncodeU32 encodes the input as a ValueTypeI32.
func EncodeU32(input uint32) uint64 {
return api.EncodeU32(input)
}
// DecodeU32 decodes the input as a ValueTypeI32.
func DecodeU32(input uint64) uint32 {
return api.DecodeU32(input)
}
// EncodeI64 encodes the input as a ValueTypeI64.
func EncodeI64(input int64) uint64 {
return api.EncodeI64(input)
}
// EncodeF32 encodes the input as a ValueTypeF32.
//
// See DecodeF32
func EncodeF32(input float32) uint64 {
return api.EncodeF32(input)
}
// DecodeF32 decodes the input as a ValueTypeF32.
//
// See EncodeF32
func DecodeF32(input uint64) float32 {
return api.DecodeF32(input)
}
// EncodeF64 encodes the input as a ValueTypeF64.
//
// See EncodeF32
func EncodeF64(input float64) uint64 {
return api.EncodeF64(input)
}
// DecodeF64 decodes the input as a ValueTypeF64.
//
// See EncodeF64
func DecodeF64(input uint64) float64 {
return api.DecodeF64(input)
}