-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpostprocessing.py
67 lines (54 loc) · 2.02 KB
/
postprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from cgitb import reset
from pathlib import Path
import csv
import pandas as pd
def get_age_group(age):
"""Convert numerical age to standardized age group format."""
age = float(age)
if age <= 24:
return 'xx-24'
elif age <= 34:
return '25-34'
elif age <= 49:
return '35-49'
else:
return '50-xx'
def write_xml(path: Path, data: pd.DataFrame):
if not os.path.exists(path):
os.mkdir(path)
data.reset_index(inplace=True)
print(data.columns)
for row in data.iterrows():
row_to_xml(row[1], path)
def row_to_xml(row: pd.Series, path: Path):
xml_string = (f"<user id=\"{row['userid']}\" "
f"age_group=\"{row['age_range']}\" "
f"gender=\"{'male' if row['gender'] == 0 else 'female'}\" "
f"extrovert=\"{row['ext']}\" "
f"neurotic=\"{row['neu']}\" "
f"agreeable=\"{row['agr']}\" "
f"conscientiousness=\"{row['con']}\" "
f"open=\"{row['ope']}\" />")
print(xml_string)
with open(f"{path}/{row['userid']}.xml", "x") as f:
f.write(xml_string)
def export_confusion_matrices(results, output_dir='./data/output/'):
os.makedirs(output_dir, exist_ok=True)
created_files = []
for pred_type, data in results.items():
if data is None or not isinstance(data, dict):
print(f"Warning: Skipping {pred_type} - invalid data format")
continue
cm_df = data.get('confusion_matrix')
if not isinstance(cm_df, pd.DataFrame):
print(f"Warning: No valid DataFrame for {pred_type}")
continue
try:
filename = f'{pred_type}_confusion_matrix.csv'
filepath = os.path.join(output_dir, filename)
cm_df.to_csv(filepath, index=True)
created_files.append(filepath)
except Exception as e:
print(f"Error saving confusion matrix for {pred_type}: {e}")
return created_files