From e51cdcac505ebcd3b1a453da37806e5106dab413 Mon Sep 17 00:00:00 2001 From: CrazyMax <1951866+crazy-max@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:42:43 +0100 Subject: [PATCH] bake: basic variable validation Signed-off-by: CrazyMax <1951866+crazy-max@users.noreply.github.com> --- bake/bake_test.go | 152 ++++++++++++++++++++++++++++++++++++ bake/hclparser/hclparser.go | 46 +++++++++-- 2 files changed, 193 insertions(+), 5 deletions(-) diff --git a/bake/bake_test.go b/bake/bake_test.go index bb95499a..f07794a8 100644 --- a/bake/bake_test.go +++ b/bake/bake_test.go @@ -1856,3 +1856,155 @@ func TestNetNone(t *testing.T) { require.Len(t, bo["app"].Allow, 0) require.Equal(t, "none", bo["app"].NetworkMode) } + +func TestVariableValidation(t *testing.T) { + fp := File{ + Name: "docker-bake.hcl", + Data: []byte(` +variable "FOO" { + validation { + condition = FOO != "" + error_message = "FOO is required." + } +} +target "app" { + args = { + FOO = FOO + } +} +`), + } + + ctx := context.TODO() + + t.Run("Valid", func(t *testing.T) { + t.Setenv("FOO", "bar") + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.NoError(t, err) + }) + + t.Run("Invalid", func(t *testing.T) { + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "FOO is required.") + }) +} + +func TestVariableValidationMulti(t *testing.T) { + fp := File{ + Name: "docker-bake.hcl", + Data: []byte(` +variable "FOO" { + validation { + condition = FOO != "" + error_message = "FOO is required." + } + validation { + condition = strlen(FOO) > 4 + error_message = "FOO must be longer than 4 characters." + } +} +target "app" { + args = { + FOO = FOO + } +} +`), + } + + ctx := context.TODO() + + t.Run("Valid", func(t *testing.T) { + t.Setenv("FOO", "barbar") + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.NoError(t, err) + }) + + t.Run("InvalidLength", func(t *testing.T) { + t.Setenv("FOO", "bar") + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "FOO must be longer than 4 characters.") + }) + + t.Run("InvalidEmpty", func(t *testing.T) { + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "FOO is required.") + }) +} + +func TestVariableValidationWithDeps(t *testing.T) { + fp := File{ + Name: "docker-bake.hcl", + Data: []byte(` +variable "FOO" {} +variable "BAR" { + validation { + condition = FOO != "" + error_message = "BAR requires FOO to be set." + } +} +target "app" { + args = { + BAR = BAR + } +} +`), + } + + ctx := context.TODO() + + t.Run("Valid", func(t *testing.T) { + t.Setenv("FOO", "bar") + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.NoError(t, err) + }) + + t.Run("SetBar", func(t *testing.T) { + t.Setenv("FOO", "bar") + t.Setenv("BAR", "baz") + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.NoError(t, err) + }) + + t.Run("Invalid", func(t *testing.T) { + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "BAR requires FOO to be set.") + }) +} + +func TestVariableValidationTyped(t *testing.T) { + fp := File{ + Name: "docker-bake.hcl", + Data: []byte(` +variable "FOO" { + default = 0 + validation { + condition = FOO > 5 + error_message = "FOO must be greater than 5." + } +} +target "app" { + args = { + FOO = FOO + } +} +`), + } + + ctx := context.TODO() + + t.Run("Valid", func(t *testing.T) { + t.Setenv("FOO", "10") + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.NoError(t, err) + }) + + t.Run("Invalid", func(t *testing.T) { + _, _, err := ReadTargets(ctx, []File{fp}, []string{"app"}, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "FOO must be greater than 5.") + }) +} diff --git a/bake/hclparser/hclparser.go b/bake/hclparser/hclparser.go index fe7dc772..110a72b5 100644 --- a/bake/hclparser/hclparser.go +++ b/bake/hclparser/hclparser.go @@ -25,11 +25,17 @@ type Opt struct { } type variable struct { - Name string `json:"-" hcl:"name,label"` - Default *hcl.Attribute `json:"default,omitempty" hcl:"default,optional"` - Description string `json:"description,omitempty" hcl:"description,optional"` - Body hcl.Body `json:"-" hcl:",body"` - Remain hcl.Body `json:"-" hcl:",remain"` + Name string `json:"-" hcl:"name,label"` + Default *hcl.Attribute `json:"default,omitempty" hcl:"default,optional"` + Description string `json:"description,omitempty" hcl:"description,optional"` + Validations []*variableValidation `json:"validation,omitempty" hcl:"validation,block"` + Body hcl.Body `json:"-" hcl:",body"` + Remain hcl.Body `json:"-" hcl:",remain"` +} + +type variableValidation struct { + Condition hcl.Expression `json:"condition" hcl:"condition"` + ErrorMessage hcl.Expression `json:"error_message" hcl:"error_message"` } type functionDef struct { @@ -541,6 +547,33 @@ func (p *parser) resolveBlockNames(block *hcl.Block) ([]string, error) { return names, nil } +func (p *parser) validateVariables(vars map[string]*variable, ectx *hcl.EvalContext) hcl.Diagnostics { + var diags hcl.Diagnostics + for _, v := range vars { + for _, validation := range v.Validations { + condition, condDiags := validation.Condition.Value(ectx) + if condDiags.HasErrors() { + diags = append(diags, condDiags...) + continue + } + if !condition.True() { + message, msgDiags := validation.ErrorMessage.Value(ectx) + if msgDiags.HasErrors() { + diags = append(diags, msgDiags...) + continue + } + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Validation failed", + Detail: message.AsString(), + Subject: validation.Condition.Range().Ptr(), + }) + } + } + } + return diags +} + type Variable struct { Name string Description string @@ -686,6 +719,9 @@ func Parse(b hcl.Body, opt Opt, val interface{}) (*ParseMeta, hcl.Diagnostics) { } vars = append(vars, v) } + if diags := p.validateVariables(p.vars, p.ectx); diags.HasErrors() { + return nil, diags + } for k := range p.funcs { if err := p.resolveFunction(p.ectx, k); err != nil {