feat: Parse YAML parameters sequentially. (#2555)

Signed-off-by: joyceliu <joyceliu@yunify.com>
This commit is contained in:
liujian 2025-05-07 17:15:52 +08:00 committed by GitHub
parent dc8717479b
commit 9502ac5391
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1146 additions and 502 deletions

View File

@ -69,8 +69,8 @@ linters:
- nakedret
- nestif
# - nilerr
- nilnil
- nlreturn
# - nilnil
# - nlreturn
- noctx
- nolintlint
- nonamedreturns
@ -459,7 +459,7 @@ linters-settings:
# # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#enforce-map-style
- name: enforce-map-style
severity: warning
disabled: false
disabled: true
exclude: [""]
arguments:
- "make"
@ -473,7 +473,7 @@ linters-settings:
# # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#enforce-slice-style
- name: enforce-slice-style
severity: warning
disabled: false
disabled: true
exclude: [""]
arguments:
- "make"

View File

@ -16,6 +16,8 @@ limitations under the License.
package v1
import "gopkg.in/yaml.v3"
// Base defined in project.
type Base struct {
Name string `yaml:"name,omitempty"`
@ -26,7 +28,7 @@ type Base struct {
RemoteUser string `yaml:"remote_user,omitempty"`
// variables
Vars map[string]any `yaml:"vars,omitempty"`
Vars yaml.Node `yaml:"vars,omitempty"`
// module default params
//ModuleDefaults []map[string]map[string]any `yaml:"module_defaults,omitempty"`

2
go.mod
View File

@ -17,7 +17,6 @@ require (
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.31.0
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1
k8s.io/api v0.31.3
k8s.io/apimachinery v0.31.3
@ -146,6 +145,7 @@ require (
gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
k8s.io/apiextensions-apiserver v0.31.3 // indirect
k8s.io/cluster-bootstrap v0.31.3 // indirect
k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 // indirect

View File

@ -21,6 +21,7 @@ import (
"strconv"
"strings"
"gopkg.in/yaml.v3"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/json"
@ -149,3 +150,19 @@ func ConvertKKClusterToInventoryHost(kkcluster *capkkinfrav1beta1.KKCluster) (kk
return inventoryHosts, nil
}
// ConvertMap2Node converts a map[string]any to a yaml.Node by first marshaling to YAML bytes
// then unmarshaling into a Node. This allows working with the YAML node structure directly.
func ConvertMap2Node(m map[string]any) (yaml.Node, error) {
data, err := yaml.Marshal(m)
if err != nil {
return yaml.Node{}, errors.Wrap(err, "failed to marshal map to yaml")
}
var node yaml.Node
err = yaml.Unmarshal(data, &node)
if err != nil {
return yaml.Node{}, errors.Wrap(err, "failed to unmarshal yaml to node")
}
return node, nil
}

View File

@ -170,3 +170,51 @@ func TestConvertKKClusterToInventoryHost(t *testing.T) {
})
}
}
func TestConvertMap2Node(t *testing.T) {
testcases := []struct {
name string
input map[string]any
wantErr bool
}{
{
name: "simple map",
input: map[string]any{
"key1": "value1",
"key2": 123,
"key3": true,
},
},
{
name: "nested map",
input: map[string]any{
"outer": map[string]any{
"inner": "value",
},
"array": []any{"a", "b", "c"},
},
},
{
name: "empty map",
input: map[string]any{},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
node, err := ConvertMap2Node(tc.input)
if err != nil {
t.Fatal(err)
}
assert.NotNil(t, node)
// Convert back to map to verify roundtrip
var result map[string]any
err = node.Decode(&result)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, tc.input, result)
})
}
}

View File

@ -8,6 +8,7 @@ import (
kkcorev1alpha1 "github.com/kubesphere/kubekey/api/core/v1alpha1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
_const "github.com/kubesphere/kubekey/v4/pkg/const"
@ -15,21 +16,40 @@ import (
"github.com/kubesphere/kubekey/v4/pkg/variable/source"
)
func newTestOption() (*option, error) {
func newTestOption(hosts []string) (*option, error) {
var err error
// convert host to InventoryHost
inventoryHost := make(kkcorev1.InventoryHost)
for _, h := range hosts {
inventoryHost[h] = runtime.RawExtension{}
}
client := fake.NewClientBuilder().WithScheme(_const.Scheme).WithStatusSubresource(&kkcorev1.Playbook{}, &kkcorev1alpha1.Task{}).Build()
inventory := &kkcorev1.Inventory{
TypeMeta: metav1.TypeMeta{},
ObjectMeta: metav1.ObjectMeta{
GenerateName: "test-",
Namespace: corev1.NamespaceDefault,
},
Spec: kkcorev1.InventorySpec{
Hosts: inventoryHost,
},
}
if err := client.Create(context.TODO(), inventory); err != nil {
return nil, err
}
o := &option{
client: fake.NewClientBuilder().WithScheme(_const.Scheme).WithStatusSubresource(&kkcorev1.Playbook{}, &kkcorev1alpha1.Task{}).Build(),
client: client,
playbook: &kkcorev1.Playbook{
TypeMeta: metav1.TypeMeta{},
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: corev1.NamespaceDefault,
GenerateName: "test-",
Namespace: corev1.NamespaceDefault,
},
Spec: kkcorev1.PlaybookSpec{
InventoryRef: &corev1.ObjectReference{
Name: "test",
Namespace: corev1.NamespaceDefault,
Name: inventory.Name,
Namespace: inventory.Namespace,
},
},
Status: kkcorev1.PlaybookStatus{},
@ -37,17 +57,6 @@ func newTestOption() (*option, error) {
logOutput: os.Stdout,
}
if err := o.client.Create(context.TODO(), &kkcorev1.Inventory{
TypeMeta: metav1.TypeMeta{},
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: corev1.NamespaceDefault,
},
Spec: kkcorev1.InventorySpec{},
}); err != nil {
return nil, err
}
o.variable, err = variable.New(context.TODO(), o.client, *o.playbook, source.MemorySource)
if err != nil {
return nil, err

View File

@ -20,11 +20,13 @@ import (
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
_const "github.com/kubesphere/kubekey/v4/pkg/const"
"github.com/kubesphere/kubekey/v4/pkg/converter"
"github.com/kubesphere/kubekey/v4/pkg/converter/tmpl"
"github.com/kubesphere/kubekey/v4/pkg/modules"
"github.com/kubesphere/kubekey/v4/pkg/variable"
)
// taskExecutor handles the execution of a single task across multiple hosts.
type taskExecutor struct {
*option
task *kkcorev1alpha1.Task
@ -34,7 +36,8 @@ type taskExecutor struct {
taskRunTimeout time.Duration
}
// Exec and store Task
// Exec creates and executes a task, updating its status and the parent playbook's status.
// It returns an error if the task creation or execution fails.
func (e *taskExecutor) Exec(ctx context.Context) error {
// create task
if err := e.client.Create(ctx, e.task); err != nil {
@ -141,7 +144,7 @@ func (e *taskExecutor) runTaskLoop(ctx context.Context) error {
}
}
// execTask
// execTask executes the task across all specified hosts in parallel and updates the task status.
func (e *taskExecutor) execTask(ctx context.Context) {
// check task host results
wg := &wait.Group{}
@ -165,7 +168,8 @@ func (e *taskExecutor) execTask(ctx context.Context) {
}
}
// execTaskHost deal module in each host parallel.
// execTaskHost handles executing a task on a single host, including variable setup,
// condition checking, and module execution. It runs in parallel for each host.
func (e *taskExecutor) execTaskHost(i int, h string) func(ctx context.Context) {
return func(ctx context.Context) {
// task result
@ -208,28 +212,13 @@ func (e *taskExecutor) execTaskHost(i int, h string) func(ctx context.Context) {
// execute module in loop with loop item.
// if loop is empty. execute once, and the item is null
for _, item := range e.dealLoop(had) {
// set item to runtime variable
if err := e.variable.Merge(variable.MergeRuntimeVariable(map[string]any{
_const.VariableItem: item,
}, h)); err != nil {
stderr = fmt.Sprintf("set loop item to variable error: %v", err)
return
}
e.executeModule(ctx, e.task, h, &stdout, &stderr)
// delete item
if err := e.variable.Merge(variable.MergeRuntimeVariable(map[string]any{
_const.VariableItem: nil,
}, h)); err != nil {
stderr = fmt.Sprintf("clean loop item to variable error: %v", err)
return
}
e.executeModule(ctx, e.task, item, h, &stdout, &stderr)
}
}
}
// execTaskHostLogs logs for each host
// execTaskHostLogs sets up and manages progress bar logging for task execution on a host.
// It returns a cleanup function to be called when execution completes.
func (e *taskExecutor) execTaskHostLogs(ctx context.Context, h string, stdout, stderr *string) func() {
// placeholder format task log
var placeholder string
@ -287,26 +276,62 @@ func (e *taskExecutor) execTaskHostLogs(ctx context.Context, h string, stdout, s
}
}
// executeModule find register module and execute it in a single host.
func (e *taskExecutor) executeModule(ctx context.Context, task *kkcorev1alpha1.Task, host string, stdout, stderr *string) {
// get all variable. which contains item.
// executeModule executes a single module task on a specific host. It handles setting up loop item variables,
// retrieving host variables, checking failure conditions, and executing the actual module.
func (e *taskExecutor) executeModule(ctx context.Context, task *kkcorev1alpha1.Task, item any, host string, stdout, stderr *string) {
// Set loop item variable if one was provided
if item != nil {
// Convert item to runtime variable
node, err := converter.ConvertMap2Node(map[string]any{_const.VariableItem: item})
if err != nil {
*stderr = fmt.Sprintf("convert loop item error: %v", err)
return
}
// Merge item into host's runtime variables
if err := e.variable.Merge(variable.MergeRuntimeVariable(node, host)); err != nil {
*stderr = fmt.Sprintf("set loop item to variable error: %v", err)
return
}
// Clean up loop item variable after execution
defer func() {
if item == nil {
return
}
// Reset item to null
resetNode, err := converter.ConvertMap2Node(map[string]any{_const.VariableItem: nil})
if err != nil {
*stderr = fmt.Sprintf("convert loop item error: %v", err)
return
}
if err := e.variable.Merge(variable.MergeRuntimeVariable(resetNode, host)); err != nil {
*stderr = fmt.Sprintf("clean loop item to variable error: %v", err)
return
}
}()
}
// Get all variables for this host, including any loop item
ha, err := e.variable.Get(variable.GetAllVariable(host))
if err != nil {
*stderr = fmt.Sprintf("failed to get host %s variable: %v", host, err)
return
}
// convert hostVariable to map
// Convert host variables to map type
had, ok := ha.(map[string]any)
if !ok {
*stderr = fmt.Sprintf("host: %s variable is not a map", host)
return
}
// check failed when condition
// Check if task should fail based on failed_when conditions
if skip := e.dealFailedWhen(had, stdout, stderr); skip {
return
}
// Execute the actual module with the prepared context
*stdout, *stderr = modules.FindModule(task.Spec.Module.Name)(ctx, modules.ExecOptions{
Args: e.task.Spec.Module.Args,
Host: host,
@ -316,9 +341,9 @@ func (e *taskExecutor) executeModule(ctx context.Context, task *kkcorev1alpha1.T
})
}
// execLoop parse loop to item slice and execute it. if loop contains template string. convert it.
// loop is json string. try convertor to string slice by json.
// loop is normal string. set it to empty slice and return.
// dealLoop parses the loop specification into a slice of items to iterate over.
// If no loop is specified, returns a single nil item. Otherwise converts the loop
// specification from JSON into a slice of values.
func (e *taskExecutor) dealLoop(ha map[string]any) []any {
var items []any
switch {
@ -332,7 +357,8 @@ func (e *taskExecutor) dealLoop(ha map[string]any) []any {
return items
}
// dealWhen "when" argument in task.
// dealWhen evaluates the "when" conditions for a task to determine if it should be skipped.
// Returns true if the task should be skipped, false if it should proceed.
func (e *taskExecutor) dealWhen(had map[string]any, stdout, stderr *string) bool {
if len(e.task.Spec.When) > 0 {
ok, err := tmpl.ParseBool(had, e.task.Spec.When...)
@ -352,7 +378,8 @@ func (e *taskExecutor) dealWhen(had map[string]any, stdout, stderr *string) bool
return false
}
// dealFailedWhen "failed_when" argument in task.
// dealFailedWhen evaluates the "failed_when" conditions for a task to determine if it should fail.
// Returns true if the task should be marked as failed, false if it should proceed.
func (e *taskExecutor) dealFailedWhen(had map[string]any, stdout, stderr *string) bool {
if len(e.task.Spec.FailedWhen) > 0 {
ok, err := tmpl.ParseBool(had, e.task.Spec.FailedWhen...)
@ -373,7 +400,8 @@ func (e *taskExecutor) dealFailedWhen(had map[string]any, stdout, stderr *string
return false
}
// dealRegister "register" argument in task.
// dealRegister handles storing task output in a registered variable if specified.
// The output can be stored as raw string, JSON, or YAML based on the register type.
func (e *taskExecutor) dealRegister(stdout, stderr, host string) error {
if e.task.Spec.Register != "" {
var stdoutResult any = stdout
@ -387,12 +415,16 @@ func (e *taskExecutor) dealRegister(stdout, stderr, host string) error {
// store by string
}
// set variable to parent location
if err := e.variable.Merge(variable.MergeRuntimeVariable(map[string]any{
node, err := converter.ConvertMap2Node(map[string]any{
e.task.Spec.Register: map[string]any{
"stdout": stdoutResult,
"stderr": stderrResult,
},
}, host)); err != nil {
})
if err != nil {
return err
}
if err := e.variable.Merge(variable.MergeRuntimeVariable(node, host)); err != nil {
return err
}
}

View File

@ -14,15 +14,16 @@ import (
func TestTaskExecutor(t *testing.T) {
testcases := []struct {
name string
task *kkcorev1alpha1.Task
name string
hosts []string
task *kkcorev1alpha1.Task
}{
{
name: "debug module in single host",
task: &kkcorev1alpha1.Task{
TypeMeta: metav1.TypeMeta{},
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Name: "test1",
Namespace: corev1.NamespaceDefault,
},
Spec: kkcorev1alpha1.TaskSpec{
@ -35,12 +36,34 @@ func TestTaskExecutor(t *testing.T) {
Status: kkcorev1alpha1.TaskStatus{},
},
},
{
name: "debug module in single host with loop",
hosts: []string{"node1"},
task: &kkcorev1alpha1.Task{
TypeMeta: metav1.TypeMeta{},
ObjectMeta: metav1.ObjectMeta{
Name: "test2",
Namespace: corev1.NamespaceDefault,
},
Spec: kkcorev1alpha1.TaskSpec{
Hosts: []string{"node1"},
Module: kkcorev1alpha1.Module{
Name: "debug",
Args: runtime.RawExtension{Raw: []byte(`{"msg":"hello"}`)},
},
Loop: runtime.RawExtension{
Raw: []byte(string(`["a", "b"]`)),
},
},
Status: kkcorev1alpha1.TaskStatus{},
},
},
{
name: "debug module in multiple hosts",
task: &kkcorev1alpha1.Task{
TypeMeta: metav1.TypeMeta{},
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Name: "test3",
Namespace: corev1.NamespaceDefault,
},
Spec: kkcorev1alpha1.TaskSpec{
@ -56,7 +79,7 @@ func TestTaskExecutor(t *testing.T) {
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
o, err := newTestOption()
o, err := newTestOption(tc.hosts)
if err != nil {
t.Fatal(err)
}

View File

@ -18,44 +18,21 @@ package modules
import (
"context"
"encoding/json"
"fmt"
"gopkg.in/yaml.v2"
"gopkg.in/yaml.v3"
"github.com/kubesphere/kubekey/v4/pkg/converter/tmpl"
"github.com/kubesphere/kubekey/v4/pkg/variable"
)
// ModuleSetFact deal "set_fact" module
func ModuleSetFact(_ context.Context, options ExecOptions) (string, string) {
ha, err := options.getAllVariables()
if err != nil {
return "", err.Error()
var node yaml.Node
// Unmarshal the YAML document into a root node.
if err := yaml.Unmarshal(options.Args.Raw, &node); err != nil {
return "", fmt.Sprintf("failed to unmarshal YAML error: %v", err)
}
// get host variable
args := variable.Extension2Variables(options.Args)
for k, v := range args {
switch val := v.(type) {
case bool, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
args[k] = val
case string:
sv, err := tmpl.Parse(ha, val)
if err != nil {
return "", fmt.Sprintf("parse %q error: %v", k, err)
}
var ssvResult any
if json.Valid(sv) {
_ = json.Unmarshal(sv, &ssvResult)
} else {
_ = yaml.Unmarshal(sv, &ssvResult)
}
args[k] = ssvResult
default:
return "", fmt.Sprintf("only support bool, int, float64, string value for %q.", k)
}
}
if err := options.Variable.Merge(variable.MergeAllRuntimeVariable(args, options.Host)); err != nil {
if err := options.Variable.Merge(variable.MergeAllRuntimeVariable(node, options.Host)); err != nil {
return "", fmt.Sprintf("set_fact error: %v", err)
}

View File

@ -85,7 +85,7 @@ func TestSetFact(t *testing.T) {
Task: kkcorev1alpha1.Task{},
Playbook: kkcorev1.Playbook{},
},
exceptStderr: "only support bool, int, float64, string value for \"k\".",
exceptStdout: "success",
},
{
name: "array value",
@ -98,7 +98,7 @@ func TestSetFact(t *testing.T) {
Task: kkcorev1alpha1.Task{},
Playbook: kkcorev1.Playbook{},
},
exceptStderr: "only support bool, int, float64, string value for \"k\".",
exceptStdout: "success",
},
}

View File

@ -173,18 +173,23 @@ func (f *project) dealVarsFiles(p *kkprojectv1.Play, basePlaybook string) error
if file == "" {
return errors.Errorf("failed to find vars_files %q base on %q. it's should be:\n %s", varsFile, basePlaybook, PathFormatVarsFile)
}
data, err := fs.ReadFile(f.FS, file)
if err != nil {
return errors.Wrapf(err, "failed to read file %q", file)
}
var newVars map[string]any
var node yaml.Node
// Unmarshal the YAML document into a root node.
if err := yaml.Unmarshal(data, &newVars); err != nil {
if err := yaml.Unmarshal(data, &node); err != nil {
return errors.Wrap(err, "failed to failed to unmarshal YAML")
}
// store vars in play. the vars defined in file should not be repeated.
p.Vars = variable.CombineVariables(newVars, p.Vars)
if node.Kind != yaml.DocumentNode || len(node.Content) != 1 {
return errors.Errorf("unsupport vars_files format. it should be single map file")
}
// combine map node
if node.Content[0].Kind == yaml.MappingNode {
// skip empty file
p.Vars = *variable.CombineMappingNode(node.Content[0], &p.Vars)
}
}
return nil
@ -218,13 +223,19 @@ func (f *project) dealRoles(p kkprojectv1.Play, basePlaybook string) error {
return errors.Wrapf(err, "failed to read defaults variable file %q", defaults)
}
var newVars map[string]any
var node yaml.Node
// Unmarshal the YAML document into a root node.
if err := yaml.Unmarshal(data, &newVars); err != nil {
if err := yaml.Unmarshal(data, &node); err != nil {
return errors.Wrap(err, "failed to unmarshal YAML")
}
// store vars in play. the vars defined in file should not be repeated.
p.Roles[i].Vars = variable.CombineVariables(newVars, p.Roles[i].Vars)
if node.Kind != yaml.DocumentNode || len(node.Content) != 1 {
return errors.Errorf("unsupport vars_files format. it should be single map file")
}
// combine map node
if node.Content[0].Kind == yaml.MappingNode {
// skip empty file
p.Roles[i].Vars = *variable.CombineMappingNode(node.Content[0], &p.Roles[i].Vars)
}
}
}

View File

@ -7,6 +7,7 @@ import (
kkcorev1alpha1 "github.com/kubesphere/kubekey/api/core/v1alpha1"
kkprojectv1 "github.com/kubesphere/kubekey/api/project/v1"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestMarshalPlaybook(t *testing.T) {
@ -295,6 +296,220 @@ func TestMarshalPlaybook(t *testing.T) {
},
},
},
{
name: "test_vars_1",
playbook: kkcorev1.Playbook{
Spec: kkcorev1.PlaybookSpec{
Playbook: "testdata/playbook_var1.yaml",
},
},
except: &kkprojectv1.Playbook{
Play: []kkprojectv1.Play{
{
Base: kkprojectv1.Base{
Name: "playbook-var1",
Vars: yaml.Node{
Kind: yaml.MappingNode,
Tag: "!!map",
Line: 6,
Column: 5,
Content: []*yaml.Node{
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a",
Line: 6,
Column: 5,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "b",
Line: 6,
Column: 8,
},
},
},
},
PlayHost: kkprojectv1.PlayHost{
Hosts: []string{"node1"},
},
},
},
},
},
{
name: "test_vars_2",
playbook: kkcorev1.Playbook{
Spec: kkcorev1.PlaybookSpec{
Playbook: "testdata/playbook_var2.yaml",
},
},
except: &kkprojectv1.Playbook{
Play: []kkprojectv1.Play{
{
VarsFiles: []string{"vars/var1.yaml", "vars/var2.yaml"},
Base: kkprojectv1.Base{
Name: "playbook-var2",
Vars: yaml.Node{
Kind: yaml.MappingNode,
Tag: "!!map",
Line: 2,
Column: 1,
Content: []*yaml.Node{
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a1",
Line: 2,
Column: 1,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "aa",
Line: 2,
Column: 5,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a2",
Line: 3,
Column: 1,
},
{
Kind: yaml.ScalarNode,
Tag: "!!int",
Value: "1",
Line: 3,
Column: 5,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a2",
Line: 1,
Column: 1,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "aaa",
Line: 1,
Column: 5,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a3",
Line: 2,
Column: 1,
},
{
Kind: yaml.MappingNode,
Tag: "!!map",
Value: "",
Line: 3,
Column: 2,
Content: []*yaml.Node{
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "b3",
Line: 3,
Column: 2,
},
{
Kind: yaml.ScalarNode,
Tag: "!!int",
Value: "1",
Line: 3,
Column: 6,
},
},
},
},
},
},
PlayHost: kkprojectv1.PlayHost{
Hosts: []string{"node1"},
},
},
},
},
},
{
name: "test_vars_3",
playbook: kkcorev1.Playbook{
Spec: kkcorev1.PlaybookSpec{
Playbook: "testdata/playbook_var3.yaml",
},
},
except: &kkprojectv1.Playbook{
Play: []kkprojectv1.Play{
{
VarsFiles: []string{"vars/var1.yaml"},
Base: kkprojectv1.Base{
Name: "playbook-var3",
Vars: yaml.Node{
Kind: yaml.MappingNode,
Tag: "!!map",
Line: 8,
Column: 5,
Content: []*yaml.Node{
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a2",
Line: 8,
Column: 5,
},
{
Kind: yaml.ScalarNode,
Tag: "!!int",
Value: "2",
Line: 8,
Column: 9,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a1",
Line: 2,
Column: 1,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "aa",
Line: 2,
Column: 5,
},
{
Kind: yaml.ScalarNode,
Tag: "!!str",
Value: "a2",
Line: 3,
Column: 1,
},
{
Kind: yaml.ScalarNode,
Tag: "!!int",
Value: "1",
Line: 3,
Column: 5,
},
},
},
},
PlayHost: kkprojectv1.PlayHost{
Hosts: []string{"node1"},
},
},
},
},
},
}
for _, tc := range testcases {

View File

@ -0,0 +1,7 @@
---
- name: playbook-var1
hosts:
- node1
vars:
a: b

View File

@ -0,0 +1,8 @@
---
- name: playbook-var2
hosts:
- node1
vars_files:
- vars/var1.yaml
- vars/var2.yaml

View File

@ -0,0 +1,9 @@
---
- name: playbook-var3
hosts:
- node1
vars_files:
- vars/var1.yaml
vars:
a2: 2

3
pkg/project/testdata/vars/var1.yaml vendored Normal file
View File

@ -0,0 +1,3 @@
---
a1: aa
a2: 1

3
pkg/project/testdata/vars/var2.yaml vendored Normal file
View File

@ -0,0 +1,3 @@
a2: aaa
a3:
b3: 1

View File

@ -17,16 +17,14 @@ limitations under the License.
package variable
import (
"net"
"reflect"
"slices"
"strconv"
"strings"
"time"
"github.com/cockroachdb/errors"
kkcorev1 "github.com/kubesphere/kubekey/api/core/v1"
kkprojectv1 "github.com/kubesphere/kubekey/api/project/v1"
"gopkg.in/yaml.v3"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/klog/v2"
@ -36,18 +34,53 @@ import (
"github.com/kubesphere/kubekey/v4/pkg/converter/tmpl"
)
// CombineVariables merge multiple variables into one variable
// v2 will override v1 if variable is repeated
// CombineMappingNode combines two yaml.Node objects representing mapping nodes.
// If b is nil or zero, returns a.
// If both a and b are mapping nodes, appends a's content to b.
// Returns b in all other cases.
//
// Parameters:
// - a: First yaml.Node to combine
// - b: Second yaml.Node to combine
//
// Returns:
// - Combined yaml.Node, with b taking precedence
func CombineMappingNode(a, b *yaml.Node) *yaml.Node {
if b == nil || b.IsZero() {
return a
}
if a.Kind == yaml.MappingNode && b.Kind == yaml.MappingNode {
b.Content = append(b.Content, a.Content...)
}
return b
}
// CombineVariables merge multiple variables into one variable.
// It recursively combines two maps, where values from m2 override values from m1 if keys overlap.
// For nested maps, it will recursively merge their contents.
// For non-map values or when either input is nil, m2's value takes precedence.
//
// Parameters:
// - m1: The first map to merge (base map)
// - m2: The second map to merge (override map)
//
// Returns:
// - A new map containing the merged key-value pairs from both input maps
func CombineVariables(m1, m2 map[string]any) map[string]any {
var f func(val1, val2 any) any
f = func(val1, val2 any) any {
// If both values are non-nil maps, merge them recursively
if val1 != nil && val2 != nil &&
reflect.TypeOf(val1).Kind() == reflect.Map && reflect.TypeOf(val2).Kind() == reflect.Map {
mergedVars := make(map[string]any)
// Copy all values from val1 first
for _, k := range reflect.ValueOf(val1).MapKeys() {
mergedVars[k.String()] = reflect.ValueOf(val1).MapIndex(k).Interface()
}
// Merge in values from val2, recursively handling nested maps
for _, k := range reflect.ValueOf(val2).MapKeys() {
mergedVars[k.String()] = f(mergedVars[k.String()], reflect.ValueOf(val2).MapIndex(k).Interface())
}
@ -55,14 +88,19 @@ func CombineVariables(m1, m2 map[string]any) map[string]any {
return mergedVars
}
// For non-map values or nil inputs, return val2
return val2
}
// Initialize result map
mv := make(map[string]any)
// Copy all key-value pairs from m1
for k, v := range m1 {
mv[k] = v
}
// Merge in values from m2
for k, v := range m2 {
mv[k] = f(mv[k], v)
}
@ -70,6 +108,41 @@ func CombineVariables(m1, m2 map[string]any) map[string]any {
return mv
}
// CombineSlice combines two string slices while skipping duplicate values.
// It maintains the order of elements from g1 followed by unique elements from g2.
//
// Parameters:
// - g1: The first slice of strings
// - g2: The second slice of strings
//
// Returns:
// - A new slice containing unique strings from both input slices,
// preserving order with g1 elements appearing before unique g2 elements
func CombineSlice(g1, g2 []string) []string {
uniqueValues := make(map[string]bool)
mg := make([]string, 0)
// Add values from the first slice
for _, v := range g1 {
if !uniqueValues[v] {
uniqueValues[v] = true
mg = append(mg, v)
}
}
// Add values from the second slice
for _, v := range g2 {
if !uniqueValues[v] {
uniqueValues[v] = true
mg = append(mg, v)
}
}
return mg
}
// ConvertGroup converts the inventory into a map of groups with their respective hosts.
// It ensures that all hosts are included in the "all" group and adds a default localhost if not present.
// It also creates an "ungrouped" group for hosts that are not part of any specific group.
@ -120,145 +193,15 @@ func HostsInGroup(inv kkcorev1.Inventory, groupName string) []string {
if v, ok := inv.Spec.Groups[groupName]; ok {
var hosts []string
for _, cg := range v.Groups {
hosts = mergeSlice(HostsInGroup(inv, cg), hosts)
hosts = CombineSlice(HostsInGroup(inv, cg), hosts)
}
return mergeSlice(hosts, v.Hosts)
return CombineSlice(hosts, v.Hosts)
}
return nil
}
// mergeSlice with skip repeat value
func mergeSlice(g1, g2 []string) []string {
uniqueValues := make(map[string]bool)
mg := make([]string, 0)
// Add values from the first slice
for _, v := range g1 {
if !uniqueValues[v] {
uniqueValues[v] = true
mg = append(mg, v)
}
}
// Add values from the second slice
for _, v := range g2 {
if !uniqueValues[v] {
uniqueValues[v] = true
mg = append(mg, v)
}
}
return mg
}
// parseVariable parse all string values to the actual value.
func parseVariable(v any, parseTmplFunc func(string) (string, error)) error {
switch reflect.ValueOf(v).Kind() {
case reflect.Map:
if err := parseVariableFromMap(v, parseTmplFunc); err != nil {
return err
}
case reflect.Slice, reflect.Array:
if err := parseVariableFromArray(v, parseTmplFunc); err != nil {
return err
}
}
return nil
}
// parseVariableFromMap parse to variable when the v is map.
func parseVariableFromMap(v any, parseTmplFunc func(string) (string, error)) error {
for _, kv := range reflect.ValueOf(v).MapKeys() {
val := reflect.ValueOf(v).MapIndex(kv)
if vv, ok := val.Interface().(string); ok {
if !kkprojectv1.IsTmplSyntax(vv) {
continue
}
newValue, err := parseTmplFunc(vv)
if err != nil {
return err
}
switch {
case strings.EqualFold(newValue, "TRUE"):
reflect.ValueOf(v).SetMapIndex(kv, reflect.ValueOf(true))
case strings.EqualFold(newValue, "FALSE"):
reflect.ValueOf(v).SetMapIndex(kv, reflect.ValueOf(false))
default:
reflect.ValueOf(v).SetMapIndex(kv, reflect.ValueOf(newValue))
}
} else {
if err := parseVariable(val.Interface(), parseTmplFunc); err != nil {
return err
}
}
}
return nil
}
// parseVariableFromArray parse to variable when the v is slice.
func parseVariableFromArray(v any, parseTmplFunc func(string) (string, error)) error {
for i := range reflect.ValueOf(v).Len() {
val := reflect.ValueOf(v).Index(i)
if vv, ok := val.Interface().(string); ok {
if !kkprojectv1.IsTmplSyntax(vv) {
continue
}
newValue, err := parseTmplFunc(vv)
if err != nil {
return err
}
switch {
case strings.EqualFold(newValue, "TRUE"):
val.Set(reflect.ValueOf(true))
case strings.EqualFold(newValue, "FALSE"):
val.Set(reflect.ValueOf(false))
default:
val.Set(reflect.ValueOf(newValue))
}
} else {
if err := parseVariable(val.Interface(), parseTmplFunc); err != nil {
return err
}
}
}
return nil
}
// getLocalIP get the ipv4 or ipv6 for localhost machine
func getLocalIP(ipType string) string {
addrs, err := net.InterfaceAddrs()
if err != nil {
klog.ErrorS(err, "get network address error")
}
for _, addr := range addrs {
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipType == _const.VariableIPv4 && ipNet.IP.To4() != nil {
return ipNet.IP.String()
}
if ipType == _const.VariableIPv6 && ipNet.IP.To16() != nil && ipNet.IP.To4() == nil {
return ipNet.IP.String()
}
}
}
klog.V(4).Infof("connot get local %s address", ipType)
return ""
}
// StringVar get string value by key
func StringVar(d map[string]any, args map[string]any, key string) (string, error) {
val, ok := args[key]

View File

@ -22,130 +22,9 @@ import (
kkcorev1 "github.com/kubesphere/kubekey/api/core/v1"
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/runtime"
"github.com/kubesphere/kubekey/v4/pkg/converter/tmpl"
)
func TestMergeVariable(t *testing.T) {
testcases := []struct {
name string
v1 map[string]any
v2 map[string]any
excepted map[string]any
}{
{
name: "primary variables value is empty",
v1: nil,
v2: map[string]any{
"a1": "v1",
},
excepted: map[string]any{
"a1": "v1",
},
},
{
name: "auxiliary variables value is empty",
v1: map[string]any{
"p1": "v1",
},
v2: nil,
excepted: map[string]any{
"p1": "v1",
},
},
{
name: "non-repeat value",
v1: map[string]any{
"p1": "v1",
"p2": map[string]any{
"p21": "v21",
},
},
v2: map[string]any{
"a1": "v1",
},
excepted: map[string]any{
"p1": "v1",
"p2": map[string]any{
"p21": "v21",
},
"a1": "v1",
},
},
{
name: "repeat value",
v1: map[string]any{
"p1": "v1",
"p2": map[string]any{
"p21": "v21",
"p22": "v22",
},
},
v2: map[string]any{
"a1": "v1",
"p1": "v2",
"p2": map[string]any{
"p21": "v22",
"a21": "v21",
},
},
excepted: map[string]any{
"a1": "v1",
"p1": "v2",
"p2": map[string]any{
"p21": "v22",
"a21": "v21",
"p22": "v22",
},
},
},
{
name: "repeat deep value",
v1: map[string]any{
"p1": map[string]string{
"p11": "v11",
},
"p2": map[string]any{
"p21": "v21",
"p22": "v22",
},
},
v2: map[string]any{
"p1": map[string]string{
"p21": "v21",
},
"p2": map[string]any{
"p21": map[string]any{
"p211": "v211",
},
"a21": "v21",
},
},
excepted: map[string]any{
"p1": map[string]any{
"p11": "v11",
"p21": "v21",
},
"p2": map[string]any{
"p21": map[string]any{
"p211": "v211",
},
"p22": "v22",
"a21": "v21",
},
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
v := CombineVariables(tc.v1, tc.v2)
assert.Equal(t, tc.excepted, v)
})
}
}
func TestMergeGroup(t *testing.T) {
func TestCombineSlice(t *testing.T) {
testcases := []struct {
name string
g1 []string
@ -180,118 +59,12 @@ func TestMergeGroup(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
ac := mergeSlice(tc.g1, tc.g2)
ac := CombineSlice(tc.g1, tc.g2)
assert.Equal(t, tc.except, ac)
})
}
}
func TestParseVariable(t *testing.T) {
testcases := []struct {
name string
data map[string]any
base map[string]any
except map[string]any
}{
{
name: "parse string",
data: map[string]any{
"a": "{{ .a }}",
},
base: map[string]any{
"a": "b",
},
except: map[string]any{
"a": "b",
},
},
{
name: "parse map",
data: map[string]any{
"a": "{{ .a.b }}",
},
base: map[string]any{
"a": map[string]any{
"b": "c",
},
},
except: map[string]any{
"a": "c",
},
},
{
name: "parse slice",
data: map[string]any{
"a": []string{"{{ .b }}"},
},
base: map[string]any{
"b": "c",
},
except: map[string]any{
"a": []string{"c"},
},
},
{
name: "parse map in slice",
data: map[string]any{
"a": []map[string]any{
{
"a1": []any{"{{ .b }}"},
},
},
},
base: map[string]any{
"b": "c",
},
except: map[string]any{
"a": []map[string]any{
{
"a1": []any{"c"},
},
},
},
},
{
name: "parse slice with bool value",
data: map[string]any{
"a": []any{"{{ .b }}"},
},
base: map[string]any{
"b": "true",
},
except: map[string]any{
"a": []any{true},
},
},
{
name: "parse map with bool value",
data: map[string]any{
"a": "{{ .b }}",
},
base: map[string]any{
"b": "true",
},
except: map[string]any{
"a": true,
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := parseVariable(tc.data, func(s string) (string, error) {
// parse use total variable. the task variable should not contain template syntax.
return tmpl.ParseFunc(CombineVariables(tc.data, tc.base), s, func(b []byte) string { return string(b) })
})
if err != nil {
t.Fatal(err)
}
assert.Equal(t, tc.except, tc.data)
})
}
}
func TestHostsInGroup(t *testing.T) {
testcases := []struct {
name string

View File

@ -1,6 +1,7 @@
package variable
import (
"net"
"regexp"
"slices"
"strconv"
@ -8,6 +9,7 @@ import (
"github.com/cockroachdb/errors"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/klog/v2"
_const "github.com/kubesphere/kubekey/v4/pkg/const"
"github.com/kubesphere/kubekey/v4/pkg/converter/tmpl"
@ -40,7 +42,7 @@ var GetHostnames = func(name []string) GetFunc {
for gn, gv := range ConvertGroup(vv.value.Inventory) {
if gn == n {
if gvd, ok := gv.([]string); ok {
hs = mergeSlice(hs, gvd)
hs = CombineSlice(hs, gvd)
}
break
@ -77,6 +79,30 @@ var GetHostnames = func(name []string) GetFunc {
// GetAllVariable get all variable for a given host
var GetAllVariable = func(hostname string) GetFunc {
// getLocalIP get the ipv4 or ipv6 for localhost machine
getLocalIP := func(ipType string) string {
addrs, err := net.InterfaceAddrs()
if err != nil {
klog.ErrorS(err, "get network address error")
}
for _, addr := range addrs {
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipType == _const.VariableIPv4 && ipNet.IP.To4() != nil {
return ipNet.IP.String()
}
if ipType == _const.VariableIPv6 && ipNet.IP.To16() != nil && ipNet.IP.To4() == nil {
return ipNet.IP.String()
}
}
}
klog.V(4).Infof("cannot get local %s address", ipType)
return ""
}
// defaultHostVariable set default vars when hostname is "localhost"
defaultHostVariable := func(hostname string, hostVars map[string]any) {
if hostname == _const.VariableLocalHost {

View File

@ -255,3 +255,122 @@ func TestGetWorkdir(t *testing.T) {
})
}
}
func TestCombineVariables(t *testing.T) {
testcases := []struct {
name string
v1 map[string]any
v2 map[string]any
excepted map[string]any
}{
{
name: "primary variables value is empty",
v1: nil,
v2: map[string]any{
"a1": "v1",
},
excepted: map[string]any{
"a1": "v1",
},
},
{
name: "auxiliary variables value is empty",
v1: map[string]any{
"p1": "v1",
},
v2: nil,
excepted: map[string]any{
"p1": "v1",
},
},
{
name: "non-repeat value",
v1: map[string]any{
"p1": "v1",
"p2": map[string]any{
"p21": "v21",
},
},
v2: map[string]any{
"a1": "v1",
},
excepted: map[string]any{
"p1": "v1",
"p2": map[string]any{
"p21": "v21",
},
"a1": "v1",
},
},
{
name: "repeat value",
v1: map[string]any{
"p1": "v1",
"p2": map[string]any{
"p21": "v21",
"p22": "v22",
},
},
v2: map[string]any{
"a1": "v1",
"p1": "v2",
"p2": map[string]any{
"p21": "v22",
"a21": "v21",
},
},
excepted: map[string]any{
"a1": "v1",
"p1": "v2",
"p2": map[string]any{
"p21": "v22",
"a21": "v21",
"p22": "v22",
},
},
},
{
name: "repeat deep value",
v1: map[string]any{
"p1": map[string]string{
"p11": "v11",
},
"p2": map[string]any{
"p21": "v21",
"p22": "v22",
},
},
v2: map[string]any{
"p1": map[string]string{
"p21": "v21",
},
"p2": map[string]any{
"p21": map[string]any{
"p211": "v211",
},
"a21": "v21",
},
},
excepted: map[string]any{
"p1": map[string]any{
"p11": "v11",
"p21": "v21",
},
"p2": map[string]any{
"p21": map[string]any{
"p211": "v211",
},
"p22": "v22",
"a21": "v21",
},
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
v := CombineVariables(tc.v1, tc.v2)
assert.Equal(t, tc.excepted, v)
})
}
}

View File

@ -2,8 +2,7 @@ package variable
import (
"github.com/cockroachdb/errors"
"github.com/kubesphere/kubekey/v4/pkg/converter/tmpl"
"gopkg.in/yaml.v3"
)
// ***************************** MergeFunc ***************************** //
@ -34,8 +33,8 @@ var MergeRemoteVariable = func(data map[string]any, hostname string) MergeFunc {
}
// MergeRuntimeVariable parse variable by specific host and merge to the host.
var MergeRuntimeVariable = func(data map[string]any, hosts ...string) MergeFunc {
if len(data) == 0 || len(hosts) == 0 {
var MergeRuntimeVariable = func(node yaml.Node, hosts ...string) MergeFunc {
if node.IsZero() {
// skip
return emptyMergeFunc
}
@ -52,24 +51,14 @@ var MergeRuntimeVariable = func(data map[string]any, hosts ...string) MergeFunc
if err != nil {
return err
}
cv, ok := curVars.(map[string]any)
ctx, ok := curVars.(map[string]any)
if !ok {
return errors.Errorf("host %s variables type error, expect map[string]any", hostname)
}
parser := func(s string) (string, error) {
return tmpl.ParseFunc(
CombineVariables(data, cv),
s,
func(b []byte) string { return string(b) },
)
}
// parse variable
if err := parseVariable(data, parser); err != nil {
data, err := parseYamlNode(ctx, node)
if err != nil {
return err
}
hv := vv.value.Hosts[hostname]
hv.RuntimeVars = CombineVariables(hv.RuntimeVars, data)
vv.value.Hosts[hostname] = hv
@ -80,31 +69,29 @@ var MergeRuntimeVariable = func(data map[string]any, hosts ...string) MergeFunc
}
// MergeAllRuntimeVariable parse variable by specific host and merge to all hosts.
var MergeAllRuntimeVariable = func(data map[string]any, hostname string) MergeFunc {
var MergeAllRuntimeVariable = func(node yaml.Node, hostname string) MergeFunc {
if node.IsZero() {
// skip
return emptyMergeFunc
}
return func(v Variable) error {
vv, ok := v.(*variable)
if !ok {
return errors.New("variable type error")
}
// Avoid nested locking: prepare context for parsing outside locking region
curVars, err := v.Get(GetAllVariable(hostname))
if err != nil {
return err
}
cv, ok := curVars.(map[string]any)
ctx, ok := curVars.(map[string]any)
if !ok {
return errors.Errorf("host %s variables type error, expect map[string]any", hostname)
}
parser := func(s string) (string, error) {
return tmpl.ParseFunc(
CombineVariables(data, cv),
s,
func(b []byte) string { return string(b) },
)
}
if err := parseVariable(data, parser); err != nil {
data, err := parseYamlNode(ctx, node)
if err != nil {
return err
}
for h := range vv.value.Hosts {

View File

@ -5,6 +5,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/kubesphere/kubekey/v4/pkg/converter"
"github.com/kubesphere/kubekey/v4/pkg/variable/source"
)
@ -94,10 +95,13 @@ func TestMergeRuntimeVariable(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := tc.variable.Merge(MergeRuntimeVariable(tc.data, tc.hostname))
node, err := converter.ConvertMap2Node(tc.data)
if err != nil {
t.Fatal(err)
}
if err := tc.variable.Merge(MergeRuntimeVariable(node, tc.hostname)); err != nil {
t.Fatal(err)
}
assert.Equal(t, tc.except, *tc.variable.value)
})
@ -146,10 +150,13 @@ func TestMergeAllRuntimeVariable(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := tc.variable.Merge(MergeAllRuntimeVariable(tc.data, tc.hostname))
node, err := converter.ConvertMap2Node(tc.data)
if err != nil {
t.Fatal(err)
}
if err := tc.variable.Merge(MergeAllRuntimeVariable(node, tc.hostname)); err != nil {
t.Fatal(err)
}
assert.Equal(t, tc.except, *tc.variable.value)
})

303
pkg/variable/yaml.go Normal file
View File

@ -0,0 +1,303 @@
// Package variable provides functionality for handling variables in YAML format.
package variable
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/cockroachdb/errors"
kkprojectv1 "github.com/kubesphere/kubekey/api/project/v1"
"gopkg.in/yaml.v3"
"github.com/kubesphere/kubekey/v4/pkg/converter/tmpl"
)
// YAML tag constants used for parsing different types of nodes
const (
nullTag = "!!null" // Represents null/nil values
boolTag = "!!bool" // Boolean values
strTag = "!!str" // String values
intTag = "!!int" // Integer values
floatTag = "!!float" // Floating point values
timestampTag = "!!timestamp" // Timestamp values
seqTag = "!!seq" // Sequence/array values
mapTag = "!!map" // Map/object values
binaryTag = "!!binary" // Binary data
mergeTag = "!!merge" // Merge key indicator
)
// parseYamlNode parses a YAML node into a map[string]any.
// It handles both document nodes and other node types.
func parseYamlNode(ctx map[string]any, node yaml.Node) (map[string]any, error) {
// parse node
switch node.Kind {
case yaml.DocumentNode:
for _, dn := range node.Content {
if err := processNode(ctx, dn); err != nil {
return nil, err
}
}
default:
if err := processNode(ctx, &node); err != nil {
return nil, err
}
}
var result map[string]any
return result, errors.Wrap(node.Decode(&result), "failed to decode node to map")
}
// processNode recursively processes a YAML node and updates the context map.
// It handles mapping nodes (objects), sequence nodes (arrays), and scalar nodes (values).
func processNode(ctx map[string]any, node *yaml.Node, path ...string) error {
switch node.Kind {
case yaml.MappingNode:
if len(node.Content)%2 != 0 {
return errors.New("mapping node has odd number of content nodes")
}
for i := 0; i < len(node.Content); i += 2 {
keyNode := node.Content[i]
valueNode := node.Content[i+1]
if keyNode.Kind != yaml.ScalarNode {
return errors.New("map key must be scalar")
}
newPath := append(path, mapTag+keyNode.Value)
if err := processNode(ctx, valueNode, newPath...); err != nil {
return err
}
}
case yaml.SequenceNode:
for i, item := range node.Content {
elemPath := append(path, fmt.Sprintf("%s%d", seqTag, i))
if err := processNode(ctx, item, elemPath...); err != nil {
return err
}
}
case yaml.ScalarNode:
value, err := parseScalarValue(ctx, node)
if err != nil {
return err
}
// set context value
if err := setContextValue(ctx, value, path...); err != nil {
return err
}
default:
return errors.Errorf("unsupported node kind: %d", node.Kind)
}
return nil
}
// parseScalarValue parses a scalar YAML node into its corresponding Go value.
// It handles null, boolean, string, integer, and float values.
func parseScalarValue(ctx map[string]any, node *yaml.Node) (any, error) {
switch node.Tag {
case nullTag:
return nil, nil
case boolTag:
v, err := strconv.ParseBool(node.Value)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse %q to bool", node.Value)
}
return v, nil
case strTag, "":
if kkprojectv1.IsTmplSyntax(node.Value) {
pv, err := tmpl.ParseFunc(ctx, node.Value, func(b []byte) string { return string(b) })
if err != nil {
return nil, err
}
// change node value
node.Value = pv
return pv, nil
}
return node.Value, nil
case intTag:
v, err := strconv.Atoi(node.Value)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse %q to int", node.Value)
}
return int64(v), nil
case floatTag:
v, err := strconv.ParseFloat(node.Value, 64)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse %q to float", node.Value)
}
return float64(v), nil
default:
return node.Value, nil
}
}
// setContextValue sets a value in the context map at the specified path.
// The path is a sequence of tags and values that describe the location in the nested structure.
func setContextValue(ctx map[string]any, value any, path ...string) error {
current := reflect.ValueOf(ctx)
var parents []reflect.Value
var keys []any
for i := range path {
tag, val := path[i][:5], path[i][5:] // Split into tag and value
isLast := i == len(path)-1
// Handle interface values
current = derefInterface(current)
var err error
switch tag {
case mapTag:
current, err = handleMap(current, val, isLast, value, &parents, &keys, path, i)
case seqTag:
current, err = handleSlice(current, val, isLast, value, &parents, &keys, path, i)
default:
return fmt.Errorf("unsupported tag: %s", tag)
}
if err != nil {
return err
}
if isLast {
return updateParents(parents, keys, current)
}
}
return nil
}
// handleMap handles setting or creating map values during context updates.
// It manages the creation of new map entries and handles both terminal and non-terminal path segments.
func handleMap(current reflect.Value, key string, isLast bool, value any,
parents *[]reflect.Value, keys *[]any, path []string, i int) (reflect.Value, error) {
if current.Kind() != reflect.Map {
return reflect.Value{}, fmt.Errorf("expected map, got %s", current.Kind())
}
rKey := reflect.ValueOf(key)
existing := current.MapIndex(rKey)
if isLast {
if value == nil {
current.SetMapIndex(rKey, reflect.Zero(reflect.TypeOf((*any)(nil)).Elem()))
} else {
current.SetMapIndex(rKey, reflect.ValueOf(value))
}
return current, nil
}
// Get or create next value
var next reflect.Value
if !existing.IsValid() || isNil(existing) {
if i+1 >= len(path) {
return reflect.Value{}, fmt.Errorf("path incomplete after index %d", i)
}
if strings.HasPrefix(path[i+1], mapTag) {
next = reflect.ValueOf(make(map[string]any))
} else if strings.HasPrefix(path[i+1], seqTag) {
next = reflect.ValueOf([]any{})
} else {
next = reflect.Zero(reflect.TypeOf((*any)(nil)).Elem())
}
current.SetMapIndex(rKey, next)
} else {
next = derefInterface(existing)
}
*parents = append(*parents, current)
*keys = append(*keys, key)
return next, nil
}
// handleSlice handles setting or creating slice values during context updates.
// It manages slice growth, element creation, and handles both terminal and non-terminal path segments.
func handleSlice(current reflect.Value, val string, isLast bool, value any,
parents *[]reflect.Value, keys *[]any, path []string, i int) (reflect.Value, error) {
// Parse index from path value
index, err := strconv.Atoi(val)
if err != nil {
return reflect.Value{}, fmt.Errorf("invalid index %s: %w", val, err)
}
if current.Kind() != reflect.Slice {
return reflect.Value{}, fmt.Errorf("expected slice, got %s", current.Kind())
}
// Grow slice if requested index is beyond current length
if index >= current.Len() {
newLen := index + 1
newSlice := reflect.MakeSlice(current.Type(), newLen, newLen)
reflect.Copy(newSlice, current)
current = newSlice
if err := updateParents(*parents, *keys, current); err != nil {
return reflect.Value{}, err
}
}
// Handle setting the final value
if isLast {
if value == nil {
current.Index(index).Set(reflect.Zero(reflect.TypeOf((*any)(nil)).Elem()))
} else {
current.Index(index).Set(reflect.ValueOf(value))
}
return current, nil
}
// Get or initialize nested value
item := current.Index(index)
if isNil(item) {
if i+1 >= len(path) {
return reflect.Value{}, fmt.Errorf("path incomplete after index %d", i)
}
var newItem reflect.Value
switch {
case strings.HasPrefix(path[i+1], mapTag):
newItem = reflect.ValueOf(make(map[string]any))
case strings.HasPrefix(path[i+1], seqTag):
newItem = reflect.ValueOf([]any{})
default:
newItem = reflect.Zero(reflect.TypeOf((*any)(nil)).Elem())
}
current.Index(index).Set(newItem)
item = newItem
}
item = derefInterface(item)
*parents = append(*parents, current)
*keys = append(*keys, index)
return item, nil
}
// derefInterface dereferences an interface value to get its underlying value.
// If the value is not an interface, it returns the original value.
func derefInterface(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface {
return v.Elem()
}
return v
}
// updateParents updates all parent containers after modifying a nested value.
// It walks back up the parent chain, updating each container with the modified value.
func updateParents(parents []reflect.Value, keys []any, value reflect.Value) error {
for i := len(parents) - 1; i >= 0; i-- {
parent := parents[i]
key := keys[i]
switch parent.Kind() {
case reflect.Map:
k := reflect.ValueOf(key)
parent.SetMapIndex(k, value)
case reflect.Slice:
idx, ok := key.(int)
if !ok {
return errors.Errorf("expected int key for slice index, got %T", key)
}
parent.Index(idx).Set(value)
default:
return errors.Errorf("unexpected parent kind: %s", parent.Kind())
}
value = parent
}
return nil
}
// isNil checks if a reflect.Value is nil, handling interface, map and slice types.
// It returns true if the value is invalid or if it's a nil interface, map, or slice.
func isNil(v reflect.Value) bool {
return !v.IsValid() || (v.Kind() == reflect.Interface || v.Kind() == reflect.Map || v.Kind() == reflect.Slice) && v.IsNil()
}

122
pkg/variable/yaml_test.go Normal file
View File

@ -0,0 +1,122 @@
package variable
import (
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestProcessNode(t *testing.T) {
testcases := []struct {
name string
yamlStr string
expected map[string]any
}{
{
name: "scalar value of map",
yamlStr: `
name: alice
age: 30
`,
expected: map[string]any{
"name": "alice",
"age": int64(30),
},
},
{
name: "map value of map",
yamlStr: `
user:
name: alice
age: 30
`,
expected: map[string]any{
"user": map[string]any{
"name": "alice",
"age": int64(30),
},
},
},
{
name: "scalar value of sequence",
yamlStr: `
user:
- alice
- carol
`,
expected: map[string]any{
"user": []any{"alice", "carol"},
},
},
{
name: "map value of sequence",
yamlStr: `
user:
- name: carol
`,
expected: map[string]any{
"user": []any{map[string]any{"name": "carol"}},
},
},
{
name: "sequence of sequences",
yamlStr: `
matrix:
- [1, 2, 3]
- [4, 5, 6]
`,
expected: map[string]any{
"matrix": []any{
[]any{int64(1), int64(2), int64(3)},
[]any{int64(4), int64(5), int64(6)},
},
},
},
{
name: "deeply nested map and sequence",
yamlStr: `
app:
name: myapp
env:
- dev
- staging
- prod
config:
ports:
- 80
- 443
`,
expected: map[string]any{
"app": map[string]any{
"name": "myapp",
"env": []any{"dev", "staging", "prod"},
"config": map[string]any{
"ports": []any{int64(80), int64(443)},
},
},
},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
var node yaml.Node
err := yaml.Unmarshal([]byte(tc.yamlStr), &node)
if err != nil {
t.Fatalf("failed to unmarshal YAML: %v", err)
}
if len(node.Content) == 0 {
t.Fatalf("empty YAML content")
}
ctx := make(map[string]any)
if err = processNode(ctx, node.Content[0]); err != nil {
t.Fatalf("processNode failed: %v", err)
}
assert.Equal(t, tc.expected, ctx)
})
}
}