Skip to content

Commit

Permalink
Add Streaming Sentencepiece Decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
apaniukov committed Oct 19, 2023
1 parent 96673f5 commit 0e7ae87
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
std::make_shared<ov::frontend::ConversionExtension>("Const", translate_const), \
std::make_shared<ov::OpExtension<TemplateExtension::SentencepieceTokenizer>>(), \
std::make_shared<ov::OpExtension<TemplateExtension::SentencepieceDetokenizer>>(), \
std::make_shared<ov::OpExtension<TemplateExtension::SentencepieceStreamDetokenizer>>(), \
std::make_shared<ov::frontend::ConversionExtension>("SentencepieceOp", translate_sentencepiece_op), \
std::make_shared<ov::frontend::ConversionExtension>("RaggedTensorToSparse", translate_sentencepiece_tokenizer),
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def convert_tokenizer(
tokenizer_object: Any, number_of_inputs: int = 1, with_decoder: bool = False
tokenizer_object: Any, number_of_inputs: int = 1, with_decoder: bool = False, streaming_decoder: bool = False
) -> Union[Model, Tuple[Model, Model]]:
# todo: add support for more then 1 input
if number_of_inputs > 1:
Expand All @@ -32,6 +32,7 @@ def convert_tokenizer(
tokenizer_object,
add_attention_mask=True,
with_decoder=with_decoder,
streaming_decoder=streaming_decoder,
)
elif isinstance(tokenizer_object, PreTrainedTokenizerFast):
logger.info("Convert Huggingface Fast tokenizer pipeline.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

import numpy as np
import openvino.runtime.opset12 as opset
from openvino.runtime import Model, PartialShape, Type, op
from openvino import Model, PartialShape, Type
from openvino.runtime import Node, op
from openvino.runtime.exceptions import OVTypeError
from openvino.runtime.utils.types import as_node, make_constant_node

Expand Down Expand Up @@ -273,7 +274,6 @@ def convert_fast_tokenizer(
hf_tokenizer: "PreTrainedTokenizerBase",
number_of_inputs: int = 1,
with_decoder: bool = False,
greedy_decoder: bool = False,
) -> Union[Model, Tuple[Model, Model]]:
pipeline = TransformersTokenizerPipelineParser(hf_tokenizer).parse(number_of_inputs=number_of_inputs)
ov_tokenizer = pipeline.get_encoder_ov_subgraph()
Expand Down Expand Up @@ -312,6 +312,7 @@ def convert_sentencepiece_model_tokenizer(
hf_tokenizer: "PreTrainedTokenizerBase",
add_attention_mask: bool = True,
with_decoder: bool = False,
streaming_decoder: bool = False,
) -> Union[Model, Tuple[Model, Model]]:
if not is_sentencepiece_model(hf_tokenizer):
raise OVTypeError("Cannot convert tokenizer that does not have `.model` file.")
Expand Down Expand Up @@ -386,16 +387,23 @@ def convert_sentencepiece_model_tokenizer(
if not with_decoder:
return tokenizer_encoder

decoder_input = op.Parameter(Type.i32, PartialShape(["?", "?"])) # (batch, sequence)
token_ids = decoder_input
return tokenizer_encoder, get_sp_decoder(sp_model_node, streaming_decoder=streaming_decoder)


def get_sp_decoder(sp_model_node: Node, streaming_decoder: bool = False) -> Model:
token_ids = op.Parameter(Type.i32, PartialShape(["?", "?"])) # (batch, sequence)

decoder = factory.create(
"SentencepieceDetokenizer",
"SentencepieceStreamDetokenizer" if streaming_decoder else "SentencepieceDetokenizer",
[sp_model_node, token_ids],
)
string_output = factory.create("StringTensorPack", decoder.outputs()).outputs()
).outputs()

if streaming_decoder:
decoder = RegexDecodingStep.replace_sp_spaces().get_ov_subgraph(decoder)
decoder = RegexDecodingStep.replace_sp_newlines().get_ov_subgraph(decoder)

string_output = factory.create("StringTensorPack", decoder).outputs()
string_output[0].tensor.add_names({STRING_OUTPUT_NAME})
tokenizer_decoder = Model(string_output, [decoder_input], TOKENIZER_DECODER_NAME)
tokenizer_decoder = Model(string_output, [token_ids], TOKENIZER_DECODER_NAME)
tokenizer_decoder.validate_nodes_and_infer_types()

return tokenizer_encoder, tokenizer_decoder
return tokenizer_decoder
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,20 @@ def clean_up_tokenization_spaces(cls) -> "RegexDecodingStep":
replace_term=r"\1",
)

@classmethod
def replace_sp_spaces(cls) -> "RegexDecodingStep":
return cls(
regex_search_pattern="▁",
replace_term=" ",
)

@classmethod
def replace_sp_newlines(cls) -> "RegexDecodingStep":
return cls(
regex_search_pattern="<0x0A>",
replace_term="\n",
)

