Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tokenizers] String type support in Tokenizers #781

Merged
merged 14 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pip install -e .[all]
```python
from transformers import AutoTokenizer
from openvino import compile_model
from openvino_tokenizers import convert_tokenizer, pack_strings
from openvino_tokenizers import convert_tokenizer

hf_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
ov_tokenizer = convert_tokenizer(hf_tokenizer)
Expand All @@ -40,7 +40,7 @@ compiled_tokenzier = compile_model(ov_tokenizer)
text_input = "Test string"

hf_output = hf_tokenizer([text_input], return_tensors="np")
ov_output = compiled_tokenzier(pack_strings([text_input]))
ov_output = compiled_tokenzier([text_input])

for output_name in hf_output:
print(f"OpenVINO {output_name} = {ov_output[output_name]}")
Expand All @@ -58,7 +58,7 @@ for output_name in hf_output:
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from openvino import compile_model, convert_model
from openvino_tokenizers import convert_tokenizer, pack_strings, connect_models
from openvino_tokenizers import convert_tokenizer, connect_models

checkpoint = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
hf_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
Expand All @@ -73,7 +73,7 @@ ov_model = convert_model(hf_model, example_input=hf_input.data)
combined_model = connect_models(ov_tokenizer, ov_model)
compiled_combined_model = compile_model(combined_model)

openvino_output = compiled_combined_model(pack_strings(text_input))
openvino_output = compiled_combined_model(text_input)

print(f"OpenVINO logits: {openvino_output['logits']}")
# OpenVINO logits: [[ 1.2007061 -1.4698029]]
Expand All @@ -83,12 +83,11 @@ print(f"HuggingFace logits {hf_output.logits}")

### Use Extension With Converted (De)Tokenizer or Model With (De)Tokenizer

To work with converted tokenizer you need `pack_strings`/`unpack_strings` functions.
To work with converted tokenizer and detokenizer, numpy string tensors are used.

```python
import numpy as np
from openvino import Core
from openvino_tokenizers import unpack_strings

core = Core()

Expand All @@ -98,7 +97,7 @@ compiled_detokenizer = core.compile_model("detokenizer.xml")
token_ids = np.random.randint(100, 1000, size=(3, 5))
openvino_output = compiled_detokenizer(token_ids)

