-
Notifications
You must be signed in to change notification settings - Fork 3k
Best Practice
骑马小猫 edited this page Nov 10, 2022
·
7 revisions
这里将介绍使用PaddleNLP过程中的最佳实践方法,形式不限于代码片段和github repo,也欢迎大家来贡献自己的实践方法。
- 文本处理
from paddlenlp.transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-nano-zh")
result = tokenizer("您好,欢迎使用PaddleNLP", max_length=30, padding=True, return_token_type_id=True, return_tensors='pd')
assert result['input_ids'].shape == [1, 13]
result = tokenizer("您好,欢迎使用PaddleNLP", max_length=30, padding="max_length", return_token_type_id=True, return_tensors='pd')
assert result['input_ids'].shape == [1, 30]
- 空格处理
在文本处理处理当中,【空格】的出现可能会让大家比价头痛,特别是做序列标注相关的任务。
头痛的原因是在于,空格经过tokenizer之后会别删掉,从而导致原始文本中的字符少了一个,这会导致字符发生偏移。可是在paddlenlp tokenizer当中完全是可以解决这些情况。解决方案分为两种:
- 解决方案一
保留原始文本中的空格:这个时候需要先: 1. 使用encode来进行编码;2. 将文本转化为list;3. 将split_into_words设置为False。
这个时候即可实现在保留空格的情况下,将原始文本中的每一个字符转化为 token id.
不过保留空格这个方法,其实在一定程度上会影响语义的编码,毕竟空格也算是一种噪声,甚至包含某些特殊的含义。不过,具体效果还是得根据数据集来。
- 解决方案二:
使用return_offsets_mapping
参数获取到token_id在原始文本中的位置信息,这个时候在做序列标注任务是可通过此信息来定位到原始文本中的相关信息。
以上两种解决方案的简要实现代码如下所示:
from paddlenlp.transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-nano-zh")
result = tokenizer("您 好", return_tensors='pd', add_special_tokens=False)
assert result['input_ids'].shape == [1, 2]
# 解决方案一
result = tokenizer.encode(list("您 好"), padding=True, split_into_words = False, return_tensors='pd', add_special_tokens=False)
assert result['input_ids'].shape == [1, 3]
# {'input_ids': [892, 39979, 170], 'token_type_ids': [0, 0, 0]}
# 解决方案二
result = tokenizer.encode("您 好", add_special_tokens=False, return_offsets_mapping=True)
assert result['input_ids'].shape == [1, 3]
# {'input_ids': [892, 170], 'token_type_ids': [0, 0], 'offset_mapping': [(0, 1), (2, 3)]}
from paddlenlp.transformers import BertConfig, BertModel, BertForTokenClassification
from paddlenlp.utils.converter import Converter, StateDictKeysChecker
config = BertConfig()
bert_model = BertModel(config)
bert_for_token_model = BertForTokenClassification(config)
# base-downstream
checker = StateDictKeysChecker(
bert_model, Converter.get_model_state_dict(bert_for_token_model))
unexpected_keys = checker.get_unexpected_keys()
assert len(unexpected_keys) == 2
mismatched_keys = checker.get_mismatched_keys()
assert len(mismatched_keys) == 0