Skip to content

Commit

Permalink
add minicpm3 chat model
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Oct 27, 2024
1 parent 40421b2 commit f0b2e0f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/chart_models_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def create_chat_models_comparison_plot():
"long_context": {
"models": [
"Phi 3.5 Mini - 4b",
# "MiniCPM3 - 4b",
"Qwen 2.5 - 7b",
"Internlm2_5 - 7b",
"Yi Coder - 9b",
Expand Down
15 changes: 14 additions & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
'Zephyr - 3b': 4096,
'Qwen 2.5 - 3b': 4096,
'Llama 3.2 - 3b': 4096,
'Internlm2_5 - 1.8b': 4096
'Internlm2_5 - 1.8b': 4096,
# 'MiniCPM3 - 4b': 4096
}

# changes the default of 1024 in module_chat.mpy
Expand All @@ -21,6 +22,7 @@
'Zephyr - 3b': 512,
'Qwen 2.5 - 3b': 512,
'Internlm2_5 - 1.8b': 512,
'MiniCPM3 - 4b': 512,
}

CHAT_MODELS = {
Expand Down Expand Up @@ -108,6 +110,17 @@
'precision': 'bfloat16',
'gated': False,
},
'MiniCPM3 - 4b': {
'model': 'MiniCPM3 - 4b',
'repo_id': 'openbmb/MiniCPM3-4B',
'cache_dir': 'openbmb--MiniCPM3-4B',
'cps': 80.67,
'context_length': 8192,
'vram': 4998.10,
'function': 'MiniCPM3_4b',
'precision': 'bfloat16',
'gated': False,
},
'Qwen 2.5 - 7b': {
'model': 'Qwen 2.5 - 7b',
'repo_id': 'Qwen/Qwen2.5-7B-Instruct',
Expand Down
18 changes: 18 additions & 0 deletions src/module_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def generate_response(self, inputs):

generation_thread.join()


class Qwen2_5_3b(BaseModel):
def __init__(self, generation_settings):
model_info = CHAT_MODELS['Qwen 2.5 - 3b']
Expand Down Expand Up @@ -328,6 +329,23 @@ def create_prompt(self, augmented_query):
"""


class MiniCPM3_4b(BaseModel):
def __init__(self, generation_settings):
model_info = CHAT_MODELS['MiniCPM3 - 4b']
super().__init__(model_info, bnb_bfloat16_settings, generation_settings)

def create_prompt(self, augmented_query):
return f"""<s><|im_start|>user
{augmented_query}<|im_end|>
<|im_start|>assistant
"""

def create_inputs(self, prompt):
inputs = super().create_inputs(prompt)
inputs['pad_token_id'] = self.tokenizer.pad_token_id
return inputs


class InternLM2_5_7b(BaseModel):
def __init__(self, generation_settings):
model_info = CHAT_MODELS['Internlm2_5 - 7b']
Expand Down
21 changes: 18 additions & 3 deletions src/setup_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import time

# ctranslate2==4.5.0 now requires cudnn 9+, which works with CUDA 12.3+; however, torch 2.3.1 only supports up to CUDA 12.1
# SUPPORTS Phi3.5 and Mistral Nemo...AWQ support was added in 4.4.0. FA was added in 4.3.1 but removed in 4.4.0
# Therefore, torch 2.4.0+ is required to use cuDNN 9+ with ctranslate2 4.5.0+.
# This consequently requires installing at least xformers 0.0.27.post1, flash-attn 2.6.3, and triton
"""
# torch 2.5.0 - supports CUDA 11.8, 12.1, and 12.4
# torch 2.4.1 - supports CUDA 11.8, 12.1, and 12.4
Expand All @@ -27,9 +27,10 @@
# Flash-attn 2.6.3 is currently the only build that supports torch 2.4.0, but it only supports up to CUDA 12.3 (released 7/25/2024)
# This is problematic since some models like florence2, minicpm2.6, phi 3.5 mini, and deepseek coder require FA and can't run on SDPA...
# The FA2 repo mainter said it's compatible with newer versions and this isn't an issue.
# xformers 0.0.26.post1 - requires torch 2.3.0
# xformers 0.0.27 - requires torch 2.3.0 but also states "some might require torch 2.4".
# xformers 0.0.27 - requires torch 2.3.0 but also states "some operation might require torch 2.4".
# xformers 0.0.27.post1 - requires torch 2.4.0
# xformers 0.0.27.post2 - requires torch 2.4.0
# xformers 0.0.28.post1 (non-post1 release was not successfully uploaded to pypi) - requires torch 2.4.1
Expand Down Expand Up @@ -364,7 +365,21 @@ def install_libraries(libraries):
"xlrd==2.0.1",
"xxhash==3.4.1",
"yarl==1.9.4",
"zipp==3.19.2"
"zipp==3.19.2",
# the following are only required by minicpm3 chat model
"argcomplete==3.5.0",
"black==24.8.0",
"datamodel_code_generator==0.26.0",
"dnspython==2.7.0",
"email-validator==2.2.0",
"genson==1.3.0",
"inflect==5.6.2",
"isort==5.13.2",
"jsonschema==4.23.0",
"jsonschema-specifications==2023.12.1",
"pathspec==0.12.1",
"referencing==0.35.1",
"rpds-py==0.20.0",
]

full_install_libraries = [
Expand Down

0 comments on commit f0b2e0f

Please sign in to comment.