print(unpack_strings(openvino_output["string_output"]))
print(openvino_output["string_output"])
# ['sc�ouition�', 'intvenord hasient', 'g shouldwer M more']
```

Expand All @@ -108,12 +107,7 @@ print(unpack_strings(openvino_output["string_output"]))
import numpy as np
from openvino import compile_model, convert_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from openvino_tokenizers import (
add_greedy_decoding,
convert_tokenizer,
pack_strings,
unpack_strings,
)
from openvino_tokenizers import add_greedy_decoding, convert_tokenizer

# Use different repo for the tokenizer because the original repo doesn't have .model file
# Sentencepiece(Unigram) tokenizer supported only with .model file
Expand All @@ -128,7 +122,7 @@ ov_tokenizer, ov_detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=
compiled_tokenizer = compile_model(ov_tokenizer)

# transform input text into tokens
ov_input = compiled_tokenizer(pack_strings(text_input))
ov_input = compiled_tokenizer(text_input)
hf_input = hf_tokenizer(text_input, return_tensors="pt")

# convert Pytorch model to OpenVINO IR and add greedy decoding pipeline to it
Expand Down Expand Up @@ -158,7 +152,7 @@ hf_token_ids = hf_model.generate(

# decode model output
compiled_detokenizer = compile_model(ov_detokenizer)
ov_output = unpack_strings(compiled_detokenizer(ov_token_ids)["string_output"])
ov_output = compiled_detokenizer(ov_token_ids)["string_output"]
hf_output = hf_tokenizer.batch_decode(hf_token_ids, skip_special_tokens=True)
print(f"OpenVINO output string: `{ov_output}`")
# OpenVINO output string: `['Quick brown fox was walking through the forest. He was looking for something']`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def convert_sentencepiece_model_tokenizer(
hf_slow_tokenizer = hf_tokenizer.slow_tokenizer_class.from_pretrained(tmp)
fairseq_offset = getattr(hf_slow_tokenizer, "fairseq_offset", None)

input_node = op.Parameter(Type.u8, PartialShape(["?"]))
input_node = op.Parameter(Type.string, PartialShape(["?"]))
input_node.set_friendly_name("string_input")

if is_chatglm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def __getitem__(self, item: int) -> BasePipelineStep:
return self.steps[item]

def get_tokenizer_ov_subgraph(self) -> Model:
string_inputs = [op.Parameter(Type.u8, PartialShape(["?"])) for _ in range(self.number_of_inputs)]
string_inputs = [op.Parameter(Type.string, PartialShape(["?"])) for _ in range(self.number_of_inputs)]

processing_outputs = []
for input_node in string_inputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@

import numpy as np
import pytest
from openvino import Core
from openvino_tokenizers import (
convert_tokenizer,
pack_strings,
unpack_strings,
)
from openvino import Core, Tensor
from openvino_tokenizers import convert_tokenizer
from transformers import AutoTokenizer


# Left these two methods for convenient transition from legay u8 representation to native string tensors
# TODO: Remove the methods when transition is over
def pack_strings(strings):
return strings

def unpack_strings(strings):
return list(strings)


Comment on lines +12 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Del all (un)pack_strings functions from tests. I think we also should delete them from openvino_tokenizers.__init__.py file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do it in a separate commit please. I'm on vacation. This testing functionality -- I really found it useful to keep these two functions for debugging purposes as gates for all string tensors.

core = Core()

eng_test_strings = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,9 @@ void SentencepieceTokenizer::validate_and_infer_types() {
FRONT_END_GENERAL_CHECK(get_input_size() == 2, "SentencepieceTokenizer expects two inputs: sp model and input sentences");
FRONT_END_GENERAL_CHECK(get_input_element_type(0) == element::u8, "SentencepieceTokenizer accepts sp model as the first input and it should be of type u8 tensor");

#if USE_STRING_TENSORS

#if OPENVINO_USE_INPUT_OUTPUT_STRING_TENSOR_HACK
FRONT_END_GENERAL_CHECK(
get_input_element_type(1) == element::string || get_input_element_type(1) == element::u8,
"SentencepieceTokenizer accepts sentences as the second input and it should be of type u8 or string depending on the current stage of model preparation");
#else
FRONT_END_GENERAL_CHECK(
get_input_element_type(1) == element::string,
"SentencepieceTokenizer accepts sentences as the second input and it should be of type string tensor");
#endif

#else

#if 0 // change to 0 when compiled with master and the bug with data propagation from within inline context is not solved
FRONT_END_GENERAL_CHECK(
get_input_element_type(1) == element::u8,
"SentencepieceTokenizer accepts sentences as the second input and it should be of type u8 tensor, but got " +
get_input_element_type(1).get_type_name());
#endif

#endif
get_input_element_type(1) == element::string || get_input_element_type(1) == element::u8,
"SentencepieceTokenizer accepts sentences as the second input and it should be of type string tensor");

#endif

Expand Down Expand Up @@ -161,37 +142,37 @@ bool SentencepieceTokenizer::evaluate(TensorVector& outputs, const TensorVector&

#else

#if USE_STRING_TENSORS

#if OPENVINO_USE_INPUT_OUTPUT_STRING_TENSOR_HACK
const ov::Tensor& strings_tensor = **reinterpret_cast<ov::Tensor**>(inputs[1].data<uint8_t>());
#else
const ov::Tensor& strings_tensor = inputs[1];
#endif

const std::string* strings = strings_tensor.data<std::string>();
size_t batch_size = ov::shape_size(strings_tensor.get_shape());
auto input_element_type = get_input_element_type(1);
int32_t batch_size;

#else
// used in case of string tensors
const std::string* strings;

int32_t batch_size;
// used in case of u8 packed representation
const int32_t* begin_ids;
const int32_t* end_ids;
const uint8_t* data;
parse_packed_strings(inputs[1], batch_size, begin_ids, end_ids, data);

#endif
if(input_element_type == ov::element::string) {
strings = inputs[1].data<const std::string>();
batch_size = static_cast<int32_t>(ov::shape_size(inputs[1].get_shape()));
} else if(input_element_type == ov::element::u8) {
parse_packed_strings(inputs[1], batch_size, begin_ids, end_ids, data);
}

#endif

size_t max_token_id = 0;
for (size_t batch_ind = 0; batch_ind < batch_size; ++batch_ind) {
#if USE_STRING_TENSORS && !SENTENCE_PIECE_EXTENSION_DECOMPOSED_STRINGS
const std::string& sentence = strings[batch_ind];
#else
auto begin_ind = begin_ids[batch_ind];
auto end_ind = end_ids[batch_ind];
absl::string_view sentence((const char*)data + begin_ind, end_ind - begin_ind);
#endif
absl::string_view sentence;
if(input_element_type == ov::element::string) {
sentence = strings[batch_ind];
} else if(input_element_type == ov::element::u8) {
auto begin_ind = begin_ids[batch_ind];
auto end_ind = end_ids[batch_ind];
sentence = absl::string_view((const char*)data + begin_ind, end_ind - begin_ind);
}

std::vector<int32_t> ids;
CHECK_OK(m_sp->SampleEncode(sentence, m_nbest_size, m_alpha, &ids));
// put into resulted vectors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,23 @@ using namespace ov;
void StringTensorPack::validate_and_infer_types() {
OPENVINO_ASSERT(m_mode == "begins_ends", "StringTensorPack supports only 'begins_ends' mode, but get " + m_mode);
check_string_input(this, 0);
#if USE_STRING_TENSORS
set_output_type(0, element::string, get_input_partial_shape(0));
#else
set_output_type(0, element::u8, PartialShape{Dimension()});
#endif
}

bool StringTensorPack::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
#if USE_STRING_TENSORS
// TODO
return false;
#else
auto rank = inputs[0].get_shape().size();
if (rank != 1) {
std::cerr << "[ WARNING ] StringTensorPack ignores the rank " << rank << " of input tensor and set rank=1 in the output\n";
}

auto num_elements = shape_size(inputs[0].get_shape());
auto num_chars = shape_size(inputs[2].get_shape());
auto num_output_elements = 4*(1 + 1 + num_elements) + num_chars;
outputs[0].set_shape(Shape{num_output_elements});
auto num_strings = outputs[0].get_size();
OPENVINO_ASSERT(inputs[0].get_size() == num_strings);
OPENVINO_ASSERT(inputs[1].get_size() == num_strings);

// FIXME: Do the repacking, otherwise cannot handle string tensors with gaps between strings
//auto begins = inputs[0].data<const int32_t>(); // this is not needed as no repacking happens in this version of code
auto begins = inputs[0].data<const int32_t>();
auto ends = inputs[1].data<const int32_t>();
auto chars = inputs[2].data<const uint8_t>();

auto output = outputs[0].data<uint8_t>();
auto output_int32 = reinterpret_cast<int32_t*>(output);
auto chars = reinterpret_cast<const char*>(inputs[2].data<const uint8_t>());

*output_int32++ = num_elements;
*output_int32++ = 0;
output_int32 = std::copy(ends, ends + num_elements, output_int32);
output = reinterpret_cast<uint8_t*>(output_int32);
output = std::copy(chars, chars + num_chars, output);
auto strings = outputs[0].data<std::string>();

OPENVINO_ASSERT(num_output_elements == output - outputs[0].data<uint8_t>(), "[ INTERNAL ERROR ] StringTensorPack output tensor is corrupted");

// WARNING! Chars are not repacked. If there are gaps between strings, they will remain.
for(size_t i = 0; i < num_strings; ++i) {
strings[i].assign(chars + begins[i], chars + ends[i]);
}

return true;
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

#include <openvino/op/op.hpp>

// Having a decomposed representation for a tensor, converts it to a single string tensor
// (packed u8 or natively supported element::string depending on whether or not USE_STRING_TENSORS defined).
// Having a decomposed representation for a tensor, converts it to a single string tensor with element::string element type.
class StringTensorPack : public ov::op::Op {
public:
OPENVINO_OP("StringTensorPack");
Expand Down
Loading
Loading