Skip to content

Commit

Permalink
Merge pull request #169 from databricks-industry-solutions/feature/SF…
Browse files Browse the repository at this point in the history
…E-3483v2

Feature/sfe 3483v2
  • Loading branch information
arunpamulapati authored Oct 22, 2024
2 parents 061c031 + e36807a commit d1b40ef
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
12 changes: 9 additions & 3 deletions src/securityanalysistoolproject/clientpkgs/ws_settings_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
'''Workspace settings module'''
from core.dbclient import SatDBClient
import json
from core.logging_utils import LoggingUtils


LOGGR=None

if LOGGR is None:
LOGGR = LoggingUtils.get_logger()

class WSSettingsClient(SatDBClient):
'''workspace setting helper'''

def get_wssettings_list(self):
"""
Returns an array of json objects for workspace settings.
Expand Down Expand Up @@ -50,7 +56,6 @@ def get_wssettings_list(self):
{"name": "enableLibraryAndInitScriptOnSharedCluster", "defn":"Enable libraries and init scripts on shared Unity Catalog clusters"}
]
# pylint: enable=line-too-long

for keyn in ws_keymap:
valn={}
try:
Expand All @@ -62,7 +67,8 @@ def get_wssettings_list(self):
valins = {}
valins['name']=keyn['name']
valins['defn']=keyn['defn']
valins['value']=None if valn[keyn['name']] is None else valn[keyn['name']]
#fixed feature/SFE-3483
valins['value']=None if keyn['name'] not in valn or valn[keyn['name']] is None else valn[keyn['name']]
all_result.append(valins)
return all_result

Expand Down
34 changes: 22 additions & 12 deletions src/securityanalysistoolproject/core/dbclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, inp_configs):
self._use_mastercreds = pars.str2bool(configs['use_mastercreds'])
else:
self._use_mastercreds = configs['use_mastercreds']

self._proxies = configs.get('proxies', {})

#Azure
if 'azure' in self._cloud_type:
Expand Down Expand Up @@ -130,22 +130,32 @@ def _update_token(self):
"User-Agent": "databricks-sat/0.1.0"
}
return None


def get_temporary_oauth_token(self):
self._update_token()
if self._token is None:
return None
temptok = self._token.get("Authorization","").split(' ')
if(len(temptok) > 1):
return (temptok[1])
else:
return None


def test_connection(self, master_acct=False):
'''test connection to workspace and master account'''
if master_acct: #master acct may use a different credential
self._update_token_master()
if (self._cloud_type == 'azure'):
results = requests.get(f'{self._url}/subscriptions/{self._subscription_id}/providers/Microsoft.Databricks/workspaces?api-version=2018-04-01',
headers=self._token, timeout=60)
headers=self._token, timeout=60, proxies=self._proxies)
else:
results = requests.get(f'{self._url}/api/2.0/accounts/{self._account_id}/workspaces',
headers=self._token, timeout=60)
headers=self._token, timeout=60, proxies=self._proxies)
else:
self._update_token()
results = requests.get(f'{self._url}/api/2.0/clusters/spark-versions',
headers=self._token, timeout=60)
headers=self._token, timeout=60, proxies=self._proxies)
http_status_code = results.status_code
if http_status_code != 200:
LOGGR.info("Error. Either the credentials have expired or the \
Expand Down Expand Up @@ -173,9 +183,9 @@ def get(self, endpoint, json_params=None, version='2.0', master_acct=False):
LOGGR.debug(f"Get: {full_endpoint}")

if json_params:
raw_results = requests.get(full_endpoint, headers=self._token, params=json_params, timeout=60)
raw_results = requests.get(full_endpoint, headers=self._token, params=json_params, timeout=60, proxies=self._proxies)
else:
raw_results = requests.get(full_endpoint, headers=self._token, timeout=60)
raw_results = requests.get(full_endpoint, headers=self._token, timeout=60, proxies=self._proxies)

http_status_code = raw_results.status_code
if http_status_code in SatDBClient.http_error_codes:
Expand Down Expand Up @@ -203,16 +213,16 @@ def http_req(self, http_type, endpoint, json_params, version='2.0', files_json=N
if http_type == 'post':
if files_json:
raw_results = requests.post(full_endpoint, headers=self._token,
data=json_params, files=files_json, timeout=60)
data=json_params, files=files_json, timeout=60, proxies=self._proxies)
else:
raw_results = requests.post(full_endpoint, headers=self._token,
json=json_params, timeout=60)
json=json_params, timeout=60, proxies=self._proxies)
if http_type == 'put':
raw_results = requests.put(full_endpoint, headers=self._token,
json=json_params, timeout=60)
json=json_params, timeout=60, proxies=self._proxies)
if http_type == 'patch':
raw_results = requests.patch(full_endpoint, headers=self._token,
json=json_params, timeout=60)
json=json_params, timeout=60, proxies=self._proxies)
else:
LOGGR.info("Must have a payload in json_args param.")
return {}
Expand Down Expand Up @@ -416,7 +426,7 @@ def getAWSTokenwithOAuth(self, baccount, client_id, client_secret):
full_endpoint = f'{self._raw_url}/oidc/v1/token'

response = requests.post(full_endpoint, headers=oidc_token,
auth=user_pass, data=json_params, timeout=60)
auth=user_pass, data=json_params, timeout=60, proxies=self._proxies)

if response is not None and response.status_code == 200:
return response.json()['access_token']
Expand Down
2 changes: 1 addition & 1 deletion src/securityanalysistoolproject/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

DESCRIPTION = "Databricks Security Analysis Tool"

__version__ = "0.1.34"
__version__ = "0.1.37"

this_directory = path.abspath(path.dirname(__file__))
with open(path.join(this_directory, "README.md"), encoding="utf-8") as f:
Expand Down

0 comments on commit d1b40ef

Please sign in to comment.