-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlabel_encoder.py
37 lines (27 loc) · 1.16 KB
/
label_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# -*- coding: utf-8 -*-
from collections import Counter
class LabelEncoder(object):
def __init__(self):
self.__values = {}
self.__indices = {}
def fit(self, labels, reserved_labels=[], min_freq=1):
assert not self.__indices, "This {} instance has already fitted.".format(
__name__
)
sorted_freq_table = sorted(Counter(labels).items(), key=lambda v: (-v[1], v[0]))
for label in reserved_labels:
self.__values[len(self.__values)] = label
for k, v in sorted_freq_table:
if v >= min_freq:
self.__values[len(self.__values)] = k
self.__indices = {v: k for k, v in self.__values.items()}
def transform(self, label, unknown_label=None):
assert self.__indices, "This {} instance is not fitted yet.".format(__name__)
if label in self.__indices:
return self.__indices[label]
return self.__indices[unknown_label]
def inverse_transform(self, _id):
assert self.__indices, "This {} instance is not fitted yet.".format(__name__)
return self.__values[_id]
def __len__(self):
return len(self.__values)