diff --git a/examples/offline_inference_fakehpu.py b/examples/offline_inference_fakehpu.py index c533bb7192d64..e1b2d611a7a8d 100644 --- a/examples/offline_inference_fakehpu.py +++ b/examples/offline_inference_fakehpu.py @@ -2,13 +2,21 @@ # Sample prompts. prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + "Berlin is the capital city of ", + "Louvre is located in the city called ", + "Barack Obama was the 44th president of ", + "Warsaw is the capital city of ", + "Gniezno is a city in ", + "Hebrew is an official state language of ", + "San Francisco is located in the state of ", + "Llanfairpwllgwyngyll is located in country of ", +] +ref_answers = [ + "Germany", "Paris", "United States", "Poland", "Poland", "Israel", + "California", "Wales" ] # Create a sampling params object. -sampling_params = SamplingParams() +sampling_params = SamplingParams(temperature=0, n=1, use_beam_search=False) # Create an LLM. llm = LLM(model="facebook/opt-125m", max_model_len=32, max_num_seqs=4) @@ -16,7 +24,10 @@ # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) # Print the outputs. -for output in outputs: +for output, answer in zip(outputs, ref_answers): prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert answer in generated_text, ( + f"The generated text does not contain the correct answer: {answer}") +print('PASSED') diff --git a/vllm/utils.py b/vllm/utils.py index ce6c0f621c263..21f1b39d4c3dd 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -212,7 +212,8 @@ def is_hpu() -> bool: @lru_cache(maxsize=None) def is_fake_hpu() -> bool: - return not _is_habana_frameworks_installed() and _is_built_for_hpu() + return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0' or ( + not _is_habana_frameworks_installed() and _is_built_for_hpu()) @lru_cache(maxsize=None)