Skip to content

Commit

Permalink
Merge pull request #294 from phunc20/fix/typo_candidiate
Browse files Browse the repository at this point in the history
possible typo: candidate => candidiate
  • Loading branch information
makcedward authored Jul 1, 2022
2 parents 14cf962 + 6689932 commit 487d9c8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
18 changes: 9 additions & 9 deletions nlpaug/augmenter/char/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ class RandomCharAug(CharAugmenter):
:param int aug_word_max: Maximum number of word will be augmented. If None is passed, number of augmentation is
calculated via aup_word_p. If calculated result from aug_word_p is smaller than aug_word_max, will use calculated result
from aug_word_p. Otherwise, using aug_max.
:param bool include_upper_case: If True, upper case character may be included in augmented data. If `candidiates'
:param bool include_upper_case: If True, upper case character may be included in augmented data. If `candidates'
value is provided, this param will be ignored.
:param bool include_lower_case: If True, lower case character may be included in augmented data. If `candidiates'
:param bool include_lower_case: If True, lower case character may be included in augmented data. If `candidates'
value is provided, this param will be ignored.
:param bool include_numeric: If True, numeric character may be included in augmented data. If `candidiates'
:param bool include_numeric: If True, numeric character may be included in augmented data. If `candidates'
value is provided, this param will be ignored.
:param int min_char: If word less than this value, do not draw word for augmentation
:param swap_mode: When action is 'swap', you may pass 'adjacent', 'middle' or 'random'. 'adjacent' means swap action
only consider adjacent character (within same word). 'middle' means swap action consider adjacent character but
not the first and last character of word. 'random' means swap action will be executed without constraint.
:param str spec_char: Special character may be included in augmented data. If `candidiates'
:param str spec_char: Special character may be included in augmented data. If `candidates'
value is provided, this param will be ignored.
:param list stopwords: List of words which will be skipped from augment operation.
:param str stopwords_regex: Regular expression for matching words which will be skipped from augment operation.
:param func tokenizer: Customize tokenization process
:param func reverse_tokenizer: Customize reverse of tokenization process
:param List candidiates: List of string for augmentation. E.g. ['AAA', '11', '===']. If values is provided,
:param List candidates: List of string for augmentation. E.g. ['AAA', '11', '===']. If values is provided,
`include_upper_case`, `include_lower_case`, `include_numeric` and `spec_char` will be ignored.
:param str name: Name of this augmenter.
Expand All @@ -56,7 +56,7 @@ class RandomCharAug(CharAugmenter):
def __init__(self, action=Action.SUBSTITUTE, name='RandomChar_Aug', aug_char_min=1, aug_char_max=10, aug_char_p=0.3,
aug_word_p=0.3, aug_word_min=1, aug_word_max=10, include_upper_case=True, include_lower_case=True,
include_numeric=True, min_char=4, swap_mode='adjacent', spec_char='!@#$%^&*()_+', stopwords=None,
tokenizer=None, reverse_tokenizer=None, verbose=0, stopwords_regex=None, candidiates=None):
tokenizer=None, reverse_tokenizer=None, verbose=0, stopwords_regex=None, candidates=None):
super().__init__(
action=action, name=name, min_char=min_char, aug_char_min=aug_char_min, aug_char_max=aug_char_max,
aug_char_p=aug_char_p, aug_word_min=aug_word_min, aug_word_max=aug_word_max, aug_word_p=aug_word_p,
Expand All @@ -68,7 +68,7 @@ def __init__(self, action=Action.SUBSTITUTE, name='RandomChar_Aug', aug_char_min
self.include_numeric = include_numeric
self.swap_mode = swap_mode
self.spec_char = spec_char
self.candidiates = candidiates
self.candidates = candidates

self.model = self.get_model()

Expand Down Expand Up @@ -248,8 +248,8 @@ def delete(self, data):
return self.reverse_tokenizer(doc.get_augmented_tokens())

def get_model(self):
if self.candidiates:
return self.candidiates
if self.candidates:
return self.candidates

candidates = []
if self.include_upper_case:
Expand Down
8 changes: 4 additions & 4 deletions nlpaug/model/lang_models/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class LanguageModels:
OPTIMIZE_ATTRIBUTES = ['external_memory', 'return_proba']

def __init__(self, device='cpu', model_type='', temperature=1.0, top_k=100, top_p=0.01, batch_size=32,
def __init__(self, device='cpu', model_type='', temperature=1.0, top_k=100, top_p=0.01, batch_size=32,
optimize=None, silence=True):
try:
import torch
Expand Down Expand Up @@ -60,7 +60,7 @@ def clean(self, text):
def predict(self, text, target_word=None, n=1):
raise NotImplementedError

# for HuggingFace pipeline
# for HuggingFace pipeline
def convert_device(self, device):
if device == 'cpu' or device is None:
return -1
Expand Down Expand Up @@ -158,7 +158,7 @@ def filtering(self, logits, seed):
def pick(self, logits, idxes, target_word, n=1, include_punctuation=False):
candidate_ids, candidate_probas = self.prob_multinomial(logits, n=n*10)
candidate_ids = [idxes[candidate_id] for candidate_id in candidate_ids]
results = self.get_candidiates(candidate_ids, candidate_probas, target_word, n,
results = self.get_candidates(candidate_ids, candidate_probas, target_word, n,
include_punctuation)

return results
Expand All @@ -183,7 +183,7 @@ def prob_multinomial(self, logits, n):
def is_skip_candidate(self, candidate):
return False

def get_candidiates(self, candidate_ids, candidate_probas, target_word=None, n=1,
def get_candidates(self, candidate_ids, candidate_probas, target_word=None, n=1,
include_punctuation=False):
# To have random behavior, NO sorting for candidate_probas.
results = []
Expand Down
8 changes: 4 additions & 4 deletions test/augmenter/char/test_random_char.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ def test_swap_random(self):
self.assertNotEqual(text, augmented_text)
self.assertEqual(len(augmented_text), len(text))

def test_candidiates(self):
candidiates = ['AAA', '11', '===', '中文']
def test_candidates(self):
candidates = ['AAA', '11', '===', '中文']
text = 'quick brown jumps over lazy'
aug = RandomCharAug(min_char=4, candidiates=candidiates)
aug = RandomCharAug(min_char=4, candidates=candidates)
augmented_text = aug.augment(text)
self.assertNotEqual(text, augmented_text)

match = False
for c in candidiates:
for c in candidates:
if c in augmented_text:
match = True
break
Expand Down

0 comments on commit 487d9c8

Please sign in to comment.