-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathglobal_functions.py
69 lines (63 loc) · 2.92 KB
/
global_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
# File with most used
# functions and derived
# variables.
import numpy as np
import pandas as pd
from pycaret import classification as pyc
from pycaret import regression as pyr
import os
import subprocess
def predict_AGN_gal(catalog_df,
AGN_gal_model,
cal_AGN_gal_model,
threshold,
cal_threshold,
raw_score=True):
catalog_df = pyc.predict_model(AGN_gal_model,
data=catalog_df,
probability_threshold=threshold,
raw_score=raw_score,
round=10)
catalog_df = catalog_df.drop(columns=['Score_0'])
catalog_df = catalog_df.rename(columns={'Label': 'pred_class', 'Score_1': 'Score_AGN'})
catalog_df.loc[:, 'Score_AGN'] = np.around(catalog_df.loc[:, 'Score_AGN'], decimals=8)
pred_probs = cal_AGN_gal_model.predict(catalog_df.loc[:, 'Score_AGN'])
cal_class = np.array(pred_probs >= cal_threshold).astype(int)
catalog_df['Prob_AGN'] = pred_probs
catalog_df['pred_class_cal'] = cal_class
return catalog_df
def predict_radio_det(catalog_df,
radio_model,
cal_radio_model,
threshold,
cal_threshold,
raw_score=True):
catalog_df = pyc.predict_model(radio_model,
data=catalog_df,
probability_threshold=threshold,
raw_score=raw_score,
round=10)
catalog_df = catalog_df.drop(columns=['Score_0'])
catalog_df = catalog_df.rename(columns={'Label': 'pred_radio_AGN', 'Score_1': 'Score_radio_AGN'})
catalog_df.loc[:, 'Score_radio_AGN'] = np.around(catalog_df.loc[:, 'Score_radio_AGN'], decimals=8)
pred_probs = cal_radio_model.predict(catalog_df.loc[:, 'Score_radio_AGN'])
cal_class = np.array(pred_probs >= cal_threshold).astype(int)
catalog_df['Prob_radio_AGN'] = pred_probs
catalog_df['pred_radio_cal_AGN'] = cal_class
return catalog_df
def predict_z(catalog_df,
redshift_model):
catalog_df = pyr.predict_model(redshift_model,
data=catalog_df,
round=10)
catalog_df = catalog_df.rename(columns={'Label': 'pred_Z_rAGN'})
catalog_df.loc[:, 'pred_Z_rAGN'] = np.around(catalog_df.loc[:, 'pred_Z_rAGN'], decimals=4)
return catalog_df
def download_from_zenodo(file_name, output_path):
if not os.path.exists(file_name):
subprocess.run(['wget',
'https://zenodo.org/records/10220009/files/' + file_name, '-O', output_path + file_name])
print(f'File {file_name} has been downloaded')
else:
print(f'File {file_name} has been already downloaded')