From df7a318ec01dda1ed87a2936efa28a138e1ea590 Mon Sep 17 00:00:00 2001 From: Tonis Tiigi Date: Thu, 25 Mar 2021 19:36:03 -0700 Subject: [PATCH] bake: allow user functions in variables and vice-versa Signed-off-by: Tonis Tiigi --- bake/bake.go | 7 + bake/hcl.go | 402 ++++++++++++++++++++++++++++++++-------- bake/hcl_test.go | 53 ++++++ util/userfunc/decode.go | 87 +-------- util/userfunc/public.go | 83 +++++++++ 5 files changed, 470 insertions(+), 162 deletions(-) diff --git a/bake/bake.go b/bake/bake.go index ac7adcd6..386cff38 100644 --- a/bake/bake.go +++ b/bake/bake.go @@ -406,6 +406,13 @@ type Variable struct { Default *hcl.Attribute `json:"default,omitempty" hcl:"default,optional"` } +type Function struct { + Name string `json:"-" hcl:"name,label"` + Params *hcl.Attribute `json:"params,omitempty" hcl:"params"` + Variadic *hcl.Attribute `json:"variadic_param,omitempty" hcl:"variadic_params"` + Result *hcl.Attribute `json:"result,omitempty" hcl:"result"` +} + type Group struct { Name string `json:"-" hcl:"name,label"` Targets []string `json:"targets" hcl:"targets"` diff --git a/bake/hcl.go b/bake/hcl.go index 6ba868ec..6e0ff41e 100644 --- a/bake/hcl.go +++ b/bake/hcl.go @@ -4,8 +4,10 @@ import ( "math" "math/big" "os" + "reflect" "strconv" "strings" + "unsafe" "github.com/docker/buildx/util/userfunc" "github.com/hashicorp/go-cty-funcs/cidr" @@ -131,14 +133,18 @@ var ( type StaticConfig struct { Variables []*Variable `hcl:"variable,block"` + Functions []*Function `hcl:"function,block"` Remain hcl.Body `hcl:",remain"` attrs hcl.Attributes - defaults map[string]*hcl.Attribute - env map[string]string - values map[string]cty.Value - progress map[string]struct{} + defaults map[string]*hcl.Attribute + funcDefs map[string]*Function + funcs map[string]function.Function + env map[string]string + ectx hcl.EvalContext + progress map[string]struct{} + progressF map[string]struct{} } func mergeStaticConfig(scs []*StaticConfig) *StaticConfig { @@ -148,6 +154,7 @@ func mergeStaticConfig(scs []*StaticConfig) *StaticConfig { sc := scs[0] for _, s := range scs[1:] { sc.Variables = append(sc.Variables, s.Variables...) + sc.Functions = append(sc.Functions, s.Functions...) for k, v := range s.attrs { sc.attrs[k] = v } @@ -155,9 +162,16 @@ func mergeStaticConfig(scs []*StaticConfig) *StaticConfig { return sc } -func (sc *StaticConfig) Values(withEnv bool) (map[string]cty.Value, error) { +func (sc *StaticConfig) EvalContext(withEnv bool) (*hcl.EvalContext, error) { + // json parser also parses blocks as attributes + delete(sc.attrs, "target") + delete(sc.attrs, "function") + sc.defaults = map[string]*hcl.Attribute{} for _, v := range sc.Variables { + if v.Name == "target" { + continue + } sc.defaults[v.Name] = v.Default } @@ -171,35 +185,305 @@ func (sc *StaticConfig) Values(withEnv bool) (map[string]cty.Value, error) { } } - sc.values = map[string]cty.Value{} - sc.progress = map[string]struct{}{} + sc.funcDefs = map[string]*Function{} + for _, v := range sc.Functions { + sc.funcDefs[v.Name] = v + } + sc.ectx = hcl.EvalContext{ + Variables: map[string]cty.Value{}, + Functions: stdlibFunctions, + } + sc.funcs = map[string]function.Function{} + sc.progress = map[string]struct{}{} + sc.progressF = map[string]struct{}{} for k := range sc.attrs { - if _, err := sc.resolveValue(k); err != nil { + if err := sc.resolveValue(k); err != nil { return nil, err } } for k := range sc.defaults { - if _, err := sc.resolveValue(k); err != nil { + if err := sc.resolveValue(k); err != nil { return nil, err } } - return sc.values, nil + + for k := range sc.funcDefs { + if err := sc.resolveFunction(k); err != nil { + return nil, err + } + } + return &sc.ectx, nil } -func (sc *StaticConfig) resolveValue(name string) (v *cty.Value, err error) { - if v, ok := sc.values[name]; ok { - return &v, nil +type jsonExp interface { + ExprList() []hcl.Expression + ExprMap() []hcl.KeyValuePair +} + +func elementExpressions(je jsonExp, exp hcl.Expression) []hcl.Expression { + list := je.ExprList() + if len(list) != 0 { + exp := make([]hcl.Expression, 0, len(list)) + for _, e := range list { + if je, ok := e.(jsonExp); ok { + exp = append(exp, elementExpressions(je, e)...) + } + } + return exp + } + kvlist := je.ExprMap() + if len(kvlist) != 0 { + exp := make([]hcl.Expression, 0, len(kvlist)*2) + for _, p := range kvlist { + exp = append(exp, p.Key) + if je, ok := p.Value.(jsonExp); ok { + exp = append(exp, elementExpressions(je, p.Value)...) + } + } + return exp + } + return []hcl.Expression{exp} +} + +func jsonFuncCallsRecursive(exp hcl.Expression) ([]string, error) { + je, ok := exp.(jsonExp) + if !ok { + return nil, errors.Errorf("invalid expression type %T", exp) + } + m := map[string]struct{}{} + for _, e := range elementExpressions(je, exp) { + if err := appendJSONFuncCalls(e, m); err != nil { + return nil, err + } + } + arr := make([]string, 0, len(m)) + for n := range m { + arr = append(arr, n) + } + return arr, nil +} + +func appendJSONFuncCalls(exp hcl.Expression, m map[string]struct{}) error { + v := reflect.ValueOf(exp) + if v.Kind() != reflect.Ptr || v.IsNil() { + return errors.Errorf("invalid json expression kind %T %v", exp, v.Kind()) + } + if v.Elem().Kind() != reflect.Struct { + return errors.Errorf("invalid json expression pointer to %T %v", exp, v.Elem().Kind()) + } + src := v.Elem().FieldByName("src") + if src.IsZero() { + return errors.Errorf("%v has no property src", v.Elem().Type()) + } + if src.Kind() != reflect.Interface { + return errors.Errorf("%v src is not interface: %v", src.Type(), src.Kind()) + } + src = src.Elem() + if src.IsNil() { + return nil + } + if src.Kind() == reflect.Ptr { + src = src.Elem() + } + if src.Kind() != reflect.Struct { + return errors.Errorf("%v is not struct: %v", src.Type(), src.Kind()) + } + + // hcl/v2/json/ast#stringVal + val := src.FieldByName("Value") + if val.IsZero() { + return nil + } + rng := src.FieldByName("SrcRange") + if val.IsZero() { + return nil + } + var stringVal struct { + Value string + SrcRange hcl.Range + } + + if !val.Type().AssignableTo(reflect.ValueOf(stringVal.Value).Type()) { + return nil + } + if !rng.Type().AssignableTo(reflect.ValueOf(stringVal.SrcRange).Type()) { + return nil + } + // reflect.Set does not work for unexported fields + stringVal.Value = *(*string)(unsafe.Pointer(val.UnsafeAddr())) + stringVal.SrcRange = *(*hcl.Range)(unsafe.Pointer(rng.UnsafeAddr())) + + expr, diags := hclsyntax.ParseExpression([]byte(stringVal.Value), stringVal.SrcRange.Filename, stringVal.SrcRange.Start) + if diags.HasErrors() { + return nil + } + + fns, err := funcCalls(expr) + if err != nil { + return err + } + + for _, fn := range fns { + m[fn] = struct{}{} + } + + return nil +} + +func funcCalls(exp hcl.Expression) ([]string, hcl.Diagnostics) { + node, ok := exp.(hclsyntax.Node) + if !ok { + fns, err := jsonFuncCallsRecursive(exp) + if err != nil { + return nil, hcl.Diagnostics{ + &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Invalid expression", + Detail: err.Error(), + Subject: exp.Range().Ptr(), + Context: exp.Range().Ptr(), + }, + } + } + return fns, nil + } + + var funcnames []string + hcldiags := hclsyntax.VisitAll(node, func(n hclsyntax.Node) hcl.Diagnostics { + if fe, ok := n.(*hclsyntax.FunctionCallExpr); ok { + funcnames = append(funcnames, fe.Name) + } + return nil + }) + if hcldiags.HasErrors() { + return nil, hcldiags + } + return funcnames, nil +} + +func (sc *StaticConfig) loadDeps(exp hcl.Expression, exclude map[string]struct{}) hcl.Diagnostics { + fns, hcldiags := funcCalls(exp) + if hcldiags.HasErrors() { + return hcldiags + } + + for _, fn := range fns { + if err := sc.resolveFunction(fn); err != nil { + return hcl.Diagnostics{ + &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Invalid expression", + Detail: err.Error(), + Subject: exp.Range().Ptr(), + Context: exp.Range().Ptr(), + }, + } + } + } + + for _, v := range exp.Variables() { + if _, ok := exclude[v.RootName()]; ok { + continue + } + if err := sc.resolveValue(v.RootName()); err != nil { + return hcl.Diagnostics{ + &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Invalid expression", + Detail: err.Error(), + Subject: v.SourceRange().Ptr(), + Context: v.SourceRange().Ptr(), + }, + } + } + } + + return nil +} + +func (sc *StaticConfig) resolveFunction(name string) error { + if _, ok := sc.funcs[name]; ok { + return nil + } + f, ok := sc.funcDefs[name] + if !ok { + if _, ok := sc.ectx.Functions[name]; ok { + return nil + } + return errors.Errorf("undefined function %s", name) + } + if _, ok := sc.progressF[name]; ok { + return errors.Errorf("function cycle not allowed for %s", name) + } + sc.progressF[name] = struct{}{} + + paramExprs, paramsDiags := hcl.ExprList(f.Params.Expr) + if paramsDiags.HasErrors() { + return paramsDiags + } + var diags hcl.Diagnostics + params := map[string]struct{}{} + for _, paramExpr := range paramExprs { + param := hcl.ExprAsKeyword(paramExpr) + if param == "" { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Invalid param element", + Detail: "Each parameter name must be an identifier.", + Subject: paramExpr.Range().Ptr(), + }) + } + params[param] = struct{}{} + } + var variadic hcl.Expression + if f.Variadic != nil { + variadic = f.Variadic.Expr + param := hcl.ExprAsKeyword(variadic) + if param == "" { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Invalid param element", + Detail: "Each parameter name must be an identifier.", + Subject: f.Variadic.Range.Ptr(), + }) + } + params[param] = struct{}{} + } + if diags.HasErrors() { + return diags + } + + if diags := sc.loadDeps(f.Result.Expr, params); diags.HasErrors() { + return diags + } + + v, diags := userfunc.NewFunction(f.Params.Expr, variadic, f.Result.Expr, func() *hcl.EvalContext { + return &sc.ectx + }) + if diags.HasErrors() { + return diags + } + sc.funcs[name] = v + sc.ectx.Functions[name] = v + + return nil +} + +func (sc *StaticConfig) resolveValue(name string) (err error) { + if _, ok := sc.ectx.Variables[name]; ok { + return nil } if _, ok := sc.progress[name]; ok { - return nil, errors.Errorf("variable cycle not allowed") + return errors.Errorf("variable cycle not allowed for %s", name) } sc.progress[name] = struct{}{} + var v *cty.Value defer func() { if v != nil { - sc.values[name] = *v + sc.ectx.Variables[name] = *v } }() @@ -207,43 +491,22 @@ func (sc *StaticConfig) resolveValue(name string) (v *cty.Value, err error) { if !ok { def, ok = sc.defaults[name] if !ok { - return nil, errors.Errorf("undefined variable %q", name) + return errors.Errorf("undefined variable %q", name) } } if def == nil { - v := cty.StringVal(sc.env[name]) - return &v, nil + vv := cty.StringVal(sc.env[name]) + v = &vv + return } - ectx := &hcl.EvalContext{ - Variables: map[string]cty.Value{}, - Functions: stdlibFunctions, // user functions not possible atm + if diags := sc.loadDeps(def.Expr, nil); diags.HasErrors() { + return diags } - for _, v := range def.Expr.Variables() { - value, err := sc.resolveValue(v.RootName()) - if err != nil { - var diags hcl.Diagnostics - if !errors.As(err, &diags) { - return nil, err - } - r := v.SourceRange() - return nil, hcl.Diagnostics{ - &hcl.Diagnostic{ - Severity: hcl.DiagError, - Summary: "Invalid expression", - Detail: err.Error(), - Subject: &r, - Context: &r, - }, - } - } - ectx.Variables[v.RootName()] = *value - } - - vv, diags := def.Expr.Value(ectx) + vv, diags := def.Expr.Value(&sc.ectx) if diags.HasErrors() { - return nil, diags + return diags } _, isVar := sc.defaults[name] @@ -252,29 +515,33 @@ func (sc *StaticConfig) resolveValue(name string) (v *cty.Value, err error) { if vv.Type().Equals(cty.Bool) { b, err := strconv.ParseBool(envv) if err != nil { - return nil, errors.Wrapf(err, "failed to parse %s as bool", name) + return errors.Wrapf(err, "failed to parse %s as bool", name) } - v := cty.BoolVal(b) - return &v, nil + vv := cty.BoolVal(b) + v = &vv + return nil } else if vv.Type().Equals(cty.String) { - v := cty.StringVal(envv) - return &v, nil + vv := cty.StringVal(envv) + v = &vv + return nil } else if vv.Type().Equals(cty.Number) { n, err := strconv.ParseFloat(envv, 64) if err == nil && (math.IsNaN(n) || math.IsInf(n, 0)) { err = errors.Errorf("invalid number value") } if err != nil { - return nil, errors.Wrapf(err, "failed to parse %s as number", name) + return errors.Wrapf(err, "failed to parse %s as number", name) } - v := cty.NumberVal(big.NewFloat(n)) - return &v, nil + vv := cty.NumberVal(big.NewFloat(n)) + v = &vv + return nil } else { // TODO: support lists with csv values - return nil, errors.Errorf("unsupported type %s for variable %s", v.Type(), name) + return errors.Errorf("unsupported type %s for variable %s", v.Type(), name) } } - return &vv, nil + v = &vv + return nil } func ParseHCLFile(dt []byte, fn string) (*hcl.File, *StaticConfig, error) { @@ -331,36 +598,11 @@ func parseHCLFile(dt []byte, fn string) (f *hcl.File, _ *StaticConfig, err error } func ParseHCL(b hcl.Body, sc *StaticConfig) (_ *Config, err error) { - - // evaluate variables - variables, err := sc.Values(true) + ctx, err := sc.EvalContext(true) if err != nil { return nil, err } - userFunctions, _, diags := userfunc.DecodeUserFunctions(b, "function", func() *hcl.EvalContext { - return &hcl.EvalContext{ - Functions: stdlibFunctions, - Variables: variables, - } - }) - if diags.HasErrors() { - return nil, diags - } - - functions := make(map[string]function.Function) - for k, v := range stdlibFunctions { - functions[k] = v - } - for k, v := range userFunctions { - functions[k] = v - } - - ctx := &hcl.EvalContext{ - Variables: variables, - Functions: functions, - } - var c Config // Decode with variables and functions. diff --git a/bake/hcl_test.go b/bake/hcl_test.go index 70af0e3d..6f30aae1 100644 --- a/bake/hcl_test.go +++ b/bake/hcl_test.go @@ -513,3 +513,56 @@ func TestJSONAttributes(t *testing.T) { require.Equal(t, c.Targets[0].Name, "app") require.Equal(t, "pre-abc-def", c.Targets[0].Args["v1"]) } + +func TestJSONFunctions(t *testing.T) { + dt := []byte(`{ + "FOO": "abc", + "function": { + "myfunc": { + "params": ["inp"], + "result": "<${upper(inp)}-${FOO}>" + } + }, + "target": { + "app": { + "args": { + "v1": "pre-${myfunc(\"foo\")}" + } + } + }}`) + + c, err := ParseFile(dt, "docker-bake.json") + require.NoError(t, err) + + require.Equal(t, 1, len(c.Targets)) + require.Equal(t, c.Targets[0].Name, "app") + require.Equal(t, "pre-", c.Targets[0].Args["v1"]) +} + +func TestHCLFunctionInAttr(t *testing.T) { + dt := []byte(` + function "brace" { + params = [inp] + result = "[${inp}]" + } + function "myupper" { + params = [val] + result = "${upper(val)} <> ${brace(v2)}" + } + + v1=myupper("foo") + v2=lower("BAZ") + target "app" { + args = { + "v1": v1 + } + } + `) + + c, err := ParseFile(dt, "docker-bake.hcl") + require.NoError(t, err) + + require.Equal(t, 1, len(c.Targets)) + require.Equal(t, c.Targets[0].Name, "app") + require.Equal(t, "FOO <> [baz]", c.Targets[0].Args["v1"]) +} diff --git a/util/userfunc/decode.go b/util/userfunc/decode.go index 6c1e4ca4..6ca2b6be 100644 --- a/util/userfunc/decode.go +++ b/util/userfunc/decode.go @@ -2,7 +2,6 @@ package userfunc import ( "github.com/hashicorp/hcl/v2" - "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/function" ) @@ -53,7 +52,7 @@ func decodeUserFunctions(body hcl.Body, blockType string, contextFunc ContextFun } funcs = make(map[string]function.Function) -Blocks: + for _, block := range content.Blocks { name := block.Labels[0] funcContent, funcDiags := block.Body.Content(funcBodySchema) @@ -68,88 +67,12 @@ Blocks: if funcContent.Attributes["variadic_param"] != nil { varParamExpr = funcContent.Attributes["variadic_param"].Expr } - - var params []string - var varParam string - - paramExprs, paramsDiags := hcl.ExprList(paramsExpr) - diags = append(diags, paramsDiags...) - if paramsDiags.HasErrors() { + f, funcDiags := NewFunction(paramsExpr, varParamExpr, resultExpr, getBaseCtx) + if funcDiags.HasErrors() { + diags = append(diags, funcDiags...) continue } - for _, paramExpr := range paramExprs { - param := hcl.ExprAsKeyword(paramExpr) - if param == "" { - diags = append(diags, &hcl.Diagnostic{ - Severity: hcl.DiagError, - Summary: "Invalid param element", - Detail: "Each parameter name must be an identifier.", - Subject: paramExpr.Range().Ptr(), - }) - continue Blocks - } - params = append(params, param) - } - - if varParamExpr != nil { - varParam = hcl.ExprAsKeyword(varParamExpr) - if varParam == "" { - diags = append(diags, &hcl.Diagnostic{ - Severity: hcl.DiagError, - Summary: "Invalid variadic_param", - Detail: "The variadic parameter name must be an identifier.", - Subject: varParamExpr.Range().Ptr(), - }) - continue - } - } - - spec := &function.Spec{} - for _, paramName := range params { - spec.Params = append(spec.Params, function.Parameter{ - Name: paramName, - Type: cty.DynamicPseudoType, - }) - } - if varParamExpr != nil { - spec.VarParam = &function.Parameter{ - Name: varParam, - Type: cty.DynamicPseudoType, - } - } - impl := func(args []cty.Value) (cty.Value, error) { - ctx := getBaseCtx() - ctx = ctx.NewChild() - ctx.Variables = make(map[string]cty.Value) - - // The cty function machinery guarantees that we have at least - // enough args to fill all of our params. - for i, paramName := range params { - ctx.Variables[paramName] = args[i] - } - if spec.VarParam != nil { - varArgs := args[len(params):] - ctx.Variables[varParam] = cty.TupleVal(varArgs) - } - - result, diags := resultExpr.Value(ctx) - if diags.HasErrors() { - // Smuggle the diagnostics out via the error channel, since - // a diagnostics sequence implements error. Caller can - // type-assert this to recover the individual diagnostics - // if desired. - return cty.DynamicVal, diags - } - return result, nil - } - spec.Type = func(args []cty.Value) (cty.Type, error) { - val, err := impl(args) - return val.Type(), err - } - spec.Impl = func(args []cty.Value, retType cty.Type) (cty.Value, error) { - return impl(args) - } - funcs[name] = function.New(spec) + funcs[name] = f } return funcs, remain, diags diff --git a/util/userfunc/public.go b/util/userfunc/public.go index 5415c8c9..2a5c394b 100644 --- a/util/userfunc/public.go +++ b/util/userfunc/public.go @@ -2,6 +2,7 @@ package userfunc import ( "github.com/hashicorp/hcl/v2" + "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/function" ) @@ -40,3 +41,85 @@ type ContextFunc func() *hcl.EvalContext func DecodeUserFunctions(body hcl.Body, blockType string, context ContextFunc) (funcs map[string]function.Function, remain hcl.Body, diags hcl.Diagnostics) { return decodeUserFunctions(body, blockType, context) } + +// NewFunction creates a new function instance from preparsed HCL expressions. +func NewFunction(paramsExpr, varParamExpr, resultExpr hcl.Expression, getBaseCtx func() *hcl.EvalContext) (function.Function, hcl.Diagnostics) { + var params []string + var varParam string + + paramExprs, paramsDiags := hcl.ExprList(paramsExpr) + if paramsDiags.HasErrors() { + return function.Function{}, paramsDiags + } + for _, paramExpr := range paramExprs { + param := hcl.ExprAsKeyword(paramExpr) + if param == "" { + return function.Function{}, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: "Invalid param element", + Detail: "Each parameter name must be an identifier.", + Subject: paramExpr.Range().Ptr(), + }} + } + params = append(params, param) + } + + if varParamExpr != nil { + varParam = hcl.ExprAsKeyword(varParamExpr) + if varParam == "" { + return function.Function{}, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: "Invalid variadic_param", + Detail: "The variadic parameter name must be an identifier.", + Subject: varParamExpr.Range().Ptr(), + }} + } + } + + spec := &function.Spec{} + for _, paramName := range params { + spec.Params = append(spec.Params, function.Parameter{ + Name: paramName, + Type: cty.DynamicPseudoType, + }) + } + if varParamExpr != nil { + spec.VarParam = &function.Parameter{ + Name: varParam, + Type: cty.DynamicPseudoType, + } + } + impl := func(args []cty.Value) (cty.Value, error) { + ctx := getBaseCtx() + ctx = ctx.NewChild() + ctx.Variables = make(map[string]cty.Value) + + // The cty function machinery guarantees that we have at least + // enough args to fill all of our params. + for i, paramName := range params { + ctx.Variables[paramName] = args[i] + } + if spec.VarParam != nil { + varArgs := args[len(params):] + ctx.Variables[varParam] = cty.TupleVal(varArgs) + } + + result, diags := resultExpr.Value(ctx) + if diags.HasErrors() { + // Smuggle the diagnostics out via the error channel, since + // a diagnostics sequence implements error. Caller can + // type-assert this to recover the individual diagnostics + // if desired. + return cty.DynamicVal, diags + } + return result, nil + } + spec.Type = func(args []cty.Value) (cty.Type, error) { + val, err := impl(args) + return val.Type(), err + } + spec.Impl = func(args []cty.Value, retType cty.Type) (cty.Value, error) { + return impl(args) + } + return function.New(spec), nil +}