From 93addfb88eebfc9dd89d84380e40e0a882c4885f Mon Sep 17 00:00:00 2001 From: liujian <54946465+redscholar@users.noreply.github.com> Date: Mon, 10 Mar 2025 19:10:53 +0800 Subject: [PATCH] feat: Adding generic methods to template parsing (#2503) Signed-off-by: joyceliu Co-authored-by: joyceliu --- api/project/v1/conditional.go | 19 ++++-- api/project/v1/playbook.go | 12 ++++ pkg/const/common.go | 3 - pkg/converter/tmpl/template.go | 86 ++++++++++++++------------- pkg/converter/tmpl/template_test.go | 92 ++++++++++++++--------------- pkg/executor/task_executor.go | 8 +-- pkg/modules/assert.go | 20 ++++--- pkg/modules/assert_test.go | 8 +-- pkg/modules/copy.go | 61 ++++++++----------- pkg/modules/debug.go | 4 +- pkg/modules/set_fact.go | 9 ++- pkg/modules/template.go | 73 +++++++++-------------- pkg/variable/helper.go | 31 ++++------ pkg/variable/helper_test.go | 2 +- pkg/variable/variable_get.go | 2 +- pkg/variable/variable_merge.go | 73 ++++++++--------------- 16 files changed, 237 insertions(+), 266 deletions(-) diff --git a/api/project/v1/conditional.go b/api/project/v1/conditional.go index d1f656d3..638c1429 100644 --- a/api/project/v1/conditional.go +++ b/api/project/v1/conditional.go @@ -36,12 +36,23 @@ type When struct { func (w *When) UnmarshalYAML(node *yaml.Node) error { switch node.Kind { case yaml.ScalarNode: - w.Data = []string{node.Value} - - return nil + if IsTmplSyntax(node.Value) { + w.Data = []string{node.Value} + } else { + w.Data = []string{"{{ " + node.Value + " }}"} + } case yaml.SequenceNode: - return node.Decode(&w.Data) + if err := node.Decode(&w.Data); err != nil { + return err + } + for i, v := range w.Data { + if !IsTmplSyntax(v) { + w.Data[i] = ParseTmplSyntax(node.Value) + } + } default: return errors.New("unsupported type, excepted string or array of strings") } + + return nil } diff --git a/api/project/v1/playbook.go b/api/project/v1/playbook.go index 2f5f28fd..07521af8 100644 --- a/api/project/v1/playbook.go +++ b/api/project/v1/playbook.go @@ -18,6 +18,7 @@ package v1 import ( "errors" + "strings" ) // NOTE: @@ -47,3 +48,14 @@ func (p *Playbook) Validate() error { return nil } + +// IsTmplSyntax Check if the string conforms to the template syntax. +func IsTmplSyntax(s string) bool { + return strings.Contains(s, "{{") && strings.Contains(s, "}}") +} + +// ParseTmplSyntax wraps a string with template syntax delimiters "{{" and "}}" +// to make it a valid Go template expression +func ParseTmplSyntax(s string) string { + return "{{ " + s + "}}" +} diff --git a/pkg/const/common.go b/pkg/const/common.go index a3f9a084..5f151570 100644 --- a/pkg/const/common.go +++ b/pkg/const/common.go @@ -83,9 +83,6 @@ const ( // === From runtime === ) const ( // === From env === - // ENV_VARIABLE_PARSE_DEPTH Defines the depth of parameter transformation, specifying the number of levels in which parameters can reference other unprocessed parameters. - // The default value is 3. - ENV_VARIABLE_PARSE_DEPTH = "VARIABLE_PARSE_DEPTH" // ENV_SHELL which shell operator use in local connector. ENV_SHELL = "SHELL" // ENV_EXECUTOR_IMAGE which image use in pipeline pod. diff --git a/pkg/converter/tmpl/template.go b/pkg/converter/tmpl/template.go index 1fdeb5a7..7f22ed0b 100644 --- a/pkg/converter/tmpl/template.go +++ b/pkg/converter/tmpl/template.go @@ -19,59 +19,63 @@ package tmpl import ( "bytes" "fmt" - "strings" + kkprojectv1 "github.com/kubesphere/kubekey/api/project/v1" "k8s.io/klog/v2" "github.com/kubesphere/kubekey/v4/pkg/converter/internal" ) +// ParseFunc parses a template string using the provided context and parse function. +// It takes a context map C, an input string that may contain template syntax, +// and a parse function that converts the template result to the desired Output type. +// If the input is not a template, it directly applies the parse function. +// For template inputs, it parses and executes the template with the context, +// then applies the parse function to the result. +// Returns the parsed output and any error encountered during template processing. +func ParseFunc[C ~map[string]any, Output any](ctx C, input string, f func([]byte) Output) (Output, error) { + // If input doesn't contain template syntax, return directly + if !kkprojectv1.IsTmplSyntax(input) { + return f([]byte(input)), nil + } + // Parse the template string + tl, err := internal.Template.Parse(input) + if err != nil { + return f(nil), fmt.Errorf("failed to parse template '%s': %w", input, err) + } + // Execute template with provided context + result := bytes.NewBuffer(nil) + if err := tl.Execute(result, ctx); err != nil { + return f(nil), fmt.Errorf("failed to execute template '%s': %w", input, err) + } + // Log successful parsing + klog.V(6).InfoS(" parse template succeed", "result", result.String()) + + // Apply parse function to result and return + return f(result.Bytes()), nil +} + +// Parse is a helper function that wraps ParseFunc to directly return bytes. +// It takes a context map C and input string, and returns the parsed bytes. +func Parse[C ~map[string]any](ctx C, input string) ([]byte, error) { + return ParseFunc(ctx, input, func(o []byte) []byte { + return o + }) +} + // ParseBool parse template string to bool -func ParseBool(ctx map[string]any, inputs []string) (bool, error) { +func ParseBool(ctx map[string]any, inputs ...string) (bool, error) { for _, input := range inputs { - if !IsTmplSyntax(input) { - input = "{{ " + input + " }}" - } - - tl, err := internal.Template.Parse(input) + output, err := ParseFunc(ctx, input, func(o []byte) bool { + return bytes.EqualFold(o, []byte("true")) + }) if err != nil { - return false, fmt.Errorf("failed to parse template '%s': %w", input, err) + return false, err } - - result := bytes.NewBuffer(nil) - if err := tl.Execute(result, ctx); err != nil { - return false, fmt.Errorf("failed to execute template '%s': %w", input, err) - } - klog.V(6).InfoS(" parse template succeed", "result", result.String()) - if result.String() != "true" { - return false, nil + if !output { + return output, nil } } return true, nil } - -// ParseString parse template string to actual string -func ParseString(ctx map[string]any, input string) (string, error) { - if !IsTmplSyntax(input) { - return strings.Trim(input, "\r\n"), nil - } - - tl, err := internal.Template.Parse(input) - if err != nil { - return "", fmt.Errorf("failed to parse template '%s': %w", input, err) - } - - result := bytes.NewBuffer(nil) - if err := tl.Execute(result, ctx); err != nil { - return "", fmt.Errorf("failed to execute template '%s': %w", input, err) - } - klog.V(6).InfoS(" parse template succeed", "result", result.String()) - - return strings.Trim(result.String(), "\r\n"), nil -} - -// IsTmplSyntax Check if the string conforms to the template syntax. -func IsTmplSyntax(s string) bool { - return strings.Contains(s, "{{") && strings.Contains(s, "}}") -} diff --git a/pkg/converter/tmpl/template_test.go b/pkg/converter/tmpl/template_test.go index 29f9496e..5d183089 100644 --- a/pkg/converter/tmpl/template_test.go +++ b/pkg/converter/tmpl/template_test.go @@ -181,7 +181,7 @@ func TestParseBool(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - b, err := ParseBool(tc.variable, tc.condition) + b, err := ParseBool(tc.variable, tc.condition...) if err != nil { t.Fatal(err) } @@ -195,7 +195,7 @@ func TestParseValue(t *testing.T) { name string input string variable map[string]any - excepted string + excepted []byte }{ { name: "single level", @@ -203,7 +203,7 @@ func TestParseValue(t *testing.T) { variable: map[string]any{ "foo": "bar", }, - excepted: "bar", + excepted: []byte("bar"), }, { name: "multi level 1", @@ -213,7 +213,7 @@ func TestParseValue(t *testing.T) { "foo": "bar", }, }, - excepted: "bar", + excepted: []byte("bar"), }, { name: "multi level 2", @@ -223,7 +223,7 @@ func TestParseValue(t *testing.T) { "foo": "bar", }, }, - excepted: "bar", + excepted: []byte("bar"), }, { name: "multi level 2", @@ -233,7 +233,7 @@ func TestParseValue(t *testing.T) { "foo": "bar", }, }, - excepted: "bar", + excepted: []byte("bar"), }, { name: "multi level 3", @@ -247,7 +247,7 @@ func TestParseValue(t *testing.T) { }, }, }, - excepted: "bar", + excepted: []byte("bar"), }, { name: "exist value", @@ -255,13 +255,13 @@ func TestParseValue(t *testing.T) { variable: map[string]any{ "foo": "bar", }, - excepted: "bar", + excepted: []byte("bar"), }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - output, _ := ParseString(tc.variable, tc.input) + output, _ := Parse(tc.variable, tc.input) assert.Equal(t, tc.excepted, output) }) } @@ -272,7 +272,7 @@ func TestParseFunction(t *testing.T) { name string input string variable map[string]any - excepted string + excepted []byte }{ // ======= if ======= { @@ -286,7 +286,7 @@ func TestParseFunction(t *testing.T) { }, }, }, - excepted: "bar2", + excepted: []byte("bar2"), }, { name: "if map 1", @@ -299,7 +299,7 @@ func TestParseFunction(t *testing.T) { }, }, }, - excepted: "bar2", + excepted: []byte("bar2"), }, // ======= range ======= { @@ -313,7 +313,7 @@ func TestParseFunction(t *testing.T) { }, }, }, - excepted: "bar1bar2", + excepted: []byte("bar1bar2"), }, { name: "range map value 1", @@ -326,7 +326,7 @@ func TestParseFunction(t *testing.T) { }, }, }, - excepted: "bar1", + excepted: []byte("bar1"), }, { name: "range map top-value 1", @@ -340,7 +340,7 @@ func TestParseFunction(t *testing.T) { }, "foo1": "bar11", }, - excepted: "bar11", + excepted: []byte("bar11"), }, { name: "range slice value 1", @@ -353,7 +353,7 @@ func TestParseFunction(t *testing.T) { }, }, }, - excepted: "bar1", + excepted: []byte("bar1"), }, { name: "range slice value 1", @@ -363,27 +363,27 @@ func TestParseFunction(t *testing.T) { "foo1", "bar1", }, }, - excepted: "foo1bar1", + excepted: []byte("foo1bar1"), }, // ======= default ======= { name: "default string 1", input: "{{ .foo | default \"bar\" }}", variable: make(map[string]any), - excepted: "bar", + excepted: []byte("bar"), }, { name: "default string 2", input: "{{ default .foo \"bar\" }}", variable: make(map[string]any), - excepted: "bar", + excepted: []byte("bar"), }, { name: "default number 1", input: "{{ .foo | default 1 }}", variable: make(map[string]any), - excepted: "1", + excepted: []byte("1"), }, // ======= split ======= { @@ -392,7 +392,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": "a,b", }, - excepted: "map[_0:a _1:b]", + excepted: []byte("map[_0:a _1:b]"), }, { name: "split 2", @@ -400,7 +400,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": "a,b", }, - excepted: "map[_0:a _1:b]", + excepted: []byte("map[_0:a _1:b]"), }, // ======= len ======= { @@ -409,7 +409,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "2", + excepted: []byte("2"), }, { name: "len 2", @@ -417,7 +417,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "2", + excepted: []byte("2"), }, // ======= index ======= { @@ -428,7 +428,7 @@ func TestParseFunction(t *testing.T) { "foo": "a", }, }, - excepted: "a", + excepted: []byte("a"), }, { name: "index 2", @@ -438,7 +438,7 @@ func TestParseFunction(t *testing.T) { "foo": "a", }, }, - excepted: "false", + excepted: []byte("false"), }, { name: "index 3", @@ -450,7 +450,7 @@ func TestParseFunction(t *testing.T) { }, }, }, - excepted: "b", + excepted: []byte("b"), }, // ======= first ======= { @@ -459,7 +459,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "a", + excepted: []byte("a"), }, { name: "first 2", @@ -467,7 +467,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "a", + excepted: []byte("a"), }, // ======= last ======= { @@ -476,7 +476,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "b", + excepted: []byte("b"), }, { name: "last 2", @@ -484,7 +484,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "b", + excepted: []byte("b"), }, // ======= slice ======= { @@ -493,7 +493,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "[a b]", + excepted: []byte("[a b]"), }, // ======= join ======= { @@ -502,7 +502,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "a.b", + excepted: []byte("a.b"), }, // ======= toJson ======= { @@ -511,7 +511,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": []string{"a", "b"}, }, - excepted: "[\"a\",\"b\"]", + excepted: []byte("[\"a\",\"b\"]"), }, // ======= toYaml ======= { @@ -523,7 +523,7 @@ func TestParseFunction(t *testing.T) { "a2": "b2", }, }, - excepted: "a1: b1\na2: b2", + excepted: []byte("a1: b1\na2: b2"), }, // ======= indent ======= { @@ -532,7 +532,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": "a1: b1\na2: b2", }, - excepted: " a1: b1\n a2: b2", + excepted: []byte(" a1: b1\n a2: b2"), }, // ======= printf ======= { @@ -541,7 +541,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": "a", }, - excepted: "http://a", + excepted: []byte("http://a"), }, { name: "printf 2", @@ -549,7 +549,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": "a", }, - excepted: "http://a", + excepted: []byte("http://a"), }, // ======= div ======= @@ -559,7 +559,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": 5, }, - excepted: "1", + excepted: []byte("1"), }, { name: "div 1", @@ -567,7 +567,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": 4, }, - excepted: "0", + excepted: []byte("0"), }, // ======= sub ======= { @@ -576,7 +576,7 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": 5, }, - excepted: "3", + excepted: []byte("3"), }, // ======= trimPrefix ======= { @@ -585,13 +585,13 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": "v1.1", }, - excepted: "1.1", + excepted: []byte("1.1"), }, { name: "trimPrefix 2", input: `{{ .foo | default "" |trimPrefix "v" }}`, variable: make(map[string]any), - excepted: "", + excepted: nil, }, // ======= fromJson ======= { @@ -600,13 +600,13 @@ func TestParseFunction(t *testing.T) { variable: map[string]any{ "foo": "[\"a\",\"b\"]", }, - excepted: "a", + excepted: []byte("a"), }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - output, err := ParseString(tc.variable, tc.input) + output, err := Parse(tc.variable, tc.input) if err != nil { t.Fatal(err) } @@ -671,7 +671,7 @@ func TestParseCustomFunction(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - output, err := ParseString(tc.variable, tc.input) + output, err := ParseFunc(tc.variable, tc.input, func(b []byte) string { return string(b) }) if err != nil { t.Fatal(err) } diff --git a/pkg/executor/task_executor.go b/pkg/executor/task_executor.go index 2651f89e..821e78e0 100644 --- a/pkg/executor/task_executor.go +++ b/pkg/executor/task_executor.go @@ -2,6 +2,7 @@ package executor import ( "context" + "encoding/json" "fmt" "os" "strings" @@ -12,7 +13,6 @@ import ( kkcorev1alpha1 "github.com/kubesphere/kubekey/api/core/v1alpha1" "github.com/schollz/progressbar/v3" "gopkg.in/yaml.v3" - "k8s.io/apimachinery/pkg/util/json" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" @@ -337,7 +337,7 @@ func (e *taskExecutor) dealLoop(ha map[string]any) []any { // dealWhen "when" argument in task. func (e *taskExecutor) dealWhen(had map[string]any, stdout, stderr *string) bool { if len(e.task.Spec.When) > 0 { - ok, err := tmpl.ParseBool(had, e.task.Spec.When) + ok, err := tmpl.ParseBool(had, e.task.Spec.When...) if err != nil { klog.V(5).ErrorS(err, "validate when condition error", "task", ctrlclient.ObjectKeyFromObject(e.task)) *stderr = fmt.Sprintf("parse when condition error: %v", err) @@ -357,7 +357,7 @@ func (e *taskExecutor) dealWhen(had map[string]any, stdout, stderr *string) bool // dealFailedWhen "failed_when" argument in task. func (e *taskExecutor) dealFailedWhen(had map[string]any, stdout, stderr *string) bool { if len(e.task.Spec.FailedWhen) > 0 { - ok, err := tmpl.ParseBool(had, e.task.Spec.FailedWhen) + ok, err := tmpl.ParseBool(had, e.task.Spec.FailedWhen...) if err != nil { klog.V(5).ErrorS(err, "validate failed_when condition error", "task", ctrlclient.ObjectKeyFromObject(e.task)) *stderr = fmt.Sprintf("parse failed_when condition error: %v", err) @@ -381,7 +381,7 @@ func (e *taskExecutor) dealRegister(stdout, stderr, host string) error { var stdoutResult any = stdout var stderrResult any = stderr // try to convert by json or yaml - if (strings.HasPrefix(stdout, "{") || strings.HasPrefix(stdout, "[")) && (strings.HasSuffix(stdout, "}") || strings.HasSuffix(stdout, "]")) { + if json.Valid([]byte(stdout)) { _ = json.Unmarshal([]byte(stdout), &stdoutResult) _ = json.Unmarshal([]byte(stderr), &stderrResult) } else { diff --git a/pkg/modules/assert.go b/pkg/modules/assert.go index a19b185d..bdddf381 100644 --- a/pkg/modules/assert.go +++ b/pkg/modules/assert.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + kkprojectv1 "github.com/kubesphere/kubekey/api/project/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/klog/v2" ctrlclient "sigs.k8s.io/controller-runtime/pkg/client" @@ -43,6 +44,11 @@ func newAssertArgs(_ context.Context, raw runtime.RawExtension, vars map[string] if aa.that, err = variable.StringSliceVar(vars, args, "that"); err != nil { return nil, errors.New("\"that\" should be []string or string") } + for i, s := range aa.that { + if !kkprojectv1.IsTmplSyntax(s) { + aa.that[i] = kkprojectv1.ParseTmplSyntax(s) + } + } aa.successMsg, _ = variable.StringVar(vars, args, "success_msg") if aa.successMsg == "" { aa.successMsg = StdoutTrue @@ -71,15 +77,15 @@ func ModuleAssert(ctx context.Context, options ExecOptions) (string, string) { return "", err.Error() } - ok, err := tmpl.ParseBool(ha, aa.that) + ok, err := tmpl.ParseBool(ha, aa.that...) if err != nil { return "", fmt.Sprintf("parse \"that\" error: %v", err) } // condition is true if ok { - r, err := tmpl.ParseString(ha, aa.successMsg) + r, err := tmpl.Parse(ha, aa.successMsg) if err == nil { - return r, "" + return string(r), "" } klog.V(4).ErrorS(err, "parse \"success_msg\" error", "task", ctrlclient.ObjectKeyFromObject(&options.Task)) @@ -87,17 +93,17 @@ func ModuleAssert(ctx context.Context, options ExecOptions) (string, string) { } // condition is false and fail_msg is not empty if aa.failMsg != "" { - r, err := tmpl.ParseString(ha, aa.failMsg) + r, err := tmpl.Parse(ha, aa.failMsg) if err == nil { - return StdoutFalse, r + return StdoutFalse, string(r) } klog.V(4).ErrorS(err, "parse \"fail_msg\" error", "task", ctrlclient.ObjectKeyFromObject(&options.Task)) } // condition is false and msg is not empty if aa.msg != "" { - r, err := tmpl.ParseString(ha, aa.msg) + r, err := tmpl.Parse(ha, aa.msg) if err == nil { - return StdoutFalse, r + return StdoutFalse, string(r) } klog.V(4).ErrorS(err, "parse \"msg\" error", "task", ctrlclient.ObjectKeyFromObject(&options.Task)) } diff --git a/pkg/modules/assert_test.go b/pkg/modules/assert_test.go index 2a1afaa6..e0bc8d1d 100644 --- a/pkg/modules/assert_test.go +++ b/pkg/modules/assert_test.go @@ -46,7 +46,7 @@ func TestAssert(t *testing.T) { opt: ExecOptions{ Host: "local", Args: runtime.RawExtension{ - Raw: []byte(`{"that": ["true", "eq .testvalue \"a\""]}`), + Raw: []byte(`{"that": ["true", "{{ eq .testvalue \"a\" }}"]}`), }, Variable: &testVariable{ value: map[string]any{ @@ -61,7 +61,7 @@ func TestAssert(t *testing.T) { opt: ExecOptions{ Host: "local", Args: runtime.RawExtension{ - Raw: []byte(`{"that": ["true", "eq .k1 \"v1\""], "success_msg": "success {{ .k2 }}"}`), + Raw: []byte(`{"that": ["true", "{{ eq .k1 \"v1\" }}"], "success_msg": "success {{ .k2 }}"}`), }, Variable: &testVariable{ value: map[string]any{ @@ -77,7 +77,7 @@ func TestAssert(t *testing.T) { opt: ExecOptions{ Host: "local", Args: runtime.RawExtension{ - Raw: []byte(`{"that": ["true", "eq .k1 \"v2\""]}`), + Raw: []byte(`{"that": ["true", "{{ eq .k1 \"v2\" }}"]}`), }, Variable: &testVariable{ value: map[string]any{ @@ -94,7 +94,7 @@ func TestAssert(t *testing.T) { opt: ExecOptions{ Host: "local", Args: runtime.RawExtension{ - Raw: []byte(`{"that": ["true", "eq .k1 \"v2\""], "fail_msg": "failed {{ .k2 }}"}`), + Raw: []byte(`{"that": ["true", "{{ eq .k1 \"v2\" }}"], "fail_msg": "failed {{ .k2 }}"}`), }, Variable: &testVariable{ value: map[string]any{ diff --git a/pkg/modules/copy.go b/pkg/modules/copy.go index a425ac9b..939ecc4d 100644 --- a/pkg/modules/copy.go +++ b/pkg/modules/copy.go @@ -101,44 +101,57 @@ func ModuleCopy(ctx context.Context, options ExecOptions) (string, string) { // copySrc copy src file to dest func (ca copyArgs) copySrc(ctx context.Context, options ExecOptions, conn connector.Connector) (string, string) { - if filepath.IsAbs(ca.src) { // if src is absolute path. find it in local path + dealAbsoluteFilePath := func() (string, string) { fileInfo, err := os.Stat(ca.src) if err != nil { return "", fmt.Sprintf(" get src file %s in local path error: %v", ca.src, err) } - if fileInfo.IsDir() { // src is dir if err := ca.absDir(ctx, conn); err != nil { return "", fmt.Sprintf("sync copy absolute dir error %s", err) } } else { // src is file - if err := ca.absFile(ctx, fileInfo.Mode(), conn); err != nil { + data, err := os.ReadFile(ca.src) + if err != nil { + return "", fmt.Sprintf("read file error: %s", err) + } + if err := ca.readFile(ctx, data, fileInfo.Mode(), conn); err != nil { return "", fmt.Sprintf("sync copy absolute dir error %s", err) } } - } else { // if src is not absolute path. find file in project + + return StdoutSuccess, "" + } + dealRelativeFilePath := func() (string, string) { pj, err := project.New(ctx, options.Pipeline, false) if err != nil { return "", fmt.Sprintf("get project error: %v", err) } - fileInfo, err := pj.Stat(ca.src, project.GetFileOption{IsFile: true, Role: options.Task.Annotations[kkcorev1alpha1.TaskAnnotationRole]}) if err != nil { return "", fmt.Sprintf("get file %s from project error %v", ca.src, err) } - if fileInfo.IsDir() { if err := ca.relDir(ctx, pj, options.Task.Annotations[kkcorev1alpha1.TaskAnnotationRole], conn); err != nil { return "", fmt.Sprintf("sync copy relative dir error %s", err) } } else { - if err := ca.relFile(ctx, pj, options.Task.Annotations[kkcorev1alpha1.TaskAnnotationRole], fileInfo.Mode(), conn); err != nil { + data, err := pj.ReadFile(ca.src, project.GetFileOption{IsFile: true, Role: options.Task.Annotations[kkcorev1alpha1.TaskAnnotationRole]}) + if err != nil { + return "", fmt.Sprintf("read file error: %s", err) + } + if err := ca.readFile(ctx, data, fileInfo.Mode(), conn); err != nil { return "", fmt.Sprintf("sync copy relative dir error %s", err) } } - } - return StdoutSuccess, "" + return StdoutSuccess, "" + } + if filepath.IsAbs(ca.src) { // if src is absolute path. find it in local path + return dealAbsoluteFilePath() + } + // if src is not absolute path. find file in project + return dealRelativeFilePath() } // copyContent convert content param and copy to dest @@ -158,29 +171,6 @@ func (ca copyArgs) copyContent(ctx context.Context, mode fs.FileMode, conn conne return StdoutSuccess, "" } -// relFile when copy.src is relative dir, get all files from project, and copy to remote. -func (ca copyArgs) relFile(ctx context.Context, pj project.Project, role string, mode fs.FileMode, conn connector.Connector) any { - data, err := pj.ReadFile(ca.src, project.GetFileOption{IsFile: true, Role: role}) - if err != nil { - return fmt.Errorf("read file error: %w", err) - } - - dest := ca.dest - if strings.HasSuffix(ca.dest, "/") { - dest = filepath.Join(ca.dest, filepath.Base(ca.src)) - } - - if ca.mode != nil { - mode = os.FileMode(*ca.mode) - } - - if err := conn.PutFile(ctx, data, dest, mode); err != nil { - return fmt.Errorf("copy file error: %w", err) - } - - return nil -} - // relDir when copy.src is relative dir, get all files from project, and copy to remote. func (ca copyArgs) relDir(ctx context.Context, pj project.Project, role string, conn connector.Connector) error { if err := pj.WalkDir(ca.src, project.GetFileOption{IsFile: true, Role: role}, func(path string, d fs.DirEntry, err error) error { @@ -228,12 +218,7 @@ func (ca copyArgs) relDir(ctx context.Context, pj project.Project, role string, } // absFile when copy.src is absolute file, get file from os, and copy to remote. -func (ca copyArgs) absFile(ctx context.Context, mode fs.FileMode, conn connector.Connector) error { - data, err := os.ReadFile(ca.src) - if err != nil { - return fmt.Errorf("read file error: %w", err) - } - +func (ca copyArgs) readFile(ctx context.Context, data []byte, mode fs.FileMode, conn connector.Connector) error { dest := ca.dest if strings.HasSuffix(ca.dest, "/") { dest = filepath.Join(ca.dest, filepath.Base(ca.src)) diff --git a/pkg/modules/debug.go b/pkg/modules/debug.go index 4a260677..e8de3a55 100644 --- a/pkg/modules/debug.go +++ b/pkg/modules/debug.go @@ -35,12 +35,12 @@ func ModuleDebug(_ context.Context, options ExecOptions) (string, string) { args := variable.Extension2Variables(options.Args) // var is defined. return the value of var if varParam, err := variable.StringVar(ha, args, "var"); err == nil { - result, err := tmpl.ParseString(ha, fmt.Sprintf("{{ %s }}", varParam)) + result, err := tmpl.Parse(ha, fmt.Sprintf("{{ %s }}", varParam)) if err != nil { return "", fmt.Sprintf("failed to parse var: %v", err) } - return result, "" + return string(result), "" } // msg is defined. return the actual msg if msgParam, err := variable.StringVar(ha, args, "msg"); err == nil { diff --git a/pkg/modules/set_fact.go b/pkg/modules/set_fact.go index d89ba65c..0487e4c2 100644 --- a/pkg/modules/set_fact.go +++ b/pkg/modules/set_fact.go @@ -20,7 +20,6 @@ import ( "context" "encoding/json" "fmt" - "strings" "gopkg.in/yaml.v2" @@ -41,15 +40,15 @@ func ModuleSetFact(_ context.Context, options ExecOptions) (string, string) { case bool, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: args[k] = val case string: - sv, err := tmpl.ParseString(ha, val) + sv, err := tmpl.Parse(ha, val) if err != nil { return "", fmt.Sprintf("parse %q error: %v", k, err) } var ssvResult any - if (strings.HasPrefix(sv, "{") || strings.HasPrefix(sv, "[")) && (strings.HasSuffix(sv, "}") || strings.HasSuffix(sv, "]")) { - _ = json.Unmarshal([]byte(sv), &ssvResult) + if json.Valid(sv) { + _ = json.Unmarshal(sv, &ssvResult) } else { - _ = yaml.Unmarshal([]byte(sv), &ssvResult) + _ = yaml.Unmarshal(sv, &ssvResult) } args[k] = ssvResult default: diff --git a/pkg/modules/template.go b/pkg/modules/template.go index fb3b7eb7..45791502 100644 --- a/pkg/modules/template.go +++ b/pkg/modules/template.go @@ -97,7 +97,7 @@ func ModuleTemplate(ctx context.Context, options ExecOptions) (string, string) { } defer conn.Close(ctx) - if filepath.IsAbs(ta.src) { + dealAbsoluteFilePath := func() (string, string) { fileInfo, err := os.Stat(ta.src) if err != nil { return "", fmt.Sprintf(" get src file %s in local path error: %v", ta.src, err) @@ -108,11 +108,18 @@ func ModuleTemplate(ctx context.Context, options ExecOptions) (string, string) { return "", fmt.Sprintf("sync template absolute dir error %s", err) } } else { // src is file - if err := ta.absFile(ctx, fileInfo.Mode(), conn, ha); err != nil { + data, err := os.ReadFile(ta.src) + if err != nil { + return "", fmt.Sprintf("read file error: %s", err) + } + if err := ta.readFile(ctx, string(data), fileInfo.Mode(), conn, ha); err != nil { return "", fmt.Sprintf("sync template absolute file error %s", err) } } - } else { + + return StdoutSuccess, "" + } + dealRelativeFilePath := func() (string, string) { pj, err := project.New(ctx, options.Pipeline, false) if err != nil { return "", fmt.Sprintf("get project error: %v", err) @@ -128,23 +135,27 @@ func ModuleTemplate(ctx context.Context, options ExecOptions) (string, string) { return "", fmt.Sprintf("sync template relative dir error: %s", err) } } else { - if err := ta.relFile(ctx, pj, options.Task.Annotations[kkcorev1alpha1.TaskAnnotationRole], fileInfo.Mode(), conn, ha); err != nil { + data, err := pj.ReadFile(ta.src, project.GetFileOption{IsTemplate: true, Role: options.Task.Annotations[kkcorev1alpha1.TaskAnnotationRole]}) + if err != nil { + return "", fmt.Sprintf("read file error: %s", err) + } + if err := ta.readFile(ctx, string(data), fileInfo.Mode(), conn, ha); err != nil { return "", fmt.Sprintf("sync template relative dir error: %s", err) } } + + return StdoutSuccess, "" + } + if filepath.IsAbs(ta.src) { + return dealAbsoluteFilePath() } - return StdoutSuccess, "" + return dealRelativeFilePath() } // relFile when template.src is relative file, get file from project, parse it, and copy to remote. -func (ta templateArgs) relFile(ctx context.Context, pj project.Project, role string, mode fs.FileMode, conn connector.Connector, vars map[string]any) any { - data, err := pj.ReadFile(ta.src, project.GetFileOption{IsTemplate: true, Role: role}) - if err != nil { - return fmt.Errorf("read file error: %w", err) - } - - result, err := tmpl.ParseString(vars, string(data)) +func (ta templateArgs) readFile(ctx context.Context, data string, mode fs.FileMode, conn connector.Connector, vars map[string]any) any { + result, err := tmpl.Parse(vars, data) if err != nil { return fmt.Errorf("parse file error: %w", err) } @@ -158,7 +169,7 @@ func (ta templateArgs) relFile(ctx context.Context, pj project.Project, role str mode = os.FileMode(*ta.mode) } - if err := conn.PutFile(ctx, []byte(result), dest, mode); err != nil { + if err := conn.PutFile(ctx, result, dest, mode); err != nil { return fmt.Errorf("copy file error: %w", err) } @@ -189,7 +200,7 @@ func (ta templateArgs) relDir(ctx context.Context, pj project.Project, role stri if err != nil { return fmt.Errorf("read file error: %w", err) } - result, err := tmpl.ParseString(vars, string(data)) + result, err := tmpl.Parse(vars, string(data)) if err != nil { return fmt.Errorf("parse file error: %w", err) } @@ -203,7 +214,7 @@ func (ta templateArgs) relDir(ctx context.Context, pj project.Project, role stri dest = filepath.Join(ta.dest, rel) } - if err := conn.PutFile(ctx, []byte(result), dest, mode); err != nil { + if err := conn.PutFile(ctx, result, dest, mode); err != nil { return fmt.Errorf("copy file error: %w", err) } @@ -215,34 +226,6 @@ func (ta templateArgs) relDir(ctx context.Context, pj project.Project, role stri return nil } -// absFile when template.src is absolute file, get file by os, parse it, and copy to remote. -func (ta templateArgs) absFile(ctx context.Context, mode fs.FileMode, conn connector.Connector, vars map[string]any) error { - data, err := os.ReadFile(ta.src) - if err != nil { - return fmt.Errorf("read file error: %w", err) - } - - result, err := tmpl.ParseString(vars, string(data)) - if err != nil { - return fmt.Errorf("parse file error: %w", err) - } - - dest := ta.dest - if strings.HasSuffix(ta.dest, "/") { - dest = filepath.Join(ta.dest, filepath.Base(ta.src)) - } - - if ta.mode != nil { - mode = os.FileMode(*ta.mode) - } - - if err := conn.PutFile(ctx, []byte(result), dest, mode); err != nil { - return fmt.Errorf("copy file error: %w", err) - } - - return nil -} - // absDir when template.src is absolute dir, get all files by os, parse it, and copy to remote. func (ta templateArgs) absDir(ctx context.Context, conn connector.Connector, vars map[string]any) error { if err := filepath.WalkDir(ta.src, func(path string, d fs.DirEntry, err error) error { @@ -267,7 +250,7 @@ func (ta templateArgs) absDir(ctx context.Context, conn connector.Connector, var if err != nil { return fmt.Errorf("read file error: %w", err) } - result, err := tmpl.ParseString(vars, string(data)) + result, err := tmpl.Parse(vars, string(data)) if err != nil { return fmt.Errorf("parse file error: %w", err) } @@ -281,7 +264,7 @@ func (ta templateArgs) absDir(ctx context.Context, conn connector.Connector, var dest = filepath.Join(ta.dest, rel) } - if err := conn.PutFile(ctx, []byte(result), dest, mode); err != nil { + if err := conn.PutFile(ctx, result, dest, mode); err != nil { return fmt.Errorf("copy file error: %w", err) } diff --git a/pkg/variable/helper.go b/pkg/variable/helper.go index 40213dc4..89038f19 100644 --- a/pkg/variable/helper.go +++ b/pkg/variable/helper.go @@ -26,6 +26,7 @@ import ( "time" kkcorev1 "github.com/kubesphere/kubekey/api/core/v1" + kkprojectv1 "github.com/kubesphere/kubekey/api/project/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/json" "k8s.io/klog/v2" @@ -175,7 +176,7 @@ func parseVariableFromMap(v any, parseTmplFunc func(string) (string, error)) err for _, kv := range reflect.ValueOf(v).MapKeys() { val := reflect.ValueOf(v).MapIndex(kv) if vv, ok := val.Interface().(string); ok { - if !tmpl.IsTmplSyntax(vv) { + if !kkprojectv1.IsTmplSyntax(vv) { continue } @@ -207,7 +208,7 @@ func parseVariableFromArray(v any, parseTmplFunc func(string) (string, error)) e for i := range reflect.ValueOf(v).Len() { val := reflect.ValueOf(v).Index(i) if vv, ok := val.Interface().(string); ok { - if !tmpl.IsTmplSyntax(vv) { + if !kkprojectv1.IsTmplSyntax(vv) { continue } @@ -274,7 +275,7 @@ func StringVar(d map[string]any, args map[string]any, key string) (string, error return "", fmt.Errorf("variable \"%s\" is not string", key) } - return tmpl.ParseString(d, sv) + return tmpl.ParseFunc(d, sv, func(b []byte) string { return string(b) }) } // StringSliceVar get string slice value by key @@ -298,7 +299,7 @@ func StringSliceVar(d map[string]any, vars map[string]any, key string) ([]string return nil, nil } - as, err := tmpl.ParseString(d, av) + as, err := tmpl.ParseFunc(d, av, func(b []byte) string { return string(b) }) if err != nil { return nil, err } @@ -308,7 +309,7 @@ func StringSliceVar(d map[string]any, vars map[string]any, key string) ([]string return ss, nil case string: - as, err := tmpl.ParseString(d, valv) + as, err := tmpl.Parse(d, valv) if err != nil { klog.V(4).ErrorS(err, "parse variable error", "key", key) @@ -316,11 +317,11 @@ func StringSliceVar(d map[string]any, vars map[string]any, key string) ([]string } var ss []string - if err := json.Unmarshal([]byte(as), &ss); err == nil { + if err := json.Unmarshal(as, &ss); err == nil { return ss, nil } - return []string{as}, nil + return []string{string(as)}, nil default: klog.V(4).ErrorS(nil, "unsupported variable type", "key", key) @@ -351,7 +352,7 @@ func IntVar(d map[string]any, vars map[string]any, key string) (*int, error) { case reflect.Float32, reflect.Float64: return ptr.To(int(v.Float())), nil case reflect.String: - vs, err := tmpl.ParseString(d, v.String()) + vs, err := tmpl.ParseFunc(d, v.String(), func(b []byte) string { return string(b) }) if err != nil { klog.V(4).ErrorS(err, "parse string variable error", "key", key) @@ -387,20 +388,14 @@ func BoolVar(d map[string]any, args map[string]any, key string) (*bool, error) { case reflect.Bool: return ptr.To(v.Bool()), nil case reflect.String: - vs, err := tmpl.ParseString(d, v.String()) + vs, err := tmpl.ParseBool(d, v.String()) if err != nil { klog.V(4).ErrorS(err, "parse string variable error", "key", key) return nil, err } - if strings.EqualFold(vs, "TRUE") { - return ptr.To(true), nil - } - - if strings.EqualFold(vs, "FALSE") { - return ptr.To(false), nil - } + return ptr.To(vs), nil } return nil, fmt.Errorf("unsupported variable \"%s\" type", key) @@ -468,10 +463,10 @@ func Extension2String(d map[string]any, ext runtime.RawExtension) (string, error input = ns } - result, err := tmpl.ParseString(d, input) + result, err := tmpl.Parse(d, input) if err != nil { return "", err } - return result, nil + return string(result), nil } diff --git a/pkg/variable/helper_test.go b/pkg/variable/helper_test.go index f25adf0e..7ca859c5 100644 --- a/pkg/variable/helper_test.go +++ b/pkg/variable/helper_test.go @@ -281,7 +281,7 @@ func TestParseVariable(t *testing.T) { t.Run(tc.name, func(t *testing.T) { err := parseVariable(tc.data, func(s string) (string, error) { // parse use total variable. the task variable should not contain template syntax. - return tmpl.ParseString(CombineVariables(tc.data, tc.base), s) + return tmpl.ParseFunc(CombineVariables(tc.data, tc.base), s, func(b []byte) string { return string(b) }) }) if err != nil { t.Fatal(err) diff --git a/pkg/variable/variable_get.go b/pkg/variable/variable_get.go index ff936d35..5498af68 100644 --- a/pkg/variable/variable_get.go +++ b/pkg/variable/variable_get.go @@ -31,7 +31,7 @@ var GetHostnames = func(name []string) GetFunc { var hs []string for _, n := range name { // try parse hostname by Config. - if pn, err := tmpl.ParseString(Extension2Variables(vv.value.Config.Spec), n); err == nil { + if pn, err := tmpl.ParseFunc(Extension2Variables(vv.value.Config.Spec), n, func(b []byte) string { return string(b) }); err == nil { n = pn } // add host to hs diff --git a/pkg/variable/variable_merge.go b/pkg/variable/variable_merge.go index 0e3b24df..284c0132 100644 --- a/pkg/variable/variable_merge.go +++ b/pkg/variable/variable_merge.go @@ -3,10 +3,7 @@ package variable import ( "errors" "fmt" - "os" - "strconv" - _const "github.com/kubesphere/kubekey/v4/pkg/const" "github.com/kubesphere/kubekey/v4/pkg/converter/tmpl" ) @@ -51,36 +48,13 @@ var MergeRuntimeVariable = func(data map[string]any, hosts ...string) MergeFunc if !ok { return errors.New("variable type error") } - - depth := 3 - if envDepth, err := strconv.Atoi(os.Getenv(_const.ENV_VARIABLE_PARSE_DEPTH)); err == nil { - if envDepth != 0 { - depth = envDepth - } - } - - for range depth { - // merge to specify host - curVariable, err := v.Get(GetAllVariable(hostname)) - if err != nil { - return err - } - // parse variable - if err := parseVariable(data, func(s string) (string, error) { - // parse use total variable. the task variable should not contain template syntax. - cv, ok := curVariable.(map[string]any) - if !ok { - return "", errors.New("variable type error") - } - - return tmpl.ParseString(CombineVariables(data, cv), s) - }); err != nil { - return err - } - hv := vv.value.Hosts[hostname] - hv.RuntimeVars = CombineVariables(hv.RuntimeVars, data) - vv.value.Hosts[hostname] = hv + // parse variable + if err := parseVariable(data, runtimeVarParser(v, hostname, data)); err != nil { + return err } + hv := vv.value.Hosts[hostname] + hv.RuntimeVars = CombineVariables(hv.RuntimeVars, data) + vv.value.Hosts[hostname] = hv } return nil @@ -94,24 +68,10 @@ var MergeAllRuntimeVariable = func(data map[string]any, hostname string) MergeFu if !ok { return errors.New("variable type error") } - // merge to specify host - curVariable, err := v.Get(GetAllVariable(hostname)) - if err != nil { - return err - } // parse variable - if err := parseVariable(data, func(s string) (string, error) { - // parse use total variable. the task variable should not contain template syntax. - cv, ok := curVariable.(map[string]any) - if !ok { - return "", errors.New("variable type error") - } - - return tmpl.ParseString(CombineVariables(data, cv), s) - }); err != nil { + if err := parseVariable(data, runtimeVarParser(v, hostname, data)); err != nil { return err } - for h := range vv.value.Hosts { if _, ok := v.(*variable); !ok { return errors.New("variable type error") @@ -124,3 +84,22 @@ var MergeAllRuntimeVariable = func(data map[string]any, hostname string) MergeFu return nil } } + +func runtimeVarParser(v Variable, hostname string, data map[string]any) func(string) (string, error) { + return func(s string) (string, error) { + curVariable, err := v.Get(GetAllVariable(hostname)) + if err != nil { + return "", fmt.Errorf("get host %s variables error: %w", hostname, err) + } + cv, ok := curVariable.(map[string]any) + if !ok { + return "", fmt.Errorf("host %s variables type error, expect map[string]any", hostname) + } + + return tmpl.ParseFunc( + CombineVariables(data, cv), + s, + func(b []byte) string { return string(b) }, + ) + } +}