From 66899326e51ad5c44e5b93307f94d4944f789a23 Mon Sep 17 00:00:00 2001 From: phunc20 Date: Sun, 15 May 2022 15:39:25 +0700 Subject: [PATCH] possible typo: candidate => candidiate --- nlpaug/augmenter/char/random.py | 18 +++++++++--------- nlpaug/model/lang_models/language_models.py | 8 ++++---- test/augmenter/char/test_random_char.py | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/nlpaug/augmenter/char/random.py b/nlpaug/augmenter/char/random.py index c22df297..8d28abea 100755 --- a/nlpaug/augmenter/char/random.py +++ b/nlpaug/augmenter/char/random.py @@ -29,23 +29,23 @@ class RandomCharAug(CharAugmenter): :param int aug_word_max: Maximum number of word will be augmented. If None is passed, number of augmentation is calculated via aup_word_p. If calculated result from aug_word_p is smaller than aug_word_max, will use calculated result from aug_word_p. Otherwise, using aug_max. - :param bool include_upper_case: If True, upper case character may be included in augmented data. If `candidiates' + :param bool include_upper_case: If True, upper case character may be included in augmented data. If `candidates' value is provided, this param will be ignored. - :param bool include_lower_case: If True, lower case character may be included in augmented data. If `candidiates' + :param bool include_lower_case: If True, lower case character may be included in augmented data. If `candidates' value is provided, this param will be ignored. - :param bool include_numeric: If True, numeric character may be included in augmented data. If `candidiates' + :param bool include_numeric: If True, numeric character may be included in augmented data. If `candidates' value is provided, this param will be ignored. :param int min_char: If word less than this value, do not draw word for augmentation :param swap_mode: When action is 'swap', you may pass 'adjacent', 'middle' or 'random'. 'adjacent' means swap action only consider adjacent character (within same word). 'middle' means swap action consider adjacent character but not the first and last character of word. 'random' means swap action will be executed without constraint. - :param str spec_char: Special character may be included in augmented data. If `candidiates' + :param str spec_char: Special character may be included in augmented data. If `candidates' value is provided, this param will be ignored. :param list stopwords: List of words which will be skipped from augment operation. :param str stopwords_regex: Regular expression for matching words which will be skipped from augment operation. :param func tokenizer: Customize tokenization process :param func reverse_tokenizer: Customize reverse of tokenization process - :param List candidiates: List of string for augmentation. E.g. ['AAA', '11', '===']. If values is provided, + :param List candidates: List of string for augmentation. E.g. ['AAA', '11', '===']. If values is provided, `include_upper_case`, `include_lower_case`, `include_numeric` and `spec_char` will be ignored. :param str name: Name of this augmenter. @@ -56,7 +56,7 @@ class RandomCharAug(CharAugmenter): def __init__(self, action=Action.SUBSTITUTE, name='RandomChar_Aug', aug_char_min=1, aug_char_max=10, aug_char_p=0.3, aug_word_p=0.3, aug_word_min=1, aug_word_max=10, include_upper_case=True, include_lower_case=True, include_numeric=True, min_char=4, swap_mode='adjacent', spec_char='!@#$%^&*()_+', stopwords=None, - tokenizer=None, reverse_tokenizer=None, verbose=0, stopwords_regex=None, candidiates=None): + tokenizer=None, reverse_tokenizer=None, verbose=0, stopwords_regex=None, candidates=None): super().__init__( action=action, name=name, min_char=min_char, aug_char_min=aug_char_min, aug_char_max=aug_char_max, aug_char_p=aug_char_p, aug_word_min=aug_word_min, aug_word_max=aug_word_max, aug_word_p=aug_word_p, @@ -68,7 +68,7 @@ def __init__(self, action=Action.SUBSTITUTE, name='RandomChar_Aug', aug_char_min self.include_numeric = include_numeric self.swap_mode = swap_mode self.spec_char = spec_char - self.candidiates = candidiates + self.candidates = candidates self.model = self.get_model() @@ -248,8 +248,8 @@ def delete(self, data): return self.reverse_tokenizer(doc.get_augmented_tokens()) def get_model(self): - if self.candidiates: - return self.candidiates + if self.candidates: + return self.candidates candidates = [] if self.include_upper_case: diff --git a/nlpaug/model/lang_models/language_models.py b/nlpaug/model/lang_models/language_models.py index 86db2cbd..c0b66aab 100755 --- a/nlpaug/model/lang_models/language_models.py +++ b/nlpaug/model/lang_models/language_models.py @@ -13,7 +13,7 @@ class LanguageModels: OPTIMIZE_ATTRIBUTES = ['external_memory', 'return_proba'] - def __init__(self, device='cpu', model_type='', temperature=1.0, top_k=100, top_p=0.01, batch_size=32, + def __init__(self, device='cpu', model_type='', temperature=1.0, top_k=100, top_p=0.01, batch_size=32, optimize=None, silence=True): try: import torch @@ -60,7 +60,7 @@ def clean(self, text): def predict(self, text, target_word=None, n=1): raise NotImplementedError - # for HuggingFace pipeline + # for HuggingFace pipeline def convert_device(self, device): if device == 'cpu' or device is None: return -1 @@ -158,7 +158,7 @@ def filtering(self, logits, seed): def pick(self, logits, idxes, target_word, n=1, include_punctuation=False): candidate_ids, candidate_probas = self.prob_multinomial(logits, n=n*10) candidate_ids = [idxes[candidate_id] for candidate_id in candidate_ids] - results = self.get_candidiates(candidate_ids, candidate_probas, target_word, n, + results = self.get_candidates(candidate_ids, candidate_probas, target_word, n, include_punctuation) return results @@ -183,7 +183,7 @@ def prob_multinomial(self, logits, n): def is_skip_candidate(self, candidate): return False - def get_candidiates(self, candidate_ids, candidate_probas, target_word=None, n=1, + def get_candidates(self, candidate_ids, candidate_probas, target_word=None, n=1, include_punctuation=False): # To have random behavior, NO sorting for candidate_probas. results = [] diff --git a/test/augmenter/char/test_random_char.py b/test/augmenter/char/test_random_char.py index 5c1d788b..c379332e 100755 --- a/test/augmenter/char/test_random_char.py +++ b/test/augmenter/char/test_random_char.py @@ -135,15 +135,15 @@ def test_swap_random(self): self.assertNotEqual(text, augmented_text) self.assertEqual(len(augmented_text), len(text)) - def test_candidiates(self): - candidiates = ['AAA', '11', '===', '中文'] + def test_candidates(self): + candidates = ['AAA', '11', '===', '中文'] text = 'quick brown jumps over lazy' - aug = RandomCharAug(min_char=4, candidiates=candidiates) + aug = RandomCharAug(min_char=4, candidates=candidates) augmented_text = aug.augment(text) self.assertNotEqual(text, augmented_text) match = False - for c in candidiates: + for c in candidates: if c in augmented_text: match = True break