From 6586337271e9019d6d9de99e52f3df1da02eed73 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Fri, 3 Jan 2025 12:53:56 -0800 Subject: [PATCH] docstring; example csv; makedirs --- scripts/flops_by_perf_figure.py | 77 +++++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/scripts/flops_by_perf_figure.py b/scripts/flops_by_perf_figure.py index 0cbdfcd23..e062cdd66 100644 --- a/scripts/flops_by_perf_figure.py +++ b/scripts/flops_by_perf_figure.py @@ -1,23 +1,62 @@ -import matplotlib.pyplot as plt +""" + +Plot for the performance vs FLOPs figure. + +CSV file of results should look like: + + Model,FLOPs,Average,ARC Challenge,HSwag,WinoG,MMLU,DROP,NQ,AGIEval,GSM8k,MMLU Pro,TriviaQA + Amber-7B,5.091E+22,35.2,44.9,74.5,65.5,24.7,26.1,18.7,21.8,4.8,11.7,59.3 + DCLM-7B,1.033E+23,56.9,79.8,82.3,77.3,64.4,39.3,28.8,47.5,46.1,31.3,72.1 + Gemma-2-9B,4.436E+23,67.8,89.5,87.3,78.8,70.6,63,38,57.3,70.1,42,81.8 + Llama-2-13B,1.562E+23,54.1,67.3,83.9,74.9,55.7,45.6,38.4,41.5,28.1,23.9,81.3 + Llama-3.1-8B,7.227E+23,61.8,79.5,81.6,76.6,66.9,56.4,33.9,51.3,56.5,34.7,80.3 + MAP-Neo-7B,2.106E+23,49.6,78.4,72.8,69.2,58,39.4,28.9,45.8,12.5,25.9,65.1 + Mistral-7B-v0.3,,58.8,78.3,83.1,77.7,63.5,51.8,37.2,47.3,40.1,30,79.3 + Mistral-Nemo-Bs-12B,,66.9,85.2,85.6,81.5,69.5,69.2,39.7,54.7,62.1,36.7,84.6 + OLMo-0424-7B,8.679E+22,50.7,66.9,80.1,73.6,54.3,50,29.6,43.9,27.7,22.1,58.8 + OLMo-2-1124-13B,4.609E+23,68.3,83.5,86.4,81.5,67.5,70.7,46.7,54.2,75.1,35.1,81.9 + OLMo-2-1124-7B,1.771E+23,62.9,79.8,83.8,77.2,63.7,60.8,36.9,50.4,67.5,31,78 + OLMo-7B,1.018E+23,38.3,46.4,78.1,68.5,28.3,27.3,24.8,23.7,9.2,12.1,64.1 + Qwen-2.5-14B,1.595E+24,72.2,94.0,94,80,79.3,51.5,37.3,71,83.4,52.8,79.1 + Qwen-2.5-7B,8.225E+23,67.4,89.5,89.7,74.2,74.4,55.8,29.9,63.7,81.5,45.8,69.4 + StableLM-2-12B,2.929E+23,62.2,81.9,84.5,77.7,62.4,55.5,37.6,50.9,62,29.3,79.9 + Zamba-2-7B,,65.2,92.2,89.4,79.6,68.5,51.7,36.5,55.5,67.2,32.8,78.8 + +Invocation looks like: + + python scripts/flops_by_perf_figure.py /path/to/results.csv output/ + +@kyleclo, @soldni + +""" + import argparse +import os + +import matplotlib.pyplot as plt import numpy as np import pandas as pd -from matplotlib import font_manager from cached_path import cached_path - +from matplotlib import font_manager ap = argparse.ArgumentParser() ap.add_argument("results_data_path", type=str, help="Path to the results data CSV file.") ap.add_argument("output_dir", type=str, help="Path to the output directory") -ap.add_argument("--manrope-medium-font-path", type=str, help="Path to the Manrope Medium font file", default="https://dolma-artifacts.org/Manrope-Medium.ttf") +ap.add_argument( + "--manrope-medium-font-path", + type=str, + help="Path to the Manrope Medium font file", + default="https://dolma-artifacts.org/Manrope-Medium.ttf", +) args = ap.parse_args() # Add Manrope font font_manager.fontManager.addfont(cached_path(args.manrope_medium_font_path)) -plt.rcParams['font.family'] = 'Manrope' -plt.rcParams['font.weight'] = 'medium' +plt.rcParams["font.family"] = "Manrope" +plt.rcParams["font.weight"] = "medium" +os.makedirs(args.output_dir, exist_ok=True) OUTPUT_PATHS = [f"{args.output_dir}/olmo2.pdf", f"{args.output_dir}/olmo2.png"] df = pd.read_csv(args.results_data_path) @@ -82,7 +121,7 @@ "Partially open": AI2_DARK_TEAL, "Other fully open": AI2_DARK_TEAL, "Previous OLMo": AI2_DARK_TEAL, - "Latest OLMo": "#a51c5c", # darker pink + "Latest OLMo": "#a51c5c", # darker pink } @@ -171,19 +210,21 @@ textcoords="offset points", fontsize=FONTSIZE, alpha=1.0, - font='Manrope', - weight='medium', + font="Manrope", + weight="medium", color=category_to_text_color[model_name_to_open_status[row[MODEL_COLUMN_NAME]]], ) # x axis tick marks tick_locations = [4e22, 6e22, 8e22, 1e23, 2e23, 4e23, 6e23, 8e23, 1e24, 2e24] + def format_scientific(x): exponent = int(np.log10(x)) mantissa = x / (10**exponent) return f"{int(mantissa)}×10{str(exponent).translate(str.maketrans('0123456789', '⁰¹²³⁴⁵⁶⁷⁸⁹'))}" + tick_labels = [format_scientific(x) for x in tick_locations] plt.xticks(tick_locations, tick_labels, rotation=45, ha="right", fontsize=8) @@ -191,8 +232,8 @@ def format_scientific(x): plt.yticks(fontsize=8) # Customize the plot with Manrope Medium -plt.xlabel("Approximate FLOPs", fontsize=10, font='Manrope', weight='medium') -plt.ylabel(f"Avg Performance ({num_datasets} Benchmarks)", fontsize=10, font='Manrope', weight='medium') +plt.xlabel("Approximate FLOPs", fontsize=10, font="Manrope", weight="medium") +plt.ylabel(f"Avg Performance ({num_datasets} Benchmarks)", fontsize=10, font="Manrope", weight="medium") # Add grid with custom colors @@ -200,12 +241,12 @@ def format_scientific(x): plt.grid(True, which="minor", ls="-", color="#9fbabc", alpha=0.2) # Also set the tick colors -plt.tick_params(which='major', colors='#105257') -plt.tick_params(which='minor', colors='#9fbabc') +plt.tick_params(which="major", colors="#105257") +plt.tick_params(which="minor", colors="#9fbabc") # If you want to change the actual axis line colors as well -plt.gca().spines['left'].set_color('#105257') -plt.gca().spines['bottom'].set_color('#105257') +plt.gca().spines["left"].set_color("#105257") +plt.gca().spines["bottom"].set_color("#105257") # Add the legend below the plot handles, labels = plt.gca().get_legend_handles_labels() @@ -223,7 +264,7 @@ def format_scientific(x): handletextpad=0.05, columnspacing=0.5, frameon=False, - prop={'family': 'Manrope', 'weight': 'medium', 'size': 8} + prop={"family": "Manrope", "weight": "medium", "size": 8}, ) # Adjust the layout @@ -257,7 +298,9 @@ def format_scientific(x): X = np.append(X, [[xmin, ymax + polygon_offset]], axis=0) # Back to left # Create and add polygon -polygon = plt.Polygon(X, facecolor=AI2_YELLOW, alpha=0.2, zorder=-1, edgecolor=AI2_ORANGE, linestyle="--", linewidth=1.5) +polygon = plt.Polygon( + X, facecolor=AI2_YELLOW, alpha=0.2, zorder=-1, edgecolor=AI2_ORANGE, linestyle="--", linewidth=1.5 +) plt.gca().add_patch(polygon) # Save the figure