Merge pull request #2649 from tonistiigi/bake-path-stdlib-functions

bake: add basename, dirname and sanitize functions
This commit is contained in:
Tõnis Tiigi 2024-08-13 13:15:12 +03:00 committed by GitHub
commit 4787b5c046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 224 additions and 8 deletions

View File

@ -2,6 +2,8 @@ package hclparser
import (
"errors"
"path"
"strings"
"time"
"github.com/hashicorp/go-cty-funcs/cidr"
@ -27,6 +29,7 @@ var stdlibFunctions = []funcDef{
{name: "and", fn: stdlib.AndFunc},
{name: "base64decode", fn: encoding.Base64DecodeFunc},
{name: "base64encode", fn: encoding.Base64EncodeFunc},
{name: "basename", factory: basenameFunc},
{name: "bcrypt", fn: crypto.BcryptFunc},
{name: "byteslen", fn: stdlib.BytesLenFunc},
{name: "bytesslice", fn: stdlib.BytesSliceFunc},
@ -45,6 +48,7 @@ var stdlibFunctions = []funcDef{
{name: "contains", fn: stdlib.ContainsFunc},
{name: "convert", fn: typeexpr.ConvertFunc},
{name: "csvdecode", fn: stdlib.CSVDecodeFunc},
{name: "dirname", factory: dirnameFunc},
{name: "distinct", fn: stdlib.DistinctFunc},
{name: "divide", fn: stdlib.DivideFunc},
{name: "element", fn: stdlib.ElementFunc},
@ -91,6 +95,7 @@ var stdlibFunctions = []funcDef{
{name: "reverse", fn: stdlib.ReverseFunc},
{name: "reverselist", fn: stdlib.ReverseListFunc},
{name: "rsadecrypt", fn: crypto.RsaDecryptFunc},
{name: "sanitize", factory: sanitizeFunc},
{name: "sethaselement", fn: stdlib.SetHasElementFunc},
{name: "setintersection", fn: stdlib.SetIntersectionFunc},
{name: "setproduct", fn: stdlib.SetProductFunc},
@ -170,6 +175,67 @@ func indexOfFunc() function.Function {
})
}
// basenameFunc constructs a function that returns the last element of a path.
func basenameFunc() function.Function {
return function.New(&function.Spec{
Params: []function.Parameter{
{
Name: "path",
Type: cty.String,
},
},
Type: function.StaticReturnType(cty.String),
Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) {
in := args[0].AsString()
return cty.StringVal(path.Base(in)), nil
},
})
}
// dirnameFunc constructs a function that returns the directory of a path.
func dirnameFunc() function.Function {
return function.New(&function.Spec{
Params: []function.Parameter{
{
Name: "path",
Type: cty.String,
},
},
Type: function.StaticReturnType(cty.String),
Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) {
in := args[0].AsString()
return cty.StringVal(path.Dir(in)), nil
},
})
}
// sanitizyFunc constructs a function that replaces all non-alphanumeric characters with a underscore,
// leaving only characters that are valid for a Bake target name.
func sanitizeFunc() function.Function {
return function.New(&function.Spec{
Params: []function.Parameter{
{
Name: "name",
Type: cty.String,
},
},
Type: function.StaticReturnType(cty.String),
Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) {
in := args[0].AsString()
// only [a-zA-Z0-9_-]+ is allowed
var b strings.Builder
for _, r := range in {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '_' || r == '-' {
b.WriteRune(r)
} else {
b.WriteRune('_')
}
}
return cty.StringVal(b.String()), nil
},
})
}
// timestampFunc constructs a function that returns a string representation of the current date and time.
//
// This function was imported from terraform's datetime utilities.

View File

@ -3,6 +3,7 @@ package hclparser
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/zclconf/go-cty/cty"
)
@ -35,15 +36,164 @@ func TestIndexOf(t *testing.T) {
name, test := name, test
t.Run(name, func(t *testing.T) {
got, err := indexOfFunc().Call([]cty.Value{test.input, test.key})
if err != nil {
if test.wantErr {
return
}
t.Fatalf("unexpected error: %s", err)
}
if !got.RawEquals(test.want) {
t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.want)
if test.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.want, got)
}
})
}
}
func TestBasename(t *testing.T) {
type testCase struct {
input cty.Value
want cty.Value
wantErr bool
}
tests := map[string]testCase{
"empty": {
input: cty.StringVal(""),
want: cty.StringVal("."),
},
"slash": {
input: cty.StringVal("/"),
want: cty.StringVal("/"),
},
"simple": {
input: cty.StringVal("/foo/bar"),
want: cty.StringVal("bar"),
},
"simple no slash": {
input: cty.StringVal("foo/bar"),
want: cty.StringVal("bar"),
},
"dot": {
input: cty.StringVal("/foo/bar."),
want: cty.StringVal("bar."),
},
"dotdot": {
input: cty.StringVal("/foo/bar.."),
want: cty.StringVal("bar.."),
},
"dotdotdot": {
input: cty.StringVal("/foo/bar..."),
want: cty.StringVal("bar..."),
},
}
for name, test := range tests {
name, test := name, test
t.Run(name, func(t *testing.T) {
got, err := basenameFunc().Call([]cty.Value{test.input})
if test.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.want, got)
}
})
}
}
func TestDirname(t *testing.T) {
type testCase struct {
input cty.Value
want cty.Value
wantErr bool
}
tests := map[string]testCase{
"empty": {
input: cty.StringVal(""),
want: cty.StringVal("."),
},
"slash": {
input: cty.StringVal("/"),
want: cty.StringVal("/"),
},
"simple": {
input: cty.StringVal("/foo/bar"),
want: cty.StringVal("/foo"),
},
"simple no slash": {
input: cty.StringVal("foo/bar"),
want: cty.StringVal("foo"),
},
"dot": {
input: cty.StringVal("/foo/bar."),
want: cty.StringVal("/foo"),
},
"dotdot": {
input: cty.StringVal("/foo/bar.."),
want: cty.StringVal("/foo"),
},
"dotdotdot": {
input: cty.StringVal("/foo/bar..."),
want: cty.StringVal("/foo"),
},
}
for name, test := range tests {
name, test := name, test
t.Run(name, func(t *testing.T) {
got, err := dirnameFunc().Call([]cty.Value{test.input})
if test.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.want, got)
}
})
}
}
func TestSanitize(t *testing.T) {
type testCase struct {
input cty.Value
want cty.Value
}
tests := map[string]testCase{
"empty": {
input: cty.StringVal(""),
want: cty.StringVal(""),
},
"simple": {
input: cty.StringVal("foo/bar"),
want: cty.StringVal("foo_bar"),
},
"simple no slash": {
input: cty.StringVal("foobar"),
want: cty.StringVal("foobar"),
},
"dot": {
input: cty.StringVal("foo/bar."),
want: cty.StringVal("foo_bar_"),
},
"dotdot": {
input: cty.StringVal("foo/bar.."),
want: cty.StringVal("foo_bar__"),
},
"dotdotdot": {
input: cty.StringVal("foo/bar..."),
want: cty.StringVal("foo_bar___"),
},
"utf8": {
input: cty.StringVal("foo/🍕bar"),
want: cty.StringVal("foo__bar"),
},
"symbols": {
input: cty.StringVal("foo/bar!@(ba+z)"),
want: cty.StringVal("foo_bar___ba_z_"),
},
}
for name, test := range tests {
name, test := name, test
t.Run(name, func(t *testing.T) {
got, err := sanitizeFunc().Call([]cty.Value{test.input})
require.NoError(t, err)
require.Equal(t, test.want, got)
})
}
}