def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
input_nodes.extend(
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,87 @@ bool SentencepieceDetokenizer::has_evaluate() const {
std::shared_ptr<Node> SentencepieceDetokenizer::clone_with_new_inputs(const OutputVector& new_args) const {
return std::make_shared<SentencepieceDetokenizer>(new_args, m_sp);
}


// Stream Detokenizer

SentencepieceStreamDetokenizer::SentencepieceStreamDetokenizer(const OutputVector& args) :
m_sp(std::make_shared<SentencePieceProcessor>()), Op(args) {
auto sp_model_const = as_type_ptr<Constant>(args[0].get_node_shared_ptr());
OPENVINO_ASSERT(sp_model_const, "SentencepieceDetokenizer expects SentencePiece model to be constant.");
auto spm_model = static_cast<const char*>(sp_model_const->get_data_ptr());
auto spm_model_size = sp_model_const->get_byte_size();

// configure SentencePieceProcessor
std::string model_proto(spm_model, spm_model_size);
CHECK_OK(m_sp->LoadFromSerializedProto(model_proto));
constructor_validate_and_infer_types();
}

SentencepieceStreamDetokenizer::SentencepieceStreamDetokenizer(const OutputVector& args, const std::shared_ptr<sentencepiece::SentencePieceProcessor>& sp) :
m_sp((sp == nullptr) ? std::make_shared<SentencePieceProcessor>(): sp), Op(args) {
// constructor above without sp argument never called when the node is created with python factory, so need to init and cache m_sp here
if (!m_sp->status().ok()) {
auto sp_model_const = as_type_ptr<Constant>(args[0].get_node_shared_ptr());
OPENVINO_ASSERT(sp_model_const, "SentencepieceDetokenizer expects SentencePiece model to be constant.");
auto spm_model = static_cast<const char*>(sp_model_const->get_data_ptr());
auto spm_model_size = sp_model_const->get_byte_size();

// configure SentencePieceProcessor
std::string model_proto(spm_model, spm_model_size);
CHECK_OK(m_sp->LoadFromSerializedProto(model_proto));
};
constructor_validate_and_infer_types();
}

void SentencepieceStreamDetokenizer::validate_and_infer_types() {
OPENVINO_ASSERT(get_input_size() == 2, "SentencepieceDetokenizer expects two inputs: sp model and token ids");
OPENVINO_ASSERT(get_input_element_type(0) == element::u8, "SentencepieceDetokenizer accepts sp model as the first input and it should be of type u8 tensor");
OPENVINO_ASSERT(get_input_partial_shape(1).size() == 2, "SentencepieceDetokenizer expects 2D tensor as second input");

auto batch_size = PartialShape({get_input_partial_shape(1)[0]});
set_string_output(this, 0, batch_size);
}

bool SentencepieceStreamDetokenizer::visit_attributes(AttributeVisitor& visitor) {
return true;
}

bool SentencepieceStreamDetokenizer::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
auto batch_size = inputs[1].get_shape()[0];
auto seq_len = inputs[1].get_shape()[1];
auto input_data = inputs[1].data<const int32_t>();

outputs[0].set_shape({batch_size});
outputs[1].set_shape({batch_size});
outputs[2].set_shape({batch_size * seq_len * 100}); // 100 chars - max token length

auto begins = outputs[0].data<int32_t>();
auto ends = outputs[1].data<int32_t>();
auto chars = outputs[2].data<uint8_t>();
uint32_t char_offset = 0;

for(size_t batch = 0; batch < batch_size; ++batch) {
const auto start = batch * seq_len;

begins[batch] = char_offset;
for(size_t seq = start; seq < start + seq_len; ++seq) {
const auto token_id = input_data[seq];
const auto token = m_sp->IdToPiece(token_id);

std::copy(token.begin(), token.end(), &chars[char_offset]);
char_offset += token.size();
};
ends[batch] = char_offset;
}
outputs[2].set_shape({char_offset});
return true;
}

bool SentencepieceStreamDetokenizer::has_evaluate() const {
return true;
}

std::shared_ptr<Node> SentencepieceStreamDetokenizer::clone_with_new_inputs(const OutputVector& new_args) const {
return std::make_shared<SentencepieceStreamDetokenizer>(new_args, m_sp);
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,28 @@ namespace TemplateExtension {
private:
std::shared_ptr<sentencepiece::SentencePieceProcessor> m_sp;
};


class SentencepieceStreamDetokenizer : public ov::op::Op {
public:
OPENVINO_OP("SentencepieceStreamDetokenizer");

SentencepieceStreamDetokenizer() = default;
SentencepieceStreamDetokenizer(const ov::OutputVector& args);
SentencepieceStreamDetokenizer(const ov::OutputVector& args,
const std::shared_ptr<sentencepiece::SentencePieceProcessor>& sp);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override;

bool has_evaluate() const override;

private:
std::shared_ptr<sentencepiece::SentencePieceProcessor> m_sp;
};
} // namespace TemplateExtension

0 comments on commit 0e7ae87

Please sign in to comment.