diff --git a/docs/llamacpp.md b/docs/llamacpp.md index 137e2ff..ed94745 100644 --- a/docs/llamacpp.md +++ b/docs/llamacpp.md @@ -1,48 +1,126 @@ # LLAMA.CPP -Run transformer models using a Llama.cpp binary and checkpoint. This model can then be used with chatting or benchmarks such as MMLU. +Run transformer models using llama.cpp. This integration allows you to: +1. Load and run llama.cpp models +2. Benchmark model performance +3. Use the models with other tools like chat or MMLU accuracy testing ## Prerequisites -This flow has been verified with a generic Llama.cpp model. +You need: +1. A compiled llama.cpp executable (llama-cli or llama-cli.exe) +2. A GGUF model file -These instructions are only for linux or Windows with wsl. It may be necessary to be running WSL in an Administrator command prompt. +### Building llama.cpp (if needed) -These instructions also assumes that lemonade has been installed. +#### Linux +```bash +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +make +``` + +#### Windows +```bash +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +cmake -B build +cmake --build build --config Release +``` +The executable will be in `build/bin/Release/llama-cli.exe` on Windows or `llama-cli` in the root directory on Linux. -### Set up Environment (Assumes TurnkeyML is already installed) +## Usage -Build or obtain the Llama.cpp model and desired checkpoint. -For example (see the [llama.cpp](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md -) source for more details): -1. cd ~ -1. git clone https://github.com/ggerganov/llama.cpp -1. cd llama.cpp -1. make -1. cd models -1. wget https://huggingface.co/TheBloke/Dolphin-Llama2-7B-GGUF/resolve/main/dolphin-llama2-7b.Q5_K_M.gguf +### Loading a Model +Use the `load-llama-cpp` tool to load a model: -## Usage +```bash +lemonade -i MODEL_NAME load-llama-cpp \ + --executable PATH_TO_EXECUTABLE \ + --model-binary PATH_TO_GGUF_FILE +``` -The Llama.cpp tool currently supports the following parameters +Parameters: +| Parameter | Required | Default | Description | +|--------------|----------|---------|-------------------------------------------------------| +| executable | Yes | - | Path to llama-cli/llama-cli.exe | +| model-binary | Yes | - | Path to .gguf model file | +| threads | No | 1 | Number of threads for generation | +| context-size | No | 512 | Context window size | +| output-tokens| No | 512 | Maximum number of tokens to generate | -| Parameter | Definition | Default | -| --------- | ---------------------------------------------------- | ------- | -| executable | Path to the Llama.cpp-generated application binary | None | -| model-binary | Model checkpoint (do not use if --input is passed to lemonade) | None | -| threads | Number of threads to use for computation | 1 | -| context-size | Maximum context length | 512 | -| temp | Temperature to use for inference (leave out to use the application default) | None | +### Benchmarking -### Example (assuming Llama.cpp built and a checkpoint loaded as above) +After loading a model, you can benchmark it using `llama-cpp-bench`: ```bash -lemonade --input ~/llama.cpp/models/dolphin-llama2-7b.Q5_K_M.gguf load-llama-cpp --executable ~/llama.cpp/llama-cli accuracy-mmlu --ntrain 5 +lemonade -i MODEL_NAME \ + load-llama-cpp \ + --executable PATH_TO_EXECUTABLE \ + --model-binary PATH_TO_GGUF_FILE \ + llama-cpp-bench ``` -On windows, the llama.cpp binary might be in a different location (such as llama.cpp\build\bin\Release\), in which case the command mgiht be something like: +Benchmark parameters: +| Parameter | Default | Description | +|------------------|----------------------------|-------------------------------------------| +| prompt | "Hello, I am conscious and"| Input prompt for benchmarking | +| context-size | 512 | Context window size | +| output-tokens | 512 | Number of tokens to generate | +| iterations | 1 | Number of benchmark iterations | +| warmup-iterations| 0 | Number of warmup iterations (not counted) | + +The benchmark will measure and report: +- Time to first token (prompt evaluation time) +- Token generation speed (tokens per second) + +### Example Commands + +#### Windows Example ```bash -lemonade --input ~\llama.cpp\models\dolphin-llama2-7b.Q5_K_M.gguf load-llama-cpp --executable ~\llama.cpp\build\bin\Release\llama-cli accuracy-mmlu --ntrain 5 +# Load and benchmark a model +lemonade -i Qwen/Qwen2.5-0.5B-Instruct-GGUF \ + load-llama-cpp \ + --executable "C:\work\llama.cpp\build\bin\Release\llama-cli.exe" \ + --model-binary "C:\work\llama.cpp\models\qwen2.5-0.5b-instruct-fp16.gguf" \ + llama-cpp-bench \ + --iterations 3 \ + --warmup-iterations 1 + +# Run MMLU accuracy test +lemonade -i Qwen/Qwen2.5-0.5B-Instruct-GGUF \ + load-llama-cpp \ + --executable "C:\work\llama.cpp\build\bin\Release\llama-cli.exe" \ + --model-binary "C:\work\llama.cpp\models\qwen2.5-0.5b-instruct-fp16.gguf" \ + accuracy-mmlu \ + --tests management \ + --max-evals 2 ``` + +#### Linux Example +```bash +# Load and benchmark a model +lemonade -i Qwen/Qwen2.5-0.5B-Instruct-GGUF \ + load-llama-cpp \ + --executable "./llama-cli" \ + --model-binary "./models/qwen2.5-0.5b-instruct-fp16.gguf" \ + llama-cpp-bench \ + --iterations 3 \ + --warmup-iterations 1 +``` + +## Integration with Other Tools + +After loading with `load-llama-cpp`, the model can be used with any tool that supports the ModelAdapter interface, including: +- accuracy-mmlu +- llm-prompt +- accuracy-humaneval +- and more + +The integration provides: +- Platform-independent path handling (works on both Windows and Linux) +- Proper error handling with detailed messages +- Performance metrics collection +- Configurable generation parameters (temperature, top_p, top_k) diff --git a/src/lemonade/cli.py b/src/lemonade/cli.py index 5417673..66768be 100644 --- a/src/lemonade/cli.py +++ b/src/lemonade/cli.py @@ -14,7 +14,7 @@ from lemonade.tools.huggingface_bench import HuggingfaceBench from lemonade.tools.ort_genai.oga_bench import OgaBench - +from lemonade.tools.llamacpp_bench import LlamaCppBench from lemonade.tools.llamacpp import LoadLlamaCpp import lemonade.cache as cache @@ -30,6 +30,7 @@ def main(): tools = [ HuggingfaceLoad, LoadLlamaCpp, + LlamaCppBench, AccuracyMMLU, AccuracyHumaneval, AccuracyPerplexity, diff --git a/src/lemonade/tools/llamacpp.py b/src/lemonade/tools/llamacpp.py index 2e9ca76..2eb8dfc 100644 --- a/src/lemonade/tools/llamacpp.py +++ b/src/lemonade/tools/llamacpp.py @@ -1,77 +1,121 @@ import argparse import os -import subprocess from typing import Optional - +import subprocess from turnkeyml.state import State +import turnkeyml.common.status as status from turnkeyml.tools import FirstTool - -import turnkeyml.common.build as build -from .adapter import PassthroughTokenizer, ModelAdapter - - -def llamacpp_dir(state: State): - return os.path.join(build.output_dir(state.cache_dir, state.build_name), "llamacpp") +from lemonade.tools.adapter import PassthroughTokenizer, ModelAdapter +from lemonade.cache import Keys class LlamaCppAdapter(ModelAdapter): - unique_name = "llama-cpp-adapter" - - def __init__(self, executable, model, tool_dir, context_size, threads, temp): + def __init__(self, model, output_tokens, context_size, threads, executable): super().__init__() - self.executable = executable - self.model = model - self.tool_dir = tool_dir + self.model = os.path.normpath(model) + self.output_tokens = output_tokens self.context_size = context_size self.threads = threads - self.temp = temp + self.executable = os.path.normpath(executable) - def generate(self, input_ids: str, max_new_tokens: Optional[int] = None): + def generate( + self, + input_ids: str, + max_new_tokens: Optional[int] = None, + temperature: float = 0.8, + top_p: float = 0.95, + top_k: int = 40, + **kwargs, # pylint: disable=unused-argument + ): """ Pass a text prompt into the llamacpp inference CLI. The input_ids arg here should receive the original text that would normally be encoded by a tokenizer. + + Args: + input_ids: The input text prompt + max_new_tokens: Maximum number of tokens to generate + temperature: Temperature for sampling (0.0 = greedy) + top_p: Top-p sampling threshold + top_k: Top-k sampling threshold + **kwargs: Additional arguments (ignored) + + Returns: + List containing a single string with the generated text """ + prompt = input_ids + n_predict = max_new_tokens if max_new_tokens is not None else self.output_tokens + cmd = [ self.executable, + "-m", + self.model, + "--ctx-size", + str(self.context_size), + "-n", + str(n_predict), + "-t", + str(self.threads), + "-p", + prompt, + "--temp", + str(temperature), + "--top-p", + str(top_p), + "--top-k", + str(top_k), "-e", ] - optional_params = { - "ctx-size": self.context_size, - "n-predict": max_new_tokens, - "threads": self.threads, - "model": self.model, - "prompt": input_ids, - "temp": self.temp, - } - - for flag, value in optional_params.items(): - if value is not None: - cmd.append(f"--{flag} {value}") - cmd = [str(m) for m in cmd] - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - ) - - raw_output, raw_err = process.communicate() - - if process.returncode != 0: - raise subprocess.CalledProcessError( - process.returncode, process.args, raw_output, raw_err + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + errors="replace", ) + raw_output, stderr = process.communicate() + if process.returncode != 0: + error_msg = f"llama.cpp failed with return code {process.returncode}.\n" + error_msg += f"Command: {' '.join(cmd)}\n" + error_msg += f"Error output:\n{stderr}\n" + error_msg += f"Standard output:\n{raw_output}" + raise Exception(error_msg) + + if raw_output is None: + raise Exception("No output received from llama.cpp process") + + # Parse timing information + for line in raw_output.splitlines(): + if "llama_perf_context_print: eval time =" in line: + parts = line.split("(")[1].strip() + parts = parts.split(",") + ms_per_token = float(parts[0].split("ms per token")[0].strip()) + self.tokens_per_second = ( + 1000 / ms_per_token if ms_per_token > 0 else 0 + ) + if "llama_perf_context_print: prompt eval time =" in line: + parts = line.split("=")[1].split("/")[0] + time_to_first_token_ms = float(parts.split("ms")[0].strip()) + self.time_to_first_token = time_to_first_token_ms / 1000 + + except Exception as e: + error_msg = f"Failed to run llama.cpp command: {str(e)}\n" + error_msg += f"Command: {' '.join(cmd)}" + raise Exception(error_msg) + + # Find where the prompt ends and the generated text begins prompt_found = False output_text = "" - prompt_first_line = input_ids.split("\n")[0] + prompt_first_line = prompt.split("\n")[0] for line in raw_output.splitlines(): if prompt_first_line in line: prompt_found = True @@ -82,6 +126,7 @@ def generate(self, input_ids: str, max_new_tokens: Optional[int] = None): if not prompt_found: raise Exception("Prompt not found in result, this is a bug in lemonade.") + # Return list containing the generated text return [output_text] @@ -102,7 +147,7 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser: "--executable", required=True, type=str, - help="Executable name", + help="Path to the llama.cpp executable (e.g., llama-cli or llama-cli.exe)", ) default_threads = 1 @@ -123,17 +168,20 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser: help=f"Context size of the prompt (default: {context_size})", ) + output_tokens = 512 parser.add_argument( - "--model-binary", + "--output-tokens", required=False, - help="Path to a .gguf model to use with benchmarking.", + type=int, + default=output_tokens, + help=f"Maximum number of output tokens the LLM should make (default: {output_tokens})", ) parser.add_argument( - "--temp", - type=float, - required=False, - help="Temperature", + "--model-binary", + required=True, + type=str, + help="Path to a .gguf model file", ) return parser @@ -141,32 +189,33 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser: def run( self, state: State, - input: str = None, - context_size: int = None, - threads: int = None, + input: str = "", + context_size: int = 512, + threads: int = 1, + output_tokens: int = 512, + model_binary: Optional[str] = None, executable: str = None, - model_binary: str = None, - temp: float = None, ) -> State: """ - Create a tokenizer instance and model instance in `state` that support: - - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - response = model.generate(input_ids, max_new_tokens=1) - response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip() + Load a llama.cpp model """ if executable is None: - raise Exception(f"{self.__class__.unique_name} requires an executable") + raise Exception(f"{self.__class__.unique_name} requires an executable path") + + # Convert paths to platform-specific format + executable = os.path.normpath(executable) - if input is not None and input != "": + if model_binary: + model_to_use = os.path.normpath(model_binary) + else: model_binary = input + model_to_use = os.path.normpath(model_binary) if model_binary else None - # Save execution parameters - state.save_stat("context_size", context_size) - state.save_stat("threads", threads) + if not model_binary: + model_to_use = state.get(Keys.MODEL) - if model_binary is None: + if model_to_use is None: raise Exception( f"{self.__class__.unique_name} requires the preceding tool to pass a " "Llamacpp model, " @@ -174,13 +223,16 @@ def run( ) state.model = LlamaCppAdapter( - executable=executable, - model=model_binary, - tool_dir=llamacpp_dir(state), + model=model_to_use, + output_tokens=output_tokens, context_size=context_size, threads=threads, - temp=temp, + executable=executable, ) state.tokenizer = PassthroughTokenizer() + # Save stats about the model + state.save_stat(Keys.CHECKPOINT, model_to_use) + status.add_to_state(state=state, name=input, model=model_to_use) + return state diff --git a/src/lemonade/tools/llamacpp_bench.py b/src/lemonade/tools/llamacpp_bench.py new file mode 100644 index 0000000..d3a08a6 --- /dev/null +++ b/src/lemonade/tools/llamacpp_bench.py @@ -0,0 +1,235 @@ +import argparse +import os +import subprocess +import statistics +import tqdm +from turnkeyml.state import State +from turnkeyml.tools import Tool +from lemonade.cache import Keys +import lemonade.tools.ort_genai.oga_bench as general +from lemonade.tools.llamacpp import LlamaCppAdapter + + +class LlamaCppBench(Tool): + unique_name = "llama-cpp-bench" + + def __init__(self): + super().__init__(monitor_message="Benchmarking LlamaCPP model") + self.status_stats = [ + Keys.SECONDS_TO_FIRST_TOKEN, + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, + ] + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Benchmark a LLM via llama.cpp", + add_help=add_help, + ) + + parser.add_argument( + "--prompt", + "-p", + required=False, + default=general.default_prompt, + help="Input prompt to the LLM. Three formats are supported. " + f"1) integer (default: {general.default_prompt}): " + "use a synthetic prompt with the specified length. " + "2) str: use a user-provided prompt string " + "3) path/to/prompt.txt: load the prompt from a text file.", + ) + + context_size = 512 + parser.add_argument( + "--context-size", + required=False, + type=int, + default=context_size, + help=f"Context size of the prompt (default: {context_size})", + ) + + output_tokens = 512 + parser.add_argument( + "--output-tokens", + required=False, + type=int, + default=output_tokens, + help=f"Maximum number of output tokens the LLM should make (default: {output_tokens})", + ) + + default_iterations = 1 + parser.add_argument( + "--iterations", + "-i", + required=False, + type=int, + default=default_iterations, + help=f"Number of benchmarking iterations to run (default: {default_iterations})", + ) + + default_warmup_runs = 0 + parser.add_argument( + "--warmup-iterations", + "-w", + required=False, + type=int, + default=default_warmup_runs, + help="Number of benchmarking iterations to use for cache warmup " + "(the results of these iterations " + f"are not included in the results; default: {default_warmup_runs})", + ) + + return parser + + def parse(self, state: State, args, known_only=True) -> argparse.Namespace: + """ + Helper function to parse CLI arguments into the args expected + by run() + """ + + parsed_args = super().parse(state, args, known_only) + + # Decode prompt arg into a string prompt + if parsed_args.prompt.isdigit(): + # Generate a prompt with the requested length + length = int(parsed_args.prompt) + parsed_args.prompt = "word " * (length - 2) + + elif os.path.exists(parsed_args.prompt): + with open(parsed_args.prompt, "r", encoding="utf-8") as f: + parsed_args.prompt = f.read() + + else: + # No change to the prompt + pass + + return parsed_args + + def run( + self, + state: State, + prompt: str = general.default_prompt, + context_size: int = len(general.default_prompt), + output_tokens: int = general.default_output_tokens, + iterations: int = general.default_iterations, + warmup_iterations: int = general.default_warmup_runs, + ) -> State: + """ + Benchmark llama.cpp model that was loaded by LoadLlamaCpp. + """ + + # Save benchmarking parameters + state.save_stat("prompt", prompt) + state.save_stat("output_tokens", output_tokens) + state.save_stat("context_size", context_size) + state.save_stat("iterations", iterations) + state.save_stat("warmup_iterations", warmup_iterations) + + if not hasattr(state, "model") or not isinstance(state.model, LlamaCppAdapter): + raise Exception( + f"{self.__class__.unique_name} requires a LlamaCppAdapter model to be " + "loaded first. Please run load-llama-cpp before this tool." + ) + + iteration_tokens_per_second = [] + iteration_time_to_first_token = [] + + for iteration in tqdm.tqdm( + range(iterations), desc="iterations", disable=iterations < 2 + ): + cmd = [ + state.model.executable, + "-m", + state.model.model, + "--ctx-size", + str(context_size), + "-n", + str(output_tokens), + "-t", + str(state.model.threads), + "-p", + prompt, + "-e", + ] + + cmd = [str(m) for m in cmd] + + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + errors="replace", + ) + + raw_output, stderr = process.communicate() + if process.returncode != 0: + error_msg = ( + f"llama.cpp failed with return code {process.returncode}.\n" + ) + error_msg += f"Command: {' '.join(cmd)}\n" + error_msg += f"Error output:\n{stderr}\n" + error_msg += f"Standard output:\n{raw_output}" + raise Exception(error_msg) + + if raw_output is None: + raise Exception("No output received from llama.cpp process") + + except Exception as e: + error_msg = f"Failed to run llama.cpp command: {str(e)}\n" + error_msg += f"Command: {' '.join(cmd)}" + raise Exception(error_msg) + + ms_per_token = None + time_to_first_token_ms = None + for line in raw_output.splitlines(): + if "llama_perf_context_print: eval time =" in line: + parts = line.split("(")[1].strip() + parts = parts.split(",") + ms_per_token = float(parts[0].split("ms per token")[0].strip()) + if "llama_perf_context_print: prompt eval time =" in line: + parts = line.split("=")[1].split("/")[0] + time_to_first_token_ms = float(parts.split("ms")[0].strip()) + + if ms_per_token is None or time_to_first_token_ms is None: + # Look in stderr as well since some versions of llama.cpp output timing there + for line in stderr.splitlines(): + if "llama_perf_context_print: eval time =" in line: + parts = line.split("(")[1].strip() + parts = parts.split(",") + ms_per_token = float(parts[0].split("ms per token")[0].strip()) + if "llama_perf_context_print: prompt eval time =" in line: + parts = line.split("=")[1].split("/")[0] + time_to_first_token_ms = float(parts.split("ms")[0].strip()) + + if ms_per_token is None or time_to_first_token_ms is None: + error_msg = "Could not find timing information in llama.cpp output.\n" + error_msg += "Raw output:\n" + raw_output + "\n" + error_msg += "Error output:\n" + stderr + raise Exception(error_msg) + + # When output_tokens is set to 1 for accuracy tests, ms_per_token tends to 0 + # and causes a divide-by-zero error. Set tokens_per_second to 0 in such cases + # as performance data for generating a few tokens is not relevant. + tokens_per_second = 0 + if output_tokens > 5 and ms_per_token > 0: + tokens_per_second = 1000 / ms_per_token + time_to_first_token = time_to_first_token_ms / 1000 + + if iteration > warmup_iterations - 1: + iteration_tokens_per_second.append(tokens_per_second) + iteration_time_to_first_token.append(time_to_first_token) + + token_generation_tokens_per_second = statistics.mean( + iteration_tokens_per_second + ) + mean_time_to_first_token = statistics.mean(iteration_time_to_first_token) + + state.save_stat( + Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, token_generation_tokens_per_second + ) + state.save_stat(Keys.SECONDS_TO_FIRST_TOKEN, mean_time_to_first_token) + + return state diff --git a/test/llm_api.py b/test/llm_api.py index 8f7c639..3f57b27 100644 --- a/test/llm_api.py +++ b/test/llm_api.py @@ -2,6 +2,10 @@ import shutil import os import urllib3 +import platform +import zipfile +import requests +import logging from turnkeyml.state import State import turnkeyml.common.filesystem as fs import turnkeyml.common.test_helpers as common @@ -10,10 +14,27 @@ from lemonade.tools.mmlu import AccuracyMMLU from lemonade.tools.humaneval import AccuracyHumaneval from lemonade.tools.chat import LLMPrompt +from lemonade.tools.llamacpp import LoadLlamaCpp +from lemonade.tools.llamacpp_bench import LlamaCppBench from lemonade.cache import Keys +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + ci_mode = os.getenv("LEMONADE_CI_MODE", False) +# Get cache directory from environment or create a new one +cache_dir = os.getenv('LEMONADE_CACHE_DIR') +if not cache_dir: + cache_dir, _ = common.create_test_dir("lemonade_api") + os.environ['LEMONADE_CACHE_DIR'] = cache_dir + +logger.info(f"Using cache directory: {cache_dir}") + try: url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" resp = urllib3.request("GET", url, preload_content=False) @@ -26,6 +47,190 @@ eecs_berkeley_edu_cannot_be_reached = True +def download_llamacpp_binary(): + """Download the appropriate llama.cpp binary for the current platform""" + logger.info("Starting llama.cpp binary download...") + + # Get latest release info + releases_url = "https://api.github.com/repos/ggerganov/llama.cpp/releases/latest" + try: + response = requests.get(releases_url) + response.raise_for_status() + latest_release = response.json() + logger.info(f"Found latest release: {latest_release.get('tag_name', 'unknown')}") + except Exception as e: + logger.error(f"Failed to fetch latest release info: {str(e)}") + raise + + # Determine platform-specific binary pattern + system = platform.system().lower() + machine = platform.machine().lower() + logger.info(f"Detected platform: {system} {machine}") + + if system == "windows": + # Windows uses AVX2 by default + asset_pattern = "win-avx2-x64" + elif system == "linux": + asset_pattern = "ubuntu-x64" + else: + error_msg = f"Unsupported platform: {system}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + # Find matching asset + matching_assets = [ + asset for asset in latest_release["assets"] + if ( + asset["name"].lower().startswith("llama-") and + asset_pattern in asset["name"].lower() + ) + ] + + if not matching_assets: + error_msg = ( + f"No matching binary found for {system} {machine}. " + f"Looking for pattern: {asset_pattern}" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + + asset = matching_assets[0] + logger.info(f"Found matching asset: {asset['name']}") + + # Create binaries directory + binary_dir = os.path.join(cache_dir, "llama_cpp_binary") + os.makedirs(binary_dir, exist_ok=True) + logger.info(f"Created binary directory: {binary_dir}") + + # Download and extract + zip_path = os.path.join(binary_dir, asset["name"]) + try: + response = requests.get(asset["browser_download_url"]) + response.raise_for_status() + + with open(zip_path, "wb") as f: + f.write(response.content) + logger.info(f"Downloaded binary to: {zip_path}") + + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(binary_dir) + logger.info("Extracted binary archive") + except Exception as e: + logger.error(f"Failed to download or extract binary: {str(e)}") + raise + + # Find the executable + if system == "windows": + executable = os.path.join(binary_dir, "llama-cli.exe") + else: + executable = os.path.join(binary_dir, "llama-cli") + # Make executable on Linux + os.chmod(executable, 0o755) + + if not os.path.exists(executable): + error_msg = ( + f"Expected executable not found at {executable} after extraction. " + f"Contents of {binary_dir}: {os.listdir(binary_dir)}" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info(f"Successfully prepared executable at: {executable}") + return executable + + +class TestLlamaCpp(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Download llama.cpp binary once for all tests""" + logger.info("Setting up TestLlamaCpp class...") + try: + cls.executable = download_llamacpp_binary() + except Exception as e: + error_msg = f"Failed to download llama.cpp binary: {str(e)}" + logger.error(error_msg) + raise unittest.SkipTest(error_msg) + + # Use a small GGUF model for testing + cls.model_name = "Qwen/Qwen2.5-0.5B-Instruct-GGUF" + cls.model_file = "qwen2.5-0.5b-instruct-fp16.gguf" + logger.info(f"Using test model: {cls.model_name}/{cls.model_file}") + + # Download the model file + try: + model_url = f"https://huggingface.co/{cls.model_name}/resolve/main/{cls.model_file}" + cls.model_path = os.path.join(cache_dir, cls.model_file) + + if not os.path.exists(cls.model_path): + logger.info(f"Downloading model from: {model_url}") + response = requests.get(model_url) + response.raise_for_status() + with open(cls.model_path, "wb") as f: + f.write(response.content) + logger.info(f"Model downloaded to: {cls.model_path}") + else: + logger.info(f"Using existing model at: {cls.model_path}") + except Exception as e: + error_msg = f"Failed to download test model: {str(e)}" + logger.error(error_msg) + raise unittest.SkipTest(error_msg) + + def setUp(self): + self.state = State( + cache_dir=cache_dir, + build_name="test_llamacpp", + ) + + def test_001_load_model(self): + """Test loading a model with llama.cpp""" + state = LoadLlamaCpp().run( + self.state, + executable=self.executable, + model_binary=self.model_path, + context_size=512, + threads=1 + ) + + self.assertIsNotNone(state.model) + + def test_002_generate_text(self): + """Test text generation with llama.cpp""" + state = LoadLlamaCpp().run( + self.state, + executable=self.executable, + model_binary=self.model_path + ) + + prompt = "What is the capital of France?" + state = LLMPrompt().run(state, prompt=prompt, max_new_tokens=20) + + self.assertIsNotNone(state.response) + self.assertGreater(len(state.response), len(prompt)) + + def test_003_benchmark(self): + """Test benchmarking with llama.cpp""" + state = LoadLlamaCpp().run( + self.state, + executable=self.executable, + model_binary=self.model_path + ) + + # Use longer output tokens to ensure we get valid performance metrics + state = LlamaCppBench().run( + state, + iterations=2, + warmup_iterations=1, + output_tokens=128, + prompt="Hello, I am a test prompt that is long enough to get meaningful metrics." + ) + + stats = fs.Stats(state.cache_dir, state.build_name).stats + + # Check if we got valid metrics + self.assertIn(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND, stats) + self.assertIn(Keys.SECONDS_TO_FIRST_TOKEN, stats) + + class Testing(unittest.TestCase): def setUp(self) -> None: shutil.rmtree(cache_dir, ignore_errors=True) @@ -109,4 +314,12 @@ def test_001_huggingface_bench(self): if __name__ == "__main__": cache_dir, _ = common.create_test_dir("lemonade_api") - unittest.main() + + # Create test suite with all test classes + suite = unittest.TestSuite() + suite.addTests(unittest.TestLoader().loadTestsFromTestCase(Testing)) + suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestLlamaCpp)) + + # Run the test suite + runner = unittest.TextTestRunner() + runner.run(suite)