From 85cf3bace9b13782963be13c19d691fd3e8b95c6 Mon Sep 17 00:00:00 2001 From: Tonis Tiigi Date: Mon, 15 Jul 2024 13:35:04 -0700 Subject: [PATCH] hclparser: avoid unnecessary allocations in init Signed-off-by: Tonis Tiigi --- bake/hclparser/stdlib.go | 308 ++++++++++++++++++---------------- bake/hclparser/stdlib_test.go | 2 +- 2 files changed, 162 insertions(+), 148 deletions(-) diff --git a/bake/hclparser/stdlib.go b/bake/hclparser/stdlib.go index e0cd9b5f..bbe4a748 100644 --- a/bake/hclparser/stdlib.go +++ b/bake/hclparser/stdlib.go @@ -1,6 +1,7 @@ package hclparser import ( + "errors" "time" "github.com/hashicorp/go-cty-funcs/cidr" @@ -9,174 +10,187 @@ import ( "github.com/hashicorp/go-cty-funcs/uuid" "github.com/hashicorp/hcl/v2/ext/tryfunc" "github.com/hashicorp/hcl/v2/ext/typeexpr" - "github.com/pkg/errors" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/function" "github.com/zclconf/go-cty/cty/function/stdlib" ) -var stdlibFunctions = map[string]function.Function{ - "absolute": stdlib.AbsoluteFunc, - "add": stdlib.AddFunc, - "and": stdlib.AndFunc, - "base64decode": encoding.Base64DecodeFunc, - "base64encode": encoding.Base64EncodeFunc, - "bcrypt": crypto.BcryptFunc, - "byteslen": stdlib.BytesLenFunc, - "bytesslice": stdlib.BytesSliceFunc, - "can": tryfunc.CanFunc, - "ceil": stdlib.CeilFunc, - "chomp": stdlib.ChompFunc, - "chunklist": stdlib.ChunklistFunc, - "cidrhost": cidr.HostFunc, - "cidrnetmask": cidr.NetmaskFunc, - "cidrsubnet": cidr.SubnetFunc, - "cidrsubnets": cidr.SubnetsFunc, - "coalesce": stdlib.CoalesceFunc, - "coalescelist": stdlib.CoalesceListFunc, - "compact": stdlib.CompactFunc, - "concat": stdlib.ConcatFunc, - "contains": stdlib.ContainsFunc, - "convert": typeexpr.ConvertFunc, - "csvdecode": stdlib.CSVDecodeFunc, - "distinct": stdlib.DistinctFunc, - "divide": stdlib.DivideFunc, - "element": stdlib.ElementFunc, - "equal": stdlib.EqualFunc, - "flatten": stdlib.FlattenFunc, - "floor": stdlib.FloorFunc, - "format": stdlib.FormatFunc, - "formatdate": stdlib.FormatDateFunc, - "formatlist": stdlib.FormatListFunc, - "greaterthan": stdlib.GreaterThanFunc, - "greaterthanorequalto": stdlib.GreaterThanOrEqualToFunc, - "hasindex": stdlib.HasIndexFunc, - "indent": stdlib.IndentFunc, - "index": stdlib.IndexFunc, - "indexof": indexOfFunc, - "int": stdlib.IntFunc, - "join": stdlib.JoinFunc, - "jsondecode": stdlib.JSONDecodeFunc, - "jsonencode": stdlib.JSONEncodeFunc, - "keys": stdlib.KeysFunc, - "length": stdlib.LengthFunc, - "lessthan": stdlib.LessThanFunc, - "lessthanorequalto": stdlib.LessThanOrEqualToFunc, - "log": stdlib.LogFunc, - "lookup": stdlib.LookupFunc, - "lower": stdlib.LowerFunc, - "max": stdlib.MaxFunc, - "md5": crypto.Md5Func, - "merge": stdlib.MergeFunc, - "min": stdlib.MinFunc, - "modulo": stdlib.ModuloFunc, - "multiply": stdlib.MultiplyFunc, - "negate": stdlib.NegateFunc, - "not": stdlib.NotFunc, - "notequal": stdlib.NotEqualFunc, - "or": stdlib.OrFunc, - "parseint": stdlib.ParseIntFunc, - "pow": stdlib.PowFunc, - "range": stdlib.RangeFunc, - "regex_replace": stdlib.RegexReplaceFunc, - "regex": stdlib.RegexFunc, - "regexall": stdlib.RegexAllFunc, - "replace": stdlib.ReplaceFunc, - "reverse": stdlib.ReverseFunc, - "reverselist": stdlib.ReverseListFunc, - "rsadecrypt": crypto.RsaDecryptFunc, - "sethaselement": stdlib.SetHasElementFunc, - "setintersection": stdlib.SetIntersectionFunc, - "setproduct": stdlib.SetProductFunc, - "setsubtract": stdlib.SetSubtractFunc, - "setsymmetricdifference": stdlib.SetSymmetricDifferenceFunc, - "setunion": stdlib.SetUnionFunc, - "sha1": crypto.Sha1Func, - "sha256": crypto.Sha256Func, - "sha512": crypto.Sha512Func, - "signum": stdlib.SignumFunc, - "slice": stdlib.SliceFunc, - "sort": stdlib.SortFunc, - "split": stdlib.SplitFunc, - "strlen": stdlib.StrlenFunc, - "substr": stdlib.SubstrFunc, - "subtract": stdlib.SubtractFunc, - "timeadd": stdlib.TimeAddFunc, - "timestamp": timestampFunc, - "title": stdlib.TitleFunc, - "trim": stdlib.TrimFunc, - "trimprefix": stdlib.TrimPrefixFunc, - "trimspace": stdlib.TrimSpaceFunc, - "trimsuffix": stdlib.TrimSuffixFunc, - "try": tryfunc.TryFunc, - "upper": stdlib.UpperFunc, - "urlencode": encoding.URLEncodeFunc, - "uuidv4": uuid.V4Func, - "uuidv5": uuid.V5Func, - "values": stdlib.ValuesFunc, - "zipmap": stdlib.ZipmapFunc, +type funcDef struct { + name string + fn function.Function + factory func() function.Function +} + +var stdlibFunctions = []funcDef{ + {name: "absolute", fn: stdlib.AbsoluteFunc}, + {name: "add", fn: stdlib.AddFunc}, + {name: "and", fn: stdlib.AndFunc}, + {name: "base64decode", fn: encoding.Base64DecodeFunc}, + {name: "base64encode", fn: encoding.Base64EncodeFunc}, + {name: "bcrypt", fn: crypto.BcryptFunc}, + {name: "byteslen", fn: stdlib.BytesLenFunc}, + {name: "bytesslice", fn: stdlib.BytesSliceFunc}, + {name: "can", fn: tryfunc.CanFunc}, + {name: "ceil", fn: stdlib.CeilFunc}, + {name: "chomp", fn: stdlib.ChompFunc}, + {name: "chunklist", fn: stdlib.ChunklistFunc}, + {name: "cidrhost", fn: cidr.HostFunc}, + {name: "cidrnetmask", fn: cidr.NetmaskFunc}, + {name: "cidrsubnet", fn: cidr.SubnetFunc}, + {name: "cidrsubnets", fn: cidr.SubnetsFunc}, + {name: "coalesce", fn: stdlib.CoalesceFunc}, + {name: "coalescelist", fn: stdlib.CoalesceListFunc}, + {name: "compact", fn: stdlib.CompactFunc}, + {name: "concat", fn: stdlib.ConcatFunc}, + {name: "contains", fn: stdlib.ContainsFunc}, + {name: "convert", fn: typeexpr.ConvertFunc}, + {name: "csvdecode", fn: stdlib.CSVDecodeFunc}, + {name: "distinct", fn: stdlib.DistinctFunc}, + {name: "divide", fn: stdlib.DivideFunc}, + {name: "element", fn: stdlib.ElementFunc}, + {name: "equal", fn: stdlib.EqualFunc}, + {name: "flatten", fn: stdlib.FlattenFunc}, + {name: "floor", fn: stdlib.FloorFunc}, + {name: "format", fn: stdlib.FormatFunc}, + {name: "formatdate", fn: stdlib.FormatDateFunc}, + {name: "formatlist", fn: stdlib.FormatListFunc}, + {name: "greaterthan", fn: stdlib.GreaterThanFunc}, + {name: "greaterthanorequalto", fn: stdlib.GreaterThanOrEqualToFunc}, + {name: "hasindex", fn: stdlib.HasIndexFunc}, + {name: "indent", fn: stdlib.IndentFunc}, + {name: "index", fn: stdlib.IndexFunc}, + {name: "indexof", factory: indexOfFunc}, + {name: "int", fn: stdlib.IntFunc}, + {name: "join", fn: stdlib.JoinFunc}, + {name: "jsondecode", fn: stdlib.JSONDecodeFunc}, + {name: "jsonencode", fn: stdlib.JSONEncodeFunc}, + {name: "keys", fn: stdlib.KeysFunc}, + {name: "length", fn: stdlib.LengthFunc}, + {name: "lessthan", fn: stdlib.LessThanFunc}, + {name: "lessthanorequalto", fn: stdlib.LessThanOrEqualToFunc}, + {name: "log", fn: stdlib.LogFunc}, + {name: "lookup", fn: stdlib.LookupFunc}, + {name: "lower", fn: stdlib.LowerFunc}, + {name: "max", fn: stdlib.MaxFunc}, + {name: "md5", fn: crypto.Md5Func}, + {name: "merge", fn: stdlib.MergeFunc}, + {name: "min", fn: stdlib.MinFunc}, + {name: "modulo", fn: stdlib.ModuloFunc}, + {name: "multiply", fn: stdlib.MultiplyFunc}, + {name: "negate", fn: stdlib.NegateFunc}, + {name: "not", fn: stdlib.NotFunc}, + {name: "notequal", fn: stdlib.NotEqualFunc}, + {name: "or", fn: stdlib.OrFunc}, + {name: "parseint", fn: stdlib.ParseIntFunc}, + {name: "pow", fn: stdlib.PowFunc}, + {name: "range", fn: stdlib.RangeFunc}, + {name: "regex_replace", fn: stdlib.RegexReplaceFunc}, + {name: "regex", fn: stdlib.RegexFunc}, + {name: "regexall", fn: stdlib.RegexAllFunc}, + {name: "replace", fn: stdlib.ReplaceFunc}, + {name: "reverse", fn: stdlib.ReverseFunc}, + {name: "reverselist", fn: stdlib.ReverseListFunc}, + {name: "rsadecrypt", fn: crypto.RsaDecryptFunc}, + {name: "sethaselement", fn: stdlib.SetHasElementFunc}, + {name: "setintersection", fn: stdlib.SetIntersectionFunc}, + {name: "setproduct", fn: stdlib.SetProductFunc}, + {name: "setsubtract", fn: stdlib.SetSubtractFunc}, + {name: "setsymmetricdifference", fn: stdlib.SetSymmetricDifferenceFunc}, + {name: "setunion", fn: stdlib.SetUnionFunc}, + {name: "sha1", fn: crypto.Sha1Func}, + {name: "sha256", fn: crypto.Sha256Func}, + {name: "sha512", fn: crypto.Sha512Func}, + {name: "signum", fn: stdlib.SignumFunc}, + {name: "slice", fn: stdlib.SliceFunc}, + {name: "sort", fn: stdlib.SortFunc}, + {name: "split", fn: stdlib.SplitFunc}, + {name: "strlen", fn: stdlib.StrlenFunc}, + {name: "substr", fn: stdlib.SubstrFunc}, + {name: "subtract", fn: stdlib.SubtractFunc}, + {name: "timeadd", fn: stdlib.TimeAddFunc}, + {name: "timestamp", factory: timestampFunc}, + {name: "title", fn: stdlib.TitleFunc}, + {name: "trim", fn: stdlib.TrimFunc}, + {name: "trimprefix", fn: stdlib.TrimPrefixFunc}, + {name: "trimspace", fn: stdlib.TrimSpaceFunc}, + {name: "trimsuffix", fn: stdlib.TrimSuffixFunc}, + {name: "try", fn: tryfunc.TryFunc}, + {name: "upper", fn: stdlib.UpperFunc}, + {name: "urlencode", fn: encoding.URLEncodeFunc}, + {name: "uuidv4", fn: uuid.V4Func}, + {name: "uuidv5", fn: uuid.V5Func}, + {name: "values", fn: stdlib.ValuesFunc}, + {name: "zipmap", fn: stdlib.ZipmapFunc}, } // indexOfFunc constructs a function that finds the element index for a given // value in a list. -var indexOfFunc = function.New(&function.Spec{ - Params: []function.Parameter{ - { - Name: "list", - Type: cty.DynamicPseudoType, +func indexOfFunc() function.Function { + return function.New(&function.Spec{ + Params: []function.Parameter{ + { + Name: "list", + Type: cty.DynamicPseudoType, + }, + { + Name: "value", + Type: cty.DynamicPseudoType, + }, }, - { - Name: "value", - Type: cty.DynamicPseudoType, - }, - }, - Type: function.StaticReturnType(cty.Number), - Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { - if !(args[0].Type().IsListType() || args[0].Type().IsTupleType()) { - return cty.NilVal, errors.New("argument must be a list or tuple") - } - - if !args[0].IsKnown() { - return cty.UnknownVal(cty.Number), nil - } - - if args[0].LengthInt() == 0 { // Easy path - return cty.NilVal, errors.New("cannot search an empty list") - } - - for it := args[0].ElementIterator(); it.Next(); { - i, v := it.Element() - eq, err := stdlib.Equal(v, args[1]) - if err != nil { - return cty.NilVal, err + Type: function.StaticReturnType(cty.Number), + Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { + if !(args[0].Type().IsListType() || args[0].Type().IsTupleType()) { + return cty.NilVal, errors.New("argument must be a list or tuple") } - if !eq.IsKnown() { + + if !args[0].IsKnown() { return cty.UnknownVal(cty.Number), nil } - if eq.True() { - return i, nil - } - } - return cty.NilVal, errors.New("item not found") - }, -}) + if args[0].LengthInt() == 0 { // Easy path + return cty.NilVal, errors.New("cannot search an empty list") + } + + for it := args[0].ElementIterator(); it.Next(); { + i, v := it.Element() + eq, err := stdlib.Equal(v, args[1]) + if err != nil { + return cty.NilVal, err + } + if !eq.IsKnown() { + return cty.UnknownVal(cty.Number), nil + } + if eq.True() { + return i, nil + } + } + return cty.NilVal, errors.New("item not found") + + }, + }) +} // timestampFunc constructs a function that returns a string representation of the current date and time. // // This function was imported from terraform's datetime utilities. -var timestampFunc = function.New(&function.Spec{ - Params: []function.Parameter{}, - Type: function.StaticReturnType(cty.String), - Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) { - return cty.StringVal(time.Now().UTC().Format(time.RFC3339)), nil - }, -}) +func timestampFunc() function.Function { + return function.New(&function.Spec{ + Params: []function.Parameter{}, + Type: function.StaticReturnType(cty.String), + Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) { + return cty.StringVal(time.Now().UTC().Format(time.RFC3339)), nil + }, + }) +} func Stdlib() map[string]function.Function { funcs := make(map[string]function.Function, len(stdlibFunctions)) - for k, v := range stdlibFunctions { - funcs[k] = v + for _, v := range stdlibFunctions { + if v.factory != nil { + funcs[v.name] = v.factory() + } else { + funcs[v.name] = v.fn + } } return funcs } diff --git a/bake/hclparser/stdlib_test.go b/bake/hclparser/stdlib_test.go index df2336ce..4c933f74 100644 --- a/bake/hclparser/stdlib_test.go +++ b/bake/hclparser/stdlib_test.go @@ -34,7 +34,7 @@ func TestIndexOf(t *testing.T) { for name, test := range tests { name, test := name, test t.Run(name, func(t *testing.T) { - got, err := indexOfFunc.Call([]cty.Value{test.input, test.key}) + got, err := indexOfFunc().Call([]cty.Value{test.input, test.key}) if err != nil { if test.wantErr { return