diff --git a/raster_loader/cli/snowflake.py b/raster_loader/cli/snowflake.py index 98b0255..37beeeb 100644 --- a/raster_loader/cli/snowflake.py +++ b/raster_loader/cli/snowflake.py @@ -53,6 +53,7 @@ def snowflake(args=None): default=None, ) @click.option("--role", help="The role to use for the file upload.", default=None) +@click.option("--warehouse", help="Name of the default warehouse to use.", default=None) @click.option( "--file_path", help="The path to the raster file.", required=False, default=None ) @@ -119,6 +120,7 @@ def upload( private_key_path, private_key_passphrase, role, + warehouse, file_path, file_url, database, @@ -189,6 +191,7 @@ def upload( database=database, schema=schema, role=role, + warehouse=warehouse, ) source = file_path if is_local_file else file_url @@ -254,6 +257,7 @@ def upload( default=None, ) @click.option("--role", help="The role to use for the file upload.", default=None) +@click.option("--warehouse", help="Name of the default warehouse to use.", default=None) @click.option("--database", help="The name of the database.", required=True) @click.option("--schema", help="The name of the schema.", required=True) @click.option("--table", help="The name of the table.", required=True) @@ -266,6 +270,7 @@ def describe( private_key_path, private_key_passphrase, role, + warehouse, database, schema, table, @@ -298,6 +303,7 @@ def describe( database=database, schema=schema, role=role, + warehouse=warehouse, ) df = connector.get_records(fqn, limit) print(f"Table: {fqn}") diff --git a/raster_loader/io/snowflake.py b/raster_loader/io/snowflake.py index 80c60f0..dc06746 100644 --- a/raster_loader/io/snowflake.py +++ b/raster_loader/io/snowflake.py @@ -44,6 +44,7 @@ def __init__( private_key_path, private_key_passphrase, role, + warehouse, ): if not _has_snowflake: import_error_snowflake() @@ -57,6 +58,7 @@ def __init__( database=database.upper(), schema=schema.upper(), role=role.upper() if role is not None else None, + warehouse=warehouse, ) elif private_key_path is not None: self.client = snowflake.connector.connect( @@ -68,6 +70,7 @@ def __init__( database=database.upper(), schema=schema.upper(), role=role.upper() if role is not None else None, + warehouse=warehouse, ) else: self.client = snowflake.connector.connect( @@ -77,6 +80,7 @@ def __init__( database=database.upper(), schema=schema.upper(), role=role.upper() if role is not None else None, + warehouse=warehouse, ) def band_rename_function(self, band_name: str):