diff --git a/runtime/runtime.go b/runtime/runtime.go index 73fc7cc413..0536d8eed2 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -20,6 +20,11 @@ import ( type ModuleLoader func(*starlark.Thread, string) (starlark.StringDict, error) +// ThreadInitializer is called when building a Starlark thread to run an applet +// on. It can customize the thread by overriding behavior or attaching thread +// local data. +type ThreadInitializer func(thread *starlark.Thread) *starlark.Thread + func init() { resolve.AllowFloat = true resolve.AllowLambda = true @@ -38,14 +43,20 @@ type Applet struct { main *starlark.Function } -func (a *Applet) thread() *starlark.Thread { - return &starlark.Thread{ +func (a *Applet) thread(initializers ...ThreadInitializer) *starlark.Thread { + t := &starlark.Thread{ Name: a.Id, Load: a.loadModule, Print: func(thread *starlark.Thread, msg string) { fmt.Printf("[%s] %s\n", a.Filename, msg) }, } + + for _, init := range initializers { + t = init(t) + } + + return t } // Loads an applet. The script filename is used as a descriptor only, @@ -90,7 +101,7 @@ func (a *Applet) Load(filename string, src []byte, loader ModuleLoader) (err err // Runs the applet's main function, passing it configuration as a // starlark dict. -func (a *Applet) Run(config map[string]string) (roots []render.Root, err error) { +func (a *Applet) Run(config map[string]string, initializers ...ThreadInitializer) (roots []render.Root, err error) { var args starlark.Tuple if a.main.NumParams() > 0 { starlarkConfig := starlark.NewDict(len(config)) @@ -103,7 +114,7 @@ func (a *Applet) Run(config map[string]string) (roots []render.Root, err error) args = starlark.Tuple{starlarkConfig} } - returnValue, err := a.Call(a.main, args) + returnValue, err := a.Call(a.main, args, initializers...) if err != nil { return nil, err } @@ -137,14 +148,14 @@ func (a *Applet) Run(config map[string]string) (roots []render.Root, err error) // Calls any callable from Applet.Globals. Pass args and receive a // starlark Value, or an error if you're unlucky. -func (a *Applet) Call(callable *starlark.Function, args starlark.Tuple) (val starlark.Value, err error) { +func (a *Applet) Call(callable *starlark.Function, args starlark.Tuple, initializers ...ThreadInitializer) (val starlark.Value, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("panic while running %s: %v", a.Filename, r) } }() - resultVal, err := starlark.Call(a.thread(), callable, args, nil) + resultVal, err := starlark.Call(a.thread(initializers...), callable, args, nil) if err != nil { evalErr, ok := err.(*starlark.EvalError) if ok { diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index dec7204c5c..88de68ed9f 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -217,6 +217,32 @@ def main(): } +func TestThreadInitializer(t *testing.T) { + src := ` +load("render.star", "render") +def main(): + print('foobar') + return render.Root(child=render.Box()) +` + // override the print function of the thread + var printedText string + initializer := func(thread *starlark.Thread) *starlark.Thread { + thread.Print = func(thread *starlark.Thread, msg string) { + printedText += msg + } + return thread + } + + app := &Applet{} + err := app.Load("test.star", []byte(src), nil) + assert.NoError(t, err) + _, err = app.Run(map[string]string{}, initializer) + assert.NoError(t, err) + + // our print function should have been called + assert.Equal(t, "foobar", printedText) +} + func TestXPathModule(t *testing.T) { src := ` load("render.star", r="render")