diff --git a/nlpaug/augmenter/char/random.py b/nlpaug/augmenter/char/random.py index c22df29..8d28abe 100755 --- a/nlpaug/augmenter/char/random.py +++ b/nlpaug/augmenter/char/random.py @@ -29,23 +29,23 @@ class RandomCharAug(CharAugmenter): :param int aug_word_max: Maximum number of word will be augmented. If None is passed, number of augmentation is calculated via aup_word_p. If calculated result from aug_word_p is smaller than aug_word_max, will use calculated result from aug_word_p. Otherwise, using aug_max. - :param bool include_upper_case: If True, upper case character may be included in augmented data. If `candidiates' + :param bool include_upper_case: If True, upper case character may be included in augmented data. If `candidates' value is provided, this param will be ignored. - :param bool include_lower_case: If True, lower case character may be included in augmented data. If `candidiates' + :param bool include_lower_case: If True, lower case character may be included in augmented data. If `candidates' value is provided, this param will be ignored. - :param bool include_numeric: If True, numeric character may be included in augmented data. If `candidiates' + :param bool include_numeric: If True, numeric character may be included in augmented data. If `candidates' value is provided, this param will be ignored. :param int min_char: If word less than this value, do not draw word for augmentation :param swap_mode: When action is 'swap', you may pass 'adjacent', 'middle' or 'random'. 'adjacent' means swap action only consider adjacent character (within same word). 'middle' means swap action consider adjacent character but not the first and last character of word. 'random' means swap action will be executed without constraint. - :param str spec_char: Special character may be included in augmented data. If `candidiates' + :param str spec_char: Special character may be included in augmented data. If `candidates' value is provided, this param will be ignored. :param list stopwords: List of words which will be skipped from augment operation. :param str stopwords_regex: Regular expression for matching words which will be skipped from augment operation. :param func tokenizer: Customize tokenization process :param func reverse_tokenizer: Customize reverse of tokenization process - :param List candidiates: List of string for augmentation. E.g. ['AAA', '11', '===']. If values is provided, + :param List candidates: List of string for augmentation. E.g. ['AAA', '11', '===']. If values is provided, `include_upper_case`, `include_lower_case`, `include_numeric` and `spec_char` will be ignored. :param str name: Name of this augmenter. @@ -56,7 +56,7 @@ class RandomCharAug(CharAugmenter): def __init__(self, action=Action.SUBSTITUTE, name='RandomChar_Aug', aug_char_min=1, aug_char_max=10, aug_char_p=0.3, aug_word_p=0.3, aug_word_min=1, aug_word_max=10, include_upper_case=True, include_lower_case=True, include_numeric=True, min_char=4, swap_mode='adjacent', spec_char='!@#$%^&*()_+', stopwords=None, - tokenizer=None, reverse_tokenizer=None, verbose=0, stopwords_regex=None, candidiates=None): + tokenizer=None, reverse_tokenizer=None, verbose=0, stopwords_regex=None, candidates=None): super().__init__( action=action, name=name, min_char=min_char, aug_char_min=aug_char_min, aug_char_max=aug_char_max, aug_char_p=aug_char_p, aug_word_min=aug_word_min, aug_word_max=aug_word_max, aug_word_p=aug_word_p, @@ -68,7 +68,7 @@ def __init__(self, action=Action.SUBSTITUTE, name='RandomChar_Aug', aug_char_min self.include_numeric = include_numeric self.swap_mode = swap_mode self.spec_char = spec_char - self.candidiates = candidiates + self.candidates = candidates self.model = self.get_model() @@ -248,8 +248,8 @@ def delete(self, data): return self.reverse_tokenizer(doc.get_augmented_tokens()) def get_model(self): - if self.candidiates: - return self.candidiates + if self.candidates: + return self.candidates candidates = [] if self.include_upper_case: diff --git a/nlpaug/model/lang_models/language_models.py b/nlpaug/model/lang_models/language_models.py index 86db2cb..c0b66aa 100755 --- a/nlpaug/model/lang_models/language_models.py +++ b/nlpaug/model/lang_models/language_models.py @@ -13,7 +13,7 @@ class LanguageModels: OPTIMIZE_ATTRIBUTES = ['external_memory', 'return_proba'] - def __init__(self, device='cpu', model_type='', temperature=1.0, top_k=100, top_p=0.01, batch_size=32, + def __init__(self, device='cpu', model_type='', temperature=1.0, top_k=100, top_p=0.01, batch_size=32, optimize=None, silence=True): try: import torch @@ -60,7 +60,7 @@ def clean(self, text): def predict(self, text, target_word=None, n=1): raise NotImplementedError - # for HuggingFace pipeline + # for HuggingFace pipeline def convert_device(self, device): if device == 'cpu' or device is None: return -1 @@ -158,7 +158,7 @@ def filtering(self, logits, seed): def pick(self, logits, idxes, target_word, n=1, include_punctuation=False): candidate_ids, candidate_probas = self.prob_multinomial(logits, n=n*10) candidate_ids = [idxes[candidate_id] for candidate_id in candidate_ids] - results = self.get_candidiates(candidate_ids, candidate_probas, target_word, n, + results = self.get_candidates(candidate_ids, candidate_probas, target_word, n, include_punctuation) return results @@ -183,7 +183,7 @@ def prob_multinomial(self, logits, n): def is_skip_candidate(self, candidate): return False - def get_candidiates(self, candidate_ids, candidate_probas, target_word=None, n=1, + def get_candidates(self, candidate_ids, candidate_probas, target_word=None, n=1, include_punctuation=False): # To have random behavior, NO sorting for candidate_probas. results = [] diff --git a/test/augmenter/char/test_random_char.py b/test/augmenter/char/test_random_char.py index 5c1d788..c379332 100755 --- a/test/augmenter/char/test_random_char.py +++ b/test/augmenter/char/test_random_char.py @@ -135,15 +135,15 @@ def test_swap_random(self): self.assertNotEqual(text, augmented_text) self.assertEqual(len(augmented_text), len(text)) - def test_candidiates(self): - candidiates = ['AAA', '11', '===', '中文'] + def test_candidates(self): + candidates = ['AAA', '11', '===', '中文'] text = 'quick brown jumps over lazy' - aug = RandomCharAug(min_char=4, candidiates=candidiates) + aug = RandomCharAug(min_char=4, candidates=candidates) augmented_text = aug.augment(text) self.assertNotEqual(text, augmented_text) match = False - for c in candidiates: + for c in candidates: if c in augmented_text: match = True break