diff --git a/src/chart_models_chat.py b/src/chart_models_chat.py index 9e40c09a..724f2234 100644 --- a/src/chart_models_chat.py +++ b/src/chart_models_chat.py @@ -20,6 +20,7 @@ def create_chat_models_comparison_plot(): "long_context": { "models": [ "Phi 3.5 Mini - 4b", + # "MiniCPM3 - 4b", "Qwen 2.5 - 7b", "Internlm2_5 - 7b", "Yi Coder - 9b", diff --git a/src/constants.py b/src/constants.py index f66a9832..f27c61a3 100644 --- a/src/constants.py +++ b/src/constants.py @@ -10,7 +10,8 @@ 'Zephyr - 3b': 4096, 'Qwen 2.5 - 3b': 4096, 'Llama 3.2 - 3b': 4096, - 'Internlm2_5 - 1.8b': 4096 + 'Internlm2_5 - 1.8b': 4096, + # 'MiniCPM3 - 4b': 4096 } # changes the default of 1024 in module_chat.mpy @@ -21,6 +22,7 @@ 'Zephyr - 3b': 512, 'Qwen 2.5 - 3b': 512, 'Internlm2_5 - 1.8b': 512, + 'MiniCPM3 - 4b': 512, } CHAT_MODELS = { @@ -108,6 +110,17 @@ 'precision': 'bfloat16', 'gated': False, }, + 'MiniCPM3 - 4b': { + 'model': 'MiniCPM3 - 4b', + 'repo_id': 'openbmb/MiniCPM3-4B', + 'cache_dir': 'openbmb--MiniCPM3-4B', + 'cps': 80.67, + 'context_length': 8192, + 'vram': 4998.10, + 'function': 'MiniCPM3_4b', + 'precision': 'bfloat16', + 'gated': False, + }, 'Qwen 2.5 - 7b': { 'model': 'Qwen 2.5 - 7b', 'repo_id': 'Qwen/Qwen2.5-7B-Instruct', diff --git a/src/module_chat.py b/src/module_chat.py index 6e24b111..c7bc9133 100644 --- a/src/module_chat.py +++ b/src/module_chat.py @@ -281,6 +281,7 @@ def generate_response(self, inputs): generation_thread.join() + class Qwen2_5_3b(BaseModel): def __init__(self, generation_settings): model_info = CHAT_MODELS['Qwen 2.5 - 3b'] @@ -328,6 +329,23 @@ def create_prompt(self, augmented_query): """ +class MiniCPM3_4b(BaseModel): + def __init__(self, generation_settings): + model_info = CHAT_MODELS['MiniCPM3 - 4b'] + super().__init__(model_info, bnb_bfloat16_settings, generation_settings) + + def create_prompt(self, augmented_query): + return f"""<|im_start|>user +{augmented_query}<|im_end|> +<|im_start|>assistant +""" + + def create_inputs(self, prompt): + inputs = super().create_inputs(prompt) + inputs['pad_token_id'] = self.tokenizer.pad_token_id + return inputs + + class InternLM2_5_7b(BaseModel): def __init__(self, generation_settings): model_info = CHAT_MODELS['Internlm2_5 - 7b'] diff --git a/src/setup_windows.py b/src/setup_windows.py index f38dd320..e49f7c80 100644 --- a/src/setup_windows.py +++ b/src/setup_windows.py @@ -8,8 +8,8 @@ import time # ctranslate2==4.5.0 now requires cudnn 9+, which works with CUDA 12.3+; however, torch 2.3.1 only supports up to CUDA 12.1 +# SUPPORTS Phi3.5 and Mistral Nemo...AWQ support was added in 4.4.0. FA was added in 4.3.1 but removed in 4.4.0 # Therefore, torch 2.4.0+ is required to use cuDNN 9+ with ctranslate2 4.5.0+. -# This consequently requires installing at least xformers 0.0.27.post1, flash-attn 2.6.3, and triton """ # torch 2.5.0 - supports CUDA 11.8, 12.1, and 12.4 # torch 2.4.1 - supports CUDA 11.8, 12.1, and 12.4 @@ -27,9 +27,10 @@ # Flash-attn 2.6.3 is currently the only build that supports torch 2.4.0, but it only supports up to CUDA 12.3 (released 7/25/2024) # This is problematic since some models like florence2, minicpm2.6, phi 3.5 mini, and deepseek coder require FA and can't run on SDPA... +# The FA2 repo mainter said it's compatible with newer versions and this isn't an issue. # xformers 0.0.26.post1 - requires torch 2.3.0 -# xformers 0.0.27 - requires torch 2.3.0 but also states "some might require torch 2.4". +# xformers 0.0.27 - requires torch 2.3.0 but also states "some operation might require torch 2.4". # xformers 0.0.27.post1 - requires torch 2.4.0 # xformers 0.0.27.post2 - requires torch 2.4.0 # xformers 0.0.28.post1 (non-post1 release was not successfully uploaded to pypi) - requires torch 2.4.1 @@ -364,7 +365,21 @@ def install_libraries(libraries): "xlrd==2.0.1", "xxhash==3.4.1", "yarl==1.9.4", - "zipp==3.19.2" + "zipp==3.19.2", + # the following are only required by minicpm3 chat model + "argcomplete==3.5.0", + "black==24.8.0", + "datamodel_code_generator==0.26.0", + "dnspython==2.7.0", + "email-validator==2.2.0", + "genson==1.3.0", + "inflect==5.6.2", + "isort==5.13.2", + "jsonschema==4.23.0", + "jsonschema-specifications==2023.12.1", + "pathspec==0.12.1", + "referencing==0.35.1", + "rpds-py==0.20.0", ] full_install_libraries = [