diff --git a/bake/hclparser/stdlib.go b/bake/hclparser/stdlib.go index bbe4a748..b0e094f6 100644 --- a/bake/hclparser/stdlib.go +++ b/bake/hclparser/stdlib.go @@ -2,6 +2,8 @@ package hclparser import ( "errors" + "path" + "strings" "time" "github.com/hashicorp/go-cty-funcs/cidr" @@ -27,6 +29,7 @@ var stdlibFunctions = []funcDef{ {name: "and", fn: stdlib.AndFunc}, {name: "base64decode", fn: encoding.Base64DecodeFunc}, {name: "base64encode", fn: encoding.Base64EncodeFunc}, + {name: "basename", factory: basenameFunc}, {name: "bcrypt", fn: crypto.BcryptFunc}, {name: "byteslen", fn: stdlib.BytesLenFunc}, {name: "bytesslice", fn: stdlib.BytesSliceFunc}, @@ -45,6 +48,7 @@ var stdlibFunctions = []funcDef{ {name: "contains", fn: stdlib.ContainsFunc}, {name: "convert", fn: typeexpr.ConvertFunc}, {name: "csvdecode", fn: stdlib.CSVDecodeFunc}, + {name: "dirname", factory: dirnameFunc}, {name: "distinct", fn: stdlib.DistinctFunc}, {name: "divide", fn: stdlib.DivideFunc}, {name: "element", fn: stdlib.ElementFunc}, @@ -91,6 +95,7 @@ var stdlibFunctions = []funcDef{ {name: "reverse", fn: stdlib.ReverseFunc}, {name: "reverselist", fn: stdlib.ReverseListFunc}, {name: "rsadecrypt", fn: crypto.RsaDecryptFunc}, + {name: "sanitize", factory: sanitizeFunc}, {name: "sethaselement", fn: stdlib.SetHasElementFunc}, {name: "setintersection", fn: stdlib.SetIntersectionFunc}, {name: "setproduct", fn: stdlib.SetProductFunc}, @@ -170,6 +175,67 @@ func indexOfFunc() function.Function { }) } +// basenameFunc constructs a function that returns the last element of a path. +func basenameFunc() function.Function { + return function.New(&function.Spec{ + Params: []function.Parameter{ + { + Name: "path", + Type: cty.String, + }, + }, + Type: function.StaticReturnType(cty.String), + Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) { + in := args[0].AsString() + return cty.StringVal(path.Base(in)), nil + }, + }) +} + +// dirnameFunc constructs a function that returns the directory of a path. +func dirnameFunc() function.Function { + return function.New(&function.Spec{ + Params: []function.Parameter{ + { + Name: "path", + Type: cty.String, + }, + }, + Type: function.StaticReturnType(cty.String), + Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) { + in := args[0].AsString() + return cty.StringVal(path.Dir(in)), nil + }, + }) +} + +// sanitizyFunc constructs a function that replaces all non-alphanumeric characters with a underscore, +// leaving only characters that are valid for a Bake target name. +func sanitizeFunc() function.Function { + return function.New(&function.Spec{ + Params: []function.Parameter{ + { + Name: "name", + Type: cty.String, + }, + }, + Type: function.StaticReturnType(cty.String), + Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) { + in := args[0].AsString() + // only [a-zA-Z0-9_-]+ is allowed + var b strings.Builder + for _, r := range in { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '_' || r == '-' { + b.WriteRune(r) + } else { + b.WriteRune('_') + } + } + return cty.StringVal(b.String()), nil + }, + }) +} + // timestampFunc constructs a function that returns a string representation of the current date and time. // // This function was imported from terraform's datetime utilities. diff --git a/bake/hclparser/stdlib_test.go b/bake/hclparser/stdlib_test.go index 4c933f74..e0d3dd29 100644 --- a/bake/hclparser/stdlib_test.go +++ b/bake/hclparser/stdlib_test.go @@ -3,6 +3,7 @@ package hclparser import ( "testing" + "github.com/stretchr/testify/require" "github.com/zclconf/go-cty/cty" ) @@ -35,15 +36,164 @@ func TestIndexOf(t *testing.T) { name, test := name, test t.Run(name, func(t *testing.T) { got, err := indexOfFunc().Call([]cty.Value{test.input, test.key}) - if err != nil { - if test.wantErr { - return - } - t.Fatalf("unexpected error: %s", err) - } - if !got.RawEquals(test.want) { - t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.want) + if test.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.want, got) } }) } } + +func TestBasename(t *testing.T) { + type testCase struct { + input cty.Value + want cty.Value + wantErr bool + } + tests := map[string]testCase{ + "empty": { + input: cty.StringVal(""), + want: cty.StringVal("."), + }, + "slash": { + input: cty.StringVal("/"), + want: cty.StringVal("/"), + }, + "simple": { + input: cty.StringVal("/foo/bar"), + want: cty.StringVal("bar"), + }, + "simple no slash": { + input: cty.StringVal("foo/bar"), + want: cty.StringVal("bar"), + }, + "dot": { + input: cty.StringVal("/foo/bar."), + want: cty.StringVal("bar."), + }, + "dotdot": { + input: cty.StringVal("/foo/bar.."), + want: cty.StringVal("bar.."), + }, + "dotdotdot": { + input: cty.StringVal("/foo/bar..."), + want: cty.StringVal("bar..."), + }, + } + + for name, test := range tests { + name, test := name, test + t.Run(name, func(t *testing.T) { + got, err := basenameFunc().Call([]cty.Value{test.input}) + if test.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.want, got) + } + }) + } +} + +func TestDirname(t *testing.T) { + type testCase struct { + input cty.Value + want cty.Value + wantErr bool + } + tests := map[string]testCase{ + "empty": { + input: cty.StringVal(""), + want: cty.StringVal("."), + }, + "slash": { + input: cty.StringVal("/"), + want: cty.StringVal("/"), + }, + "simple": { + input: cty.StringVal("/foo/bar"), + want: cty.StringVal("/foo"), + }, + "simple no slash": { + input: cty.StringVal("foo/bar"), + want: cty.StringVal("foo"), + }, + "dot": { + input: cty.StringVal("/foo/bar."), + want: cty.StringVal("/foo"), + }, + "dotdot": { + input: cty.StringVal("/foo/bar.."), + want: cty.StringVal("/foo"), + }, + "dotdotdot": { + input: cty.StringVal("/foo/bar..."), + want: cty.StringVal("/foo"), + }, + } + + for name, test := range tests { + name, test := name, test + t.Run(name, func(t *testing.T) { + got, err := dirnameFunc().Call([]cty.Value{test.input}) + if test.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.want, got) + } + }) + } +} + +func TestSanitize(t *testing.T) { + type testCase struct { + input cty.Value + want cty.Value + } + tests := map[string]testCase{ + "empty": { + input: cty.StringVal(""), + want: cty.StringVal(""), + }, + "simple": { + input: cty.StringVal("foo/bar"), + want: cty.StringVal("foo_bar"), + }, + "simple no slash": { + input: cty.StringVal("foobar"), + want: cty.StringVal("foobar"), + }, + "dot": { + input: cty.StringVal("foo/bar."), + want: cty.StringVal("foo_bar_"), + }, + "dotdot": { + input: cty.StringVal("foo/bar.."), + want: cty.StringVal("foo_bar__"), + }, + "dotdotdot": { + input: cty.StringVal("foo/bar..."), + want: cty.StringVal("foo_bar___"), + }, + "utf8": { + input: cty.StringVal("foo/🍕bar"), + want: cty.StringVal("foo__bar"), + }, + "symbols": { + input: cty.StringVal("foo/bar!@(ba+z)"), + want: cty.StringVal("foo_bar___ba_z_"), + }, + } + + for name, test := range tests { + name, test := name, test + t.Run(name, func(t *testing.T) { + got, err := sanitizeFunc().Call([]cty.Value{test.input}) + require.NoError(t, err) + require.Equal(t, test.want, got) + }) + } +}