From 528d459afdd1283c38582d423a5d8e506efc38e6 Mon Sep 17 00:00:00 2001 From: Homayoon Alimohammadi Date: Fri, 8 Nov 2024 18:36:25 +0400 Subject: [PATCH] Refactor file writing logic (#778) --- src/k8s/cmd/k8s/k8s_bootstrap_test.go | 3 +- src/k8s/cmd/k8s/k8s_x_capi.go | 5 +- src/k8s/cmd/k8sd/k8sd_cluster_recover.go | 2 +- src/k8s/pkg/docgen/json_struct.go | 4 +- src/k8s/pkg/k8sd/setup/certificates.go | 3 +- src/k8s/pkg/k8sd/setup/containerd.go | 2 +- src/k8s/pkg/k8sd/setup/containerd_test.go | 2 +- src/k8s/pkg/k8sd/setup/k8s_dqlite.go | 2 +- src/k8s/pkg/k8sd/setup/util_extra_files.go | 3 +- src/k8s/pkg/proxy/config.go | 4 +- src/k8s/pkg/snap/util/arguments.go | 2 +- src/k8s/pkg/snap/util/arguments_test.go | 4 +- src/k8s/pkg/utils/file.go | 32 ++++++++ src/k8s/pkg/utils/file_test.go | 92 +++++++++++++++++++++- 14 files changed, 143 insertions(+), 17 deletions(-) diff --git a/src/k8s/cmd/k8s/k8s_bootstrap_test.go b/src/k8s/cmd/k8s/k8s_bootstrap_test.go index 1ab92aa34..ca24ef624 100644 --- a/src/k8s/cmd/k8s/k8s_bootstrap_test.go +++ b/src/k8s/cmd/k8s/k8s_bootstrap_test.go @@ -3,7 +3,6 @@ package k8s import ( "bytes" _ "embed" - "os" "path/filepath" "testing" @@ -109,7 +108,7 @@ var testCases = []testCase{ func mustAddConfigToTestDir(t *testing.T, configPath string, data string) { t.Helper() // Create the cluster bootstrap config file - err := os.WriteFile(configPath, []byte(data), 0o644) + err := utils.WriteFile(configPath, []byte(data), 0o644) if err != nil { t.Fatal(err) } diff --git a/src/k8s/cmd/k8s/k8s_x_capi.go b/src/k8s/cmd/k8s/k8s_x_capi.go index 232132dd9..2c4658fd8 100644 --- a/src/k8s/cmd/k8s/k8s_x_capi.go +++ b/src/k8s/cmd/k8s/k8s_x_capi.go @@ -1,10 +1,9 @@ package k8s import ( - "os" - apiv1 "github.com/canonical/k8s-snap-api/api/v1" cmdutil "github.com/canonical/k8s/cmd/util" + "github.com/canonical/k8s/pkg/utils" "github.com/spf13/cobra" ) @@ -48,7 +47,7 @@ func newXCAPICmd(env cmdutil.ExecutionEnvironment) *cobra.Command { return } - if err := os.WriteFile(env.Snap.NodeTokenFile(), []byte(token), 0o600); err != nil { + if err := utils.WriteFile(env.Snap.NodeTokenFile(), []byte(token), 0o600); err != nil { cmd.PrintErrf("Error: Failed to write the node token to file.\n\nThe error was: %v\n", err) env.Exit(1) return diff --git a/src/k8s/cmd/k8sd/k8sd_cluster_recover.go b/src/k8s/cmd/k8sd/k8sd_cluster_recover.go index 7c1f74838..983aa52b2 100644 --- a/src/k8s/cmd/k8sd/k8sd_cluster_recover.go +++ b/src/k8s/cmd/k8sd/k8sd_cluster_recover.go @@ -346,7 +346,7 @@ func yamlEditorGuide( newContent = removeEmptyLines(newContent) if applyChanges { - err = os.WriteFile(path, newContent, os.FileMode(0o644)) + err = utils.WriteFile(path, newContent, os.FileMode(0o644)) if err != nil { return nil, fmt.Errorf("could not write file: %s, error: %w", path, err) } diff --git a/src/k8s/pkg/docgen/json_struct.go b/src/k8s/pkg/docgen/json_struct.go index 7a65365fe..5dc5e5a67 100644 --- a/src/k8s/pkg/docgen/json_struct.go +++ b/src/k8s/pkg/docgen/json_struct.go @@ -5,6 +5,8 @@ import ( "os" "reflect" "strings" + + "github.com/canonical/k8s/pkg/utils" ) type JsonTag struct { @@ -55,7 +57,7 @@ func MarkdownFromJsonStructToFile(i any, outFilePath string, projectDir string) return err } - err = os.WriteFile(outFilePath, []byte(content), 0o644) + err = utils.WriteFile(outFilePath, []byte(content), 0o644) if err != nil { return fmt.Errorf("failed to write markdown documentation to %s: %w", outFilePath, err) } diff --git a/src/k8s/pkg/k8sd/setup/certificates.go b/src/k8s/pkg/k8sd/setup/certificates.go index 8508d9fc5..c8e67d227 100644 --- a/src/k8s/pkg/k8sd/setup/certificates.go +++ b/src/k8s/pkg/k8sd/setup/certificates.go @@ -8,6 +8,7 @@ import ( "github.com/canonical/k8s/pkg/k8sd/pki" "github.com/canonical/k8s/pkg/snap" + "github.com/canonical/k8s/pkg/utils" ) // ensureFile creates fname with the specified contents, mode and owner bits. @@ -39,7 +40,7 @@ func ensureFile(fname string, contents string, uid, gid int, mode fs.FileMode) ( var contentChanged bool if contents != string(origContent) { - if err := os.WriteFile(fname, []byte(contents), mode); err != nil { + if err := utils.WriteFile(fname, []byte(contents), mode); err != nil { return false, fmt.Errorf("failed to write: %w", err) } contentChanged = true diff --git a/src/k8s/pkg/k8sd/setup/containerd.go b/src/k8s/pkg/k8sd/setup/containerd.go index 15227e6cb..84da35f6d 100644 --- a/src/k8s/pkg/k8sd/setup/containerd.go +++ b/src/k8s/pkg/k8sd/setup/containerd.go @@ -108,7 +108,7 @@ func Containerd(snap snap.Snap, extraContainerdConfig map[string]any, extraArgs return fmt.Errorf("failed to render containerd config.toml: %w", err) } - if err := os.WriteFile(filepath.Join(snap.ContainerdConfigDir(), "config.toml"), b, 0o600); err != nil { + if err := utils.WriteFile(filepath.Join(snap.ContainerdConfigDir(), "config.toml"), b, 0o600); err != nil { return fmt.Errorf("failed to write config.toml: %w", err) } diff --git a/src/k8s/pkg/k8sd/setup/containerd_test.go b/src/k8s/pkg/k8sd/setup/containerd_test.go index 66bfd2714..30693bfba 100644 --- a/src/k8s/pkg/k8sd/setup/containerd_test.go +++ b/src/k8s/pkg/k8sd/setup/containerd_test.go @@ -20,7 +20,7 @@ func TestContainerd(t *testing.T) { dir := t.TempDir() - g.Expect(os.WriteFile(filepath.Join(dir, "mockcni"), []byte("echo hi"), 0o600)).To(Succeed()) + g.Expect(utils.WriteFile(filepath.Join(dir, "mockcni"), []byte("echo hi"), 0o600)).To(Succeed()) s := &mock.Snap{ Mock: mock.Mock{ diff --git a/src/k8s/pkg/k8sd/setup/k8s_dqlite.go b/src/k8s/pkg/k8sd/setup/k8s_dqlite.go index 92fc4812d..1ee6a5cf9 100644 --- a/src/k8s/pkg/k8sd/setup/k8s_dqlite.go +++ b/src/k8s/pkg/k8sd/setup/k8s_dqlite.go @@ -32,7 +32,7 @@ func K8sDqlite(snap snap.Snap, address string, cluster []string, extraArgs map[s return fmt.Errorf("failed to create init.yaml file for address=%s cluster=%v: %w", address, cluster, err) } - if err := os.WriteFile(filepath.Join(snap.K8sDqliteStateDir(), "init.yaml"), b, 0o600); err != nil { + if err := utils.WriteFile(filepath.Join(snap.K8sDqliteStateDir(), "init.yaml"), b, 0o600); err != nil { return fmt.Errorf("failed to write init.yaml: %w", err) } diff --git a/src/k8s/pkg/k8sd/setup/util_extra_files.go b/src/k8s/pkg/k8sd/setup/util_extra_files.go index 2376b5645..163562ea5 100644 --- a/src/k8s/pkg/k8sd/setup/util_extra_files.go +++ b/src/k8s/pkg/k8sd/setup/util_extra_files.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/canonical/k8s/pkg/snap" + "github.com/canonical/k8s/pkg/utils" ) // ExtraNodeConfigFiles writes the file contents to the specified filenames in the snap.ExtraFilesDir directory. @@ -20,7 +21,7 @@ func ExtraNodeConfigFiles(snap snap.Snap, files map[string]string) error { filePath := filepath.Join(snap.ServiceExtraConfigDir(), filename) // Write the content to the file - if err := os.WriteFile(filePath, []byte(content), 0o400); err != nil { + if err := utils.WriteFile(filePath, []byte(content), 0o400); err != nil { return fmt.Errorf("failed to write to file %s: %w", filePath, err) } diff --git a/src/k8s/pkg/proxy/config.go b/src/k8s/pkg/proxy/config.go index 3417016ec..450f2a689 100644 --- a/src/k8s/pkg/proxy/config.go +++ b/src/k8s/pkg/proxy/config.go @@ -5,6 +5,8 @@ import ( "fmt" "os" "sort" + + "github.com/canonical/k8s/pkg/utils" ) // Configuration is the format of the apiserver proxy endpoints config file. @@ -33,7 +35,7 @@ func WriteEndpointsConfig(endpoints []string, file string) error { return fmt.Errorf("failed to marshal configuration: %w", err) } - if err := os.WriteFile(file, b, 0o600); err != nil { + if err := utils.WriteFile(file, b, 0o600); err != nil { return fmt.Errorf("failed to write configuration file %s: %w", file, err) } return nil diff --git a/src/k8s/pkg/snap/util/arguments.go b/src/k8s/pkg/snap/util/arguments.go index 12d3011be..02aad87fb 100644 --- a/src/k8s/pkg/snap/util/arguments.go +++ b/src/k8s/pkg/snap/util/arguments.go @@ -103,7 +103,7 @@ func UpdateServiceArguments(snap snap.Snap, serviceName string, updateMap map[st // sort arguments so that output is consistent sort.Strings(newArguments) - if err := os.WriteFile(argumentsFile, []byte(strings.Join(newArguments, "\n")+"\n"), 0o600); err != nil { + if err := utils.WriteFile(argumentsFile, []byte(strings.Join(newArguments, "\n")+"\n"), 0o600); err != nil { return false, fmt.Errorf("failed to write arguments for service %s: %w", serviceName, err) } return changed, nil diff --git a/src/k8s/pkg/snap/util/arguments_test.go b/src/k8s/pkg/snap/util/arguments_test.go index 52fe4a1b3..0d23ccd6a 100644 --- a/src/k8s/pkg/snap/util/arguments_test.go +++ b/src/k8s/pkg/snap/util/arguments_test.go @@ -2,12 +2,12 @@ package snaputil_test import ( "fmt" - "os" "path/filepath" "testing" "github.com/canonical/k8s/pkg/snap/mock" snaputil "github.com/canonical/k8s/pkg/snap/util" + "github.com/canonical/k8s/pkg/utils" . "github.com/onsi/gomega" ) @@ -32,7 +32,7 @@ func TestGetServiceArgument(t *testing.T) { --key=value-of-service-two `, } { - g.Expect(os.WriteFile(filepath.Join(dir, svc), []byte(args), 0o600)).To(Succeed()) + g.Expect(utils.WriteFile(filepath.Join(dir, svc), []byte(args), 0o600)).To(Succeed()) } for _, tc := range []struct { diff --git a/src/k8s/pkg/utils/file.go b/src/k8s/pkg/utils/file.go index 1b6728334..63cfad1bf 100644 --- a/src/k8s/pkg/utils/file.go +++ b/src/k8s/pkg/utils/file.go @@ -258,3 +258,35 @@ func CreateTarball(tarballPath string, rootDir string, walkDir string, excludeFi return nil } + +// WriteFile writes data to a file with the given name and permissions. +// The file is written to a temporary file in the same directory as the target file +// and then renamed to the target file to avoid partial writes in case of a crash. +func WriteFile(name string, data []byte, perm fs.FileMode) error { + dir := filepath.Dir(name) + tmpFile, err := os.CreateTemp(dir, "tmp-*") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.Write(data); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to write to temp file: %w", err) + } + + if err := tmpFile.Chmod(perm); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to set permissions on temp file: %w", err) + } + + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + if err := os.Rename(tmpFile.Name(), name); err != nil { + return fmt.Errorf("failed to rename temp file to target file: %w", err) + } + + return nil +} diff --git a/src/k8s/pkg/utils/file_test.go b/src/k8s/pkg/utils/file_test.go index 3efa8da76..db6437036 100644 --- a/src/k8s/pkg/utils/file_test.go +++ b/src/k8s/pkg/utils/file_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "sync" "testing" "github.com/canonical/k8s/pkg/utils" @@ -88,7 +89,7 @@ func TestParseArgumentFile(t *testing.T) { g := NewWithT(t) filePath := filepath.Join(t.TempDir(), tc.name) - err := os.WriteFile(filePath, []byte(tc.content), 0o755) + err := utils.WriteFile(filePath, []byte(tc.content), 0o755) if err != nil { t.Fatalf("Failed to setup testfile: %v", err) } @@ -182,3 +183,92 @@ func TestGetMountPropagationType(t *testing.T) { g.Expect(err).ToNot(HaveOccurred()) g.Expect(mountType).To(Equal(utils.MountPropagationShared)) } + +func TestWriteFile(t *testing.T) { + t.Run("PartialWrites", func(t *testing.T) { + g := NewWithT(t) + + name := filepath.Join(t.TempDir(), "testfile") + + const ( + numWriters = 200 + numIterations = 200 + ) + + var wg sync.WaitGroup + wg.Add(numWriters) + + expContent := "key: value" + expPerm := os.FileMode(0o644) + + for i := 0; i < numWriters; i++ { + go func(writerID int) { + defer wg.Done() + + for j := 0; j < numIterations; j++ { + g.Expect(utils.WriteFile(name, []byte(expContent), expPerm)).To(Succeed()) + + content, err := os.ReadFile(name) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(string(content)).To(Equal(expContent)) + + fileInfo, err := os.Stat(name) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(fileInfo.Mode().Perm()).To(Equal(expPerm)) + } + }(i) + } + + wg.Wait() + }) + + tcs := []struct { + name string + expContent []byte + expPerm os.FileMode + }{ + { + name: "test1", + expContent: []byte("key: value"), + expPerm: os.FileMode(0o644), + }, + { + name: "test2", + expContent: []byte(""), + expPerm: os.FileMode(0o600), + }, + { + name: "test3", + expContent: []byte("key: value"), + expPerm: os.FileMode(0o755), + }, + { + name: "test4", + expContent: []byte("key: value"), + expPerm: os.FileMode(0o777), + }, + { + name: "test5", + expContent: []byte("key: value"), + expPerm: os.FileMode(0o400), + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + name := filepath.Join(t.TempDir(), tc.name) + + g.Expect(utils.WriteFile(name, tc.expContent, tc.expPerm)).To(Succeed()) + + content, err := os.ReadFile(name) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(string(content)).To(Equal(string(tc.expContent))) + + fileInfo, err := os.Stat(name) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(fileInfo.Mode().Perm()).To(Equal(tc.expPerm)) + }) + } +}