From b3c14c55f4ae651e9d375888d13f284fb0d0c2eb Mon Sep 17 00:00:00 2001 From: TheDude Date: Sat, 28 Dec 2024 23:30:41 +0530 Subject: [PATCH] Ref secrets loading method --- CHANGELOG.md | 1 + superduper/misc/files.py | 37 +++++++++++++++++++++++-------------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1b5b671d..56e9e58a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add postprocess in apibase model. - Fallback for ibis drop table - Add create events waiting on db apply. +- Refactor secrets loading method. #### New Features & Functionality diff --git a/superduper/misc/files.py b/superduper/misc/files.py index f970699f9..69840339b 100644 --- a/superduper/misc/files.py +++ b/superduper/misc/files.py @@ -5,22 +5,31 @@ def load_secrets(): - """Help method to load secrets from directory.""" + """Load secrets directory into env vars.""" secrets_dir = CFG.secrets_volume + if not os.path.isdir(secrets_dir): - raise ValueError(f"The path '{secrets_dir}' is not a valid directory.") - - for root, _, files in os.walk(secrets_dir): - for file_name in files: - file_path = os.path.join(root, file_name) - try: - with open(file_path, 'r') as file: - content = file.read().strip() - - key = file_name - os.environ[key] = content - except Exception as e: - print(f"Error reading file {file_path}: {e}") + raise ValueError(f"The path '{secrets_dir}' is not a valid secrets directory.") + + for key_dir in os.listdir(secrets_dir): + key_path = os.path.join(secrets_dir, key_dir) + + if not os.path.isdir(key_path): + continue + + secret_file_path = os.path.join(key_path, 'secret_string') + + if not os.path.isfile(secret_file_path): + print(f"Warning: No 'secret_string' file found in {key_path}.") + continue + + try: + with open(secret_file_path, 'r') as file: + content = file.read().strip() + + os.environ[key_dir] = content + except Exception as e: + print(f"Error reading file {secret_file_path}: {e}") def get_file_from_uri(uri):