From 39618a78a6e0cf146122fbc89b21f72f487700b3 Mon Sep 17 00:00:00 2001 From: moshi Date: Sun, 7 Apr 2024 16:16:03 +0900 Subject: [PATCH] Add vmin arg option to `Circos.radar_chart()` --- src/pycirclize/circos.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/pycirclize/circos.py b/src/pycirclize/circos.py index 99ebe00..594bf83 100644 --- a/src/pycirclize/circos.py +++ b/src/pycirclize/circos.py @@ -180,6 +180,7 @@ def radar_chart( table: str | Path | pd.DataFrame | RadarTable, *, r_lim: tuple[float, float] = (0, 100), + vmin: float = 0, vmax: float = 100, fill: bool = True, marker_size: int = 0, @@ -203,6 +204,8 @@ def radar_chart( Table file or Table dataframe or RadarTable instance r_lim : tuple[float, float], optional Radar chart radius limit region (0 - 100) + vmin : float, optional + Min value vmax : float, optional Max value fill : bool, optional @@ -244,6 +247,10 @@ def radar_chart( circos : Circos Circos instance initialized for radar chart """ + if not vmin < vmax: + raise ValueError(f"vmax must be larger than vmin ({vmin=}, {vmax=})") + size = vmax - vmin + # Setup default properties grid_line_kws = {} if grid_line_kws is None else deepcopy(grid_line_kws) for k, v in dict(color="grey", ls="dashed", lw=0.5).items(): @@ -269,11 +276,12 @@ def radar_chart( if not 0 < grid_interval_ratio <= 1.0: raise ValueError(f"{grid_interval_ratio=} is invalid.") # Plot horizontal grid line & label - stop, step = vmax + (vmax / 1000), vmax * grid_interval_ratio - for v in np.arange(0, stop, step): - track.line(x, [v] * len(x), vmax=vmax, arc=circular, **grid_line_kws) + stop, step = vmax + (size / 1000), size * grid_interval_ratio + for v in np.arange(vmin, stop, step): + y = [v] * len(x) + track.line(x, y, vmin=vmin, vmax=vmax, arc=circular, **grid_line_kws) if show_grid_label: - r = track._y_to_r(v, 0, vmax) + r = track._y_to_r(v, vmin, vmax) # Format grid label if grid_label_formatter: text = grid_label_formatter(v) @@ -283,7 +291,7 @@ def radar_chart( track.text(text, 0, r, **grid_label_kws) # Plot vertical grid line for p in x[:-1]: - track.line([p, p], [0, vmax], vmax=vmax, **grid_line_kws) + track.line([p, p], [vmin, vmax], vmin=vmin, vmax=vmax, **grid_line_kws) # Plot radar charts if isinstance(cmap, str): @@ -296,15 +304,16 @@ def radar_chart( line_kws = line_kws_handler(row_name) if line_kws_handler else {} line_kws.setdefault("lw", 1.0) line_kws.setdefault("label", row_name) - track.line(x, y, vmax=vmax, arc=False, color=color, **line_kws) + track.line(x, y, vmin=vmin, vmax=vmax, arc=False, color=color, **line_kws) if marker_size > 0: marker_kws = marker_kws_handler(row_name) if marker_kws_handler else {} marker_kws.setdefault("marker", "o") marker_kws.setdefault("zorder", 2) marker_kws.update(s=marker_size**2) - track.scatter(x, y, vmax=vmax, color=color, **marker_kws) + track.scatter(x, y, vmin=vmin, vmax=vmax, color=color, **marker_kws) if fill: - track.fill_between(x, y, vmax=vmax, arc=False, color=color, alpha=0.5) + fill_kws = dict(arc=False, color=color, alpha=0.5) + track.fill_between(x, y, y2=vmin, vmin=vmin, vmax=vmax, **fill_kws) # type:ignore # Plot column names for idx, col_name in enumerate(radar_table.col_names):