diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 7fb91610b..1dc2761f6 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -45,10 +45,10 @@ def onnx_exists(onnx_file_path: str) -> bool: ) -def main( +def infer_api( model_name: str, num_cores: int, - prompt: str, + prompt: str = Constants.input_str, aic_enable_depth_first: bool = False, mos: int = -1, cache_dir: str = Constants.CACHE_DIR, @@ -60,7 +60,8 @@ def main( device_group: List[int] = [ 0, ], -) -> None: + skip_stats : bool = False, +) -> str: # Make model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_name)) os.makedirs(model_card_dir, exist_ok=True) @@ -76,21 +77,24 @@ def main( onnx_dir_path = os.path.join(model_card_dir, "onnx") onnx_model_path = os.path.join(onnx_dir_path, model_name.replace("/", "_") + "_kv_clipped_fp16.onnx") + # skip model download if qpc exits and we do not need stats + if not qpc_exists(qpc_dir_path) or not skip_stats: # Get tokenizer - if hf_token is not None: - login(hf_token) - model_hf_path = hf_download( - repo_id=model_name, - cache_dir=cache_dir, - ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf"], - ) - tokenizer = AutoTokenizer.from_pretrained(model_hf_path, use_cache=True, padding_side="left") + if hf_token is not None: + login(hf_token) + model_hf_path = hf_download( + repo_id=model_name, + cache_dir=cache_dir, + ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf"], + ) + tokenizer = AutoTokenizer.from_pretrained(model_hf_path, use_cache=True, padding_side="left") if qpc_exists(qpc_dir_path): # execute logger.info("Pre-compiled qpc found! Trying to execute with given prompt") - latency_stats_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt) - return + if not skip_stats: + latency_stats_kv(tokenizer=tokenizer, qpc=qpc_dir_path, device_id=device_group, prompt=prompt) + return qpc_dir_path if onnx_exists(onnx_model_path): # Compile -> execute @@ -110,8 +114,9 @@ def main( assert ( generated_qpc_path == qpc_dir_path ), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}" - latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt) - return + if not skip_stats: + latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt) + return generated_qpc_path ############################################# # hf model -> export -> compile -> execute @@ -156,9 +161,43 @@ def main( ), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}" logger.info(f"Compiled qpc files can be found at : {generated_qpc_path}") - # Execute - latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt) + if not skip_stats: + latency_stats_kv(tokenizer=tokenizer, qpc=generated_qpc_path, device_id=device_group, prompt=prompt) + + return generated_qpc_path +def main( + model_name: str, + num_cores: int, + prompt: str, + aic_enable_depth_first: bool = False, + mos: int = -1, + cache_dir: str = Constants.CACHE_DIR, + hf_token: str = None, + batch_size: int = 1, + prompt_len: int = 32, + ctx_len: int = 128, + mxfp6: bool = False, + device_group: List[int] = [ + 0, + ], +) -> None: + _ = infer_api( + model_name=model_name, + num_cores=num_cores, + prompt=prompt, + aic_enable_depth_first=aic_enable_depth_first, + mos=mos, + cache_dir=cache_dir, + hf_token=hf_token, + batch_size=batch_size, + prompt_len=prompt_len, + ctx_len=ctx_len, + mxfp6=mxfp6, + device_group=device_group + ) + + return if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/QEfficient/generation/llm_generator.py b/QEfficient/generation/llm_generator.py new file mode 100644 index 000000000..068c2dadd --- /dev/null +++ b/QEfficient/generation/llm_generator.py @@ -0,0 +1,288 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import torch +import numpy as np +import torch.nn as nn + +from typing import List, Optional, Union +from threading import Thread + +from transformers import ( + AutoTokenizer, + TextIteratorStreamer, + TextStreamer, + AutoTokenizer, + LogitsProcessorList, + MinLengthLogitsProcessor, + TopKLogitsWarper, + TemperatureLogitsWarper, + StoppingCriteriaList, + MaxLengthCriteria, +) + +from QEfficient.generation.cloud_infer import QAICInferenceSession + + +class LLMGenerator: + def __init__( + self, + qpc_path :str, + model_name : str, + device_id: Optional[List[int]] = [0], + prompt_len: Optional[int] = 32, + ctx_len: Optional[int] = 128, + streamer: Optional[Union[TextStreamer, TextIteratorStreamer]] = None, + retained_state :bool = True + ): + self.session = None + self.tokenizer = None + self.is_first_prompt = False + self.model_name = model_name + self.device_id = device_id + self.curr_cache_index = 0 + self.ctx_len = ctx_len + self.prompt_len = prompt_len + self.generated_ids = [] + self.inputs = None + self.stop_indicator = True + self.retained_state = retained_state + + self.qpc_path = ( + qpc_path if os.path.exists(qpc_path) else OSError(f"{qpc_path} not found !") + ) + + try: + self.session = QAICInferenceSession( + self.qpc_path, self.device_id, enable_debug_logs=False + ) + if self.retained_state: + self.session.skip_buffers( + [x for x in self.session.input_names if x.startswith("past_")] + ) + self.session.skip_buffers( + [ + x + for x in self.session.output_names + if x.endswith("_RetainedState") + ] + ) + + except Exception as err: + raise RuntimeError(f"Unable to load qpc on device , {err}") + + try: + hf_token = None + if os.getenv("HF_TOKEN") is not None: + hf_token = os.getenv('HF_TOKEN') + tokenizer = AutoTokenizer.from_pretrained( + model_name, padding_side="left", hf_token=hf_token + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + self.tokenizer = tokenizer + except Exception as err: + raise RuntimeError(f"Unable to load tokenizer, {err}") + + if streamer: + self.streamer = streamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=None + ) + else: + self.streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=None) + + # instantiate default logit processor and wrapper here + # TODO : change default values with temperature and top_p + # instantiate logits processors + self.logits_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor( + 15, eos_token_id=self.tokenizer.eos_token_id + ), + ] + ) + + # instantiate logits processors + self.logits_warper = LogitsProcessorList( + [ + TopKLogitsWarper(50), + TemperatureLogitsWarper(0.7), + ] + ) + + self.stopping_criteria = StoppingCriteriaList( + [MaxLengthCriteria(max_length=ctx_len)] + ) + + def _generate_next_token(self, outputs, sample=False): + logits = outputs["logits"] + + if sample: + # pre-process distribution + input_ids = torch.Tensor(self.inputs["input_ids"]) + next_token_logits = torch.from_numpy(logits) + + # Qeff is maintaining [1,1,VOCAB_SIZE] + if len(next_token_logits.shape) == 3: + next_token_logits = next_token_logits.squeeze(0) + next_token_scores = self.logits_warper(input_ids, next_token_logits) + + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + next_token_id = next_tokens.numpy().reshape(1, 1) + else: + # greedy search + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + next_token_id = logits.argmax(2) + + return next_token_id + + def _stopping_criteria(self, next_token_id, max_new_tokens=None): + if self.curr_cache_index >= self.ctx_len - 1: + print("self.curr_cache_index reach limit") + return True + + if max_new_tokens: + if len(self.generated_ids) > max_new_tokens: + print( + "len(self.generated_ids) > max_new_tokens", + len(self.generated_ids) > max_new_tokens, + ) + return True + + if next_token_id == self.tokenizer.eos_token_id: + print( + next_token_id == self.tokenizer.eos_token_id, + "next_token_id == self.tokenizer.eos_token_id", + ) + return True + + # llama3 + if next_token_id == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"): + print( + next_token_id == self.tokenizer.eos_token_id, + "next_token_id == <|eot_id|>", + ) + return True + + return False + + def prepare_inputs_for_inference(self, prompt): + # prepare inputs for prefill part + inputs = self.tokenizer( + prompt, + return_tensors="np", + padding="max_length", + max_length=self.prompt_len, + ) + batch_size, prompt_len = inputs["input_ids"].shape + + ctx_len = self.ctx_len + + inputs["position_ids"] = (np.cumsum(inputs["attention_mask"], 1) - 1) * inputs[ + "attention_mask" + ] + inputs["attention_mask"] = np.concatenate( + [ + inputs["attention_mask"].astype(bool), + np.zeros((batch_size, ctx_len - prompt_len), dtype=bool), + ], + 1, + ) + cache_index = np.array([0]) + inputs["cache_index"] = cache_index + + return inputs, prompt_len + + def update_inputs_for_inference(self, inputs, next_token_id): + _, prompt_len = inputs["input_ids"].shape + inputs["cache_index"] += prompt_len + inputs["input_ids"] = next_token_id + if "attention_mask" in inputs.keys(): + inputs["position_ids"] = inputs.pop("attention_mask").sum(1, keepdims=True) + else: + inputs["position_ids"] += 1 + _, prompt_len = inputs["input_ids"].shape + return inputs, prompt_len + + def generate(self, prompt: str, sample: bool = False, max_new_tokens: int = None): + session = self.session + + multi_turn_input_ids = [] + + if self.curr_cache_index == 0: + self.inputs, prompt_len = self.prepare_inputs_for_inference(prompt) + outputs = session.run(self.inputs) + self.curr_cache_index += prompt_len + session.skip_buffers(["attention_mask"]) + + else: + multi_turn_input_ids = self.tokenizer( + prompt, + return_tensors="np", + ).input_ids + self.generated_ids = [] + + while self.stop_indicator: + if len(multi_turn_input_ids) == 0: + next_token_id = self._generate_next_token(outputs, sample) + # next_token_id will be from prompt till prompt + self.generated_ids.append(next_token_id) + + if self.streamer: + self.streamer.put(next_token_id[0]) + + if self._stopping_criteria(next_token_id, max_new_tokens): + print("Stopping criteria hit") + break + elif ( + len(multi_turn_input_ids.shape) == 2 + and multi_turn_input_ids.shape[1] > 0 + ): + next_token_id, multi_turn_input_ids = ( + multi_turn_input_ids[:, 0], + multi_turn_input_ids[:, 1:], + ) + next_token_id = np.expand_dims(next_token_id, 1) + elif ( + len(multi_turn_input_ids.shape) == 2 + and multi_turn_input_ids.shape[1] == 0 + ): + multi_turn_input_ids = [] + + self.inputs, next_prompt_len = self.update_inputs_for_inference( + self.inputs, next_token_id + ) + outputs = session.run(self.inputs) + self.curr_cache_index += next_prompt_len + + if self.streamer: + return self.streamer.end() + else: + return "" + + def stream(self, prompt: str, sample: bool = False, max_new_tokens: int = None): + generate_args = { + "prompt": prompt, + "sample": sample, + "max_new_tokens": max_new_tokens, + } + + t = Thread(target=self.generate, kwargs=generate_args) + t.start() + + outputs = [] + for text in self.streamer: + outputs.append(text) + yield "".join(outputs) + + print("".join(outputs)) + + def apply_chat_template(self, chat): + return self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) \ No newline at end of file diff --git a/app/Readme.md b/app/Readme.md new file mode 100644 index 000000000..2567661d1 --- /dev/null +++ b/app/Readme.md @@ -0,0 +1,92 @@ + +# Developer Applications on Cloud AI 100 using Transformers Library + + +### Instructions to launch the app +1. System Dependencies + - `sudo apt-get install ffmpeg openssl` + - same as the `efficient-transformers` +2. Clone the repo `git clone https://github.com/quic/efficient-transformers.git` +3. Change directory `cd app` + - create `app_config.json` inside directory + - update the information in app_config.json like given below section + - if you have hf-token for accessing model that requires login, please create `.env` file and add below line + ``` + HF_TOKEN= + ``` +4. Update pip, `pip install -U pip` +5. Install dependencies + - Install python requirements : `pip install -r requirements` + - Install Efficient Transformers Library : `pip install -e ..` + - Generate key and cert files : `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 365 -nodes` + - Fill details in interactive session to generate keys +6. Run `python app.py` +7. Open browser https://server_name_or_ip:7881 +8. Accept the certificate + + +### Interaction of Developer Application and QEfficient Transformers Library +![Workflow](./img/full.png "Workflow of DevApp and QEfficient Interaction") + + + +### Format of "app_config.json" + +Please modify the `app_config.json` like below, +- You can add N number of entry +- Each entry name in app_config will appear as list in dropdown of tasks +- Each entry name inside the `task` will appear as list in dropdown of models +- `qpc_path` : can be either kept empty / path where you want your compiled binary to be after compilation +- `model_card` : required HF model card name for each dropdown entry + +```json +{ + "text-generation" : { + "codellama" : { + "qpc_path" : "", + "model_name" : "", + "prompt_len" : 128, + "ctx_len" : 1024, + "device_id" : [0], + "num_cores" : 16 + }, + "mpt" : { + "qpc_path" : "", + "model_name" : "", + "prompt_len" : 128, + "ctx_len" : 1024, + "device_id" : [1], + "num_cores" : 16 + }, + "llama" : { + "qpc_path" : "", + "model_name" : "", + "prompt_len" : 128, + "ctx_len" : 1024, + "device_id" : [2], + "num_cores" : 16 + }, + "mistral" : { + "qpc_path" : "", + "model_name" : "", + "prompt_len" : 128, + "ctx_len" : 1024, + "device_id" : [3], + "num_cores" : 16 + } + }, + "question-answering" : { + }, + "image-generation" : { + }, + "multi-modalities" : { + } + +} + +``` diff --git a/app/app.py b/app/app.py new file mode 100755 index 000000000..2a1512cb8 --- /dev/null +++ b/app/app.py @@ -0,0 +1,301 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import os +import time +import json +import gradio as gr + +from pathlib import Path +from threading import Thread +from typing import List, Tuple +from dotenv import load_dotenv +from jinja2.exceptions import TemplateError + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + TextIteratorStreamer, + TextStreamer, +) + +from utils import ( + get_list_of_tasks, + get_list_of_models_all, + get_data, + get_generator, + load_models_artifacts, + get_app_config, +) + +from QEfficient.cloud.infer import infer_api + +# Load .env file +load_dotenv() + +# Load app config +app_config = get_app_config() +list_of_tasks = get_list_of_tasks() +list_of_models = get_list_of_models_all() + +load_models_artifacts() + +# Global variable for book keeping +qeff_generator_model = None +qeff_flags = set() +summary_text = "" +previous_current_ctx_len = 0 +last_prompt = "" +last_state_generation_ids = [] +disable_multiturn = True + +# main title of app +title = """ +# Developer Applications on Cloud AI 100 using Transformers Library +""" +# title for left container of app +subtitle_left = """ +## Developer Application +""" +# title for right container of app +subtitle_right = """ +## Optimizing and Compiling Model using Qualcomm Transformers Library + +""" + + +def update_model(task, model): + global qeff_generator_model, app_config + new_obj = get_generator(task, model, app_config) + if new_obj is not None: + qeff_generator_model = new_obj + print("Updating qeff generator, ", qeff_generator_model.model_name) + + +def get_prompt(message: str, system_prompt: str): + prompt = message + chat = [] + if system_prompt: + chat.append({"role": "system", "content": f"{system_prompt}"}) + chat.append({"role": "user", "content": f"{message}"}) + + try: + prompt = qeff_generator_model.tokenizer.apply_chat_template( + chat, tokenize=False + ) + except TemplateError: + prompt = qeff_generator_model.tokenizer.apply_chat_template( + chat[1:], tokenize=False + ) + except Exception as err: + print(err) + + return prompt + + +def run_qeff_check(task, model_name, progress=gr.Progress()): + global summary_text, qeff_flags, app_config + summary_text = "" + + model_info = get_data(task, model_name) + + if model_name not in qeff_flags: + qeff_flags.add(model_name) + + summary_text += f"$ Downloaded {model_name} from cache directory\n" + + progress(0, desc="Optimizing and Compiling...") + time.sleep(0.5) + for i in progress.tqdm(range(100), desc="Optimizing and Compiling..."): + time.sleep(0.04) + + # calling infer api directly to get qpc_path + app_config[task][model_name]['qpc_path'] = infer_api( + model_name = model_info['model_name'], + num_cores = model_info['num_cores'], + prompt_len = model_info['prompt_len'], + ctx_len = model_info['ctx_len'], + skip_stats = True, + prompt = "", + mxfp6 = model_info['mxfp6'] + ) + + if not os.path.exists(app_config[task][model_name]['qpc_path']): + raise RuntimeError(f"qpc path not found for {task} {model_name}") + + summary_text += f"$ Optimized {model_name}\n" + + progress(0, desc="Generating Inference Container...") + for i in progress.tqdm(range(100), desc="Generating Inference Container..."): + pass + + summary_text += f"$ Compiled {model_name} and generated inference container\n" + + update_model(task, model_name) + print(qeff_generator_model.model_name) + + return Path("./img/box.png") + + +def summary(): + return summary_text + + +def infer_prompt(msg, chat_history, task, model): + global last_prompt, previous_current_ctx_len, last_state_generation_ids + + qeff_generator_model.stop_indicator = True + + if disable_multiturn: + qeff_generator_model.curr_cache_index = 0 + qeff_generator_model.generated_ids = [] + + # in case of muli-turn, reset in case of ctx length is exhausted + if qeff_generator_model.curr_cache_index >= qeff_generator_model.ctx_len - 1: + qeff_generator_model.curr_cache_index = 0 + + output = "" + yield "", chat_history + [(msg, output)] + + generate_args = { + "prompt": get_prompt(msg, "Give an brief answer."), + "sample": True, + "max_new_tokens": None, + } + + t = Thread(target=qeff_generator_model.generate, kwargs=generate_args) + t.start() + + for each in qeff_generator_model.streamer: + output += each + yield "", chat_history + [(msg, output)] + + t.join() + + +def stop(): + qeff_generator_model.stop_indicator = False + return + + +def run_clear(): + global qeff_flags + qeff_generator_model.curr_cache_index = 0 + qeff_generator_model.generated_ids = [] + qeff_flags = set() + return + + +def clear_img(img): + img.clear() + + +# Combined Interface +with gr.Blocks(theme=gr.themes.Soft(), css="demo.css") as demo: + gr.Markdown(title) + + with gr.Row(): + + with gr.Column(scale=7, variant="compact"): + gr.Markdown(subtitle_left) + + dropdown1 = gr.Dropdown( + list_of_tasks, + value=list_of_tasks[0], + label="Developer Use Case", + elem_id="task_id", + ) + + with gr.Row(): + textbox = gr.Textbox( + container=False, + show_label=False, + placeholder="Type your prompt here...", + interactive=True, + lines=2, + ) + + with gr.Row(): + chat = gr.Button("Launch on AI 100", variant="primary", size="sm") + + clear = gr.Button("Reset", size="sm") + + stop_btn = gr.Button("Stop", size="sm") + with gr.Column(): + # with gr.Group(): + chatbot = gr.Chatbot( + label="Response", + elem_id="chuanhu_chatbot", + ) + with gr.Column(variant="compact", scale=3, elem_id="qeff_id"): + gr.Markdown(subtitle_right) + + dropdown2 = gr.Dropdown( + list_of_models, + value=list_of_models[-1], + label="Pretrained model catalogue from Qualcomm Transformers Library", + elem_id="model_id", + ) + img = gr.Image( + show_label=False, + show_download_button=False, + container=True, + height=260, + width=480, + elem_id="qpc_id", + ) + # "block-size: inherit;" + qeff_output = gr.Textbox( + container=True, + show_label=False, + lines=4, + ) + with gr.Row(): + gr.Image( + "./img/full.png", + show_label=False, + show_download_button=False, + container=False, + ) + + chat.click(run_qeff_check, inputs=[dropdown1, dropdown2], outputs=[img]).then( + summary, inputs=[], outputs=[qeff_output] + ).then( + infer_prompt, + inputs=[textbox, chatbot, dropdown1, dropdown2], + outputs=[textbox, chatbot], + ) + + textbox.submit(run_qeff_check, inputs=[dropdown1, dropdown2], outputs=[img]).then( + summary, inputs=[], outputs=[qeff_output] + ).then( + infer_prompt, + inputs=[textbox, chatbot, dropdown1, dropdown2], + outputs=[textbox, chatbot], + ) + + stop_btn.click(fn=stop) + + clear.click(lambda: None, None, chatbot, queue=False).then( + lambda x: gr.update(value=""), [], [textbox] + ).then(lambda x: gr.update(value=""), [], [qeff_output]).then(fn=run_clear).then( + lambda: None, None, img, queue=False + ) + dropdown2.change(lambda x: gr.update(value=""), [], [qeff_output]).then( + lambda: None, None, img, queue=False + ) + + +demo.queue() +demo.launch( + server_name="0.0.0.0", + server_port=8085, + ssl_certfile="cert.pem", + ssl_keyfile="key.pem", + ssl_verify=False, + allowed_paths=[f"{os.getcwd()}"], +) diff --git a/app/demo.css b/app/demo.css new file mode 100644 index 000000000..0bed50d1c --- /dev/null +++ b/app/demo.css @@ -0,0 +1,116 @@ +/* +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +*/ +h1 { + text-align: center; +} + +h2 { + text-align: center; + font : "Shantell Sans"; +} + +/* #duplicate-button { + margin: auto; + color: white; + background: #1565c0; + border-radius: 100vh; +} + +#component-0 { + max-width: 900px; + margin: auto; + padding-top: 1.5rem; +} */ + +#qeff_id{ + /* background-color: var(--block-label-background-fill); */ + background-color: #3253DC;; +} + + +#banner-image { + /* animation: animName 4s linear infinite; */ + width: 480px !important; + display: block; + margin-left: auto; + margin-right: auto; + background-size: contain; + /* height:20vh; */ + /* background:#0091EA; */ + /* margin:20px; */ +} + + +img.svelte-1ucs3qg { + width: 480px; + height: auto; + object-fit: contain; + display: block; + border-radius: var(--radius-lg); +} + +@keyframes animName { + 0%{ + rotate: y 0deg; + } +50%{ + rotate: y 45deg; + } +75%{ + rotate: y -45deg; + } + 100%{ + rotate: y 0deg; +} + +} + + +#chuanhu_chatbot { + /* height: 40vh !important; */ + height: 100%; + height: 350px !important; +} + +[class = "image-container svelte-1l6wqyv"] { + height: inherit !important; +} + +#qpc_id{ + block-size: inherit; +} + + +[class *= "message"] { + /* border-radius: var(--radius-xl) !important; */ + /* border: none; */ + padding: var(--spacing-xl) !important; + font-size: var(--text-md) !important; + line-height: var(--line-md) !important; + min-height: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl)); + min-width: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl)); +} + +[class = "message-wrap svelte-12dsd9j bubble-gap"] { + gap : 0px !important; + /* max-width: 85%; + border-bottom-left-radius: 0 !important; */ +} + + +[class = "message user svelte-12dsd9j message-bubble-border"] { + text-align: right; + width: fit-content !important; + /* border-bottom-right-radius: 0 !important; */ +} + + +[class = "gradio-container gradio-container-4-8-0 svelte-1kyws56 app"]{ + max-width : 75vw !important; +} \ No newline at end of file diff --git a/app/img/box.png b/app/img/box.png new file mode 100644 index 000000000..cc88444c5 Binary files /dev/null and b/app/img/box.png differ diff --git a/app/img/full.png b/app/img/full.png new file mode 100644 index 000000000..ea8589769 Binary files /dev/null and b/app/img/full.png differ diff --git a/app/requirements.txt b/app/requirements.txt new file mode 100755 index 000000000..cee88ba2d --- /dev/null +++ b/app/requirements.txt @@ -0,0 +1,3 @@ +gradio==4.29 +json5 +python-dotenv \ No newline at end of file diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 000000000..74b2ea37e --- /dev/null +++ b/app/utils.py @@ -0,0 +1,91 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import json5 as json +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.llm_generator import LLMGenerator + +from transformers import TextIteratorStreamer + +generator_hub = {} + + +def get_app_config(): + f = open("app_config.json") + app_config = json.load(f) + f.close() + return app_config + + +def get_list_of_tasks(app_config=None): + if app_config is None: + app_config = get_app_config() + return list(app_config.keys()) + + +def get_list_of_models_all(app_config=None): + if app_config is None: + app_config = get_app_config() + list_of_models = [] + for task in app_config: + for model in app_config[task].keys(): + list_of_models.append(model) + return list_of_models + + +def get_list_of_models_task(app_config, task): + return list(app_config[task].keys()) + + +def get_data(task, model, app_config = None): + if app_config: + return app_config[task][model] + + app_config = get_app_config() + return app_config[task][model] + + +def load_models_artifacts(): + app_config = get_app_config() + for task in app_config: + generator_hub[task] = {} + for model in app_config[task].keys(): + data = app_config[task][model] + try: + generator_hub[task][model] = LLMGenerator( + qpc_path=data["qpc_path"], + model_name=data["model_name"], + device_id=data["device_id"], + prompt_len=data["prompt_len"], + ctx_len=data["ctx_len"], + streamer=TextIteratorStreamer, + ) + except Exception as err: + print(err) + generator_hub[task][model] = None + + print(generator_hub) + + +def get_generator(task, model, app_config = None): + if app_config is None: + app_config = get_app_config() + + if task in generator_hub.keys(): + if model in generator_hub[task].keys(): + if generator_hub[task][model] is None: + data = app_config[task][model] + generator_hub[task][model] = LLMGenerator( + qpc_path=data["qpc_path"], + model_name=data["model_name"], + device_id=data["device_id"], + prompt_len=data["prompt_len"], + ctx_len=data["ctx_len"], + streamer=TextIteratorStreamer, + ) + return generator_hub[task][model] + return None