diff --git a/src/pycirclize/circos.py b/src/pycirclize/circos.py index b8ae6a4..164ba8b 100644 --- a/src/pycirclize/circos.py +++ b/src/pycirclize/circos.py @@ -412,7 +412,6 @@ def initialize_from_bed( circos : Circos Circos instance initialized from BED file """ - sector2clockwise = {} if sector2clockwise is None else sector2clockwise records = Bed(bed_file).records sectors = {rec.chr: rec.size for rec in records} sector2start_pos = {rec.chr: rec.start for rec in records} diff --git a/src/pycirclize/tree.py b/src/pycirclize/tree.py index 86a0bee..c8ed3e6 100644 --- a/src/pycirclize/tree.py +++ b/src/pycirclize/tree.py @@ -185,7 +185,7 @@ def load_tree(data: str | Path | Tree, format: str) -> Tree: # Load tree string return Phylo.read(io.StringIO(data), format=format) elif isinstance(data, Tree): - return data + return deepcopy(data) else: raise ValueError(f"{data=} is invalid input tree data!!") @@ -255,26 +255,23 @@ def marker( """ target_node_name = self._search_target_node_name(query) + # Set markers (x, r) coordinates (include descendent nodes) x: list[float] = [] r: list[float] = [] rmin, rmax = self.track.r_plot_lim - if descendent: - clade: Clade = next(self.tree.find_clades(target_node_name)) - descendent_nodes: list[Clade] = list(clade.find_clades()) - for descendent_node in descendent_nodes: - node_x, node_r = self.name2xr[str(descendent_node.name)] - if descendent_node.is_terminal() and self._align_leaf_label: - node_r = rmax if self._outer else rmin - x.append(node_x) - r.append(node_r) - else: - node_x, node_r = self.name2xr[target_node_name] - target_node: Clade = next(self.tree.find_clades(target_node_name)) - if target_node.is_terminal() and self._align_leaf_label: + clade: Clade = next(self.tree.find_clades(target_node_name)) + descendent_nodes: list[Clade] = list(clade.find_clades()) + for descendent_node in descendent_nodes: + node_x, node_r = self.name2xr[str(descendent_node.name)] + if descendent_node.is_terminal() and self._align_leaf_label: node_r = rmax if self._outer else rmin x.append(node_x) r.append(node_r) + # If `descendent=False`, remove descendent nodes (x, r) coordinate + if not descendent: + x, r = [x[0]], [r[0]] + self.track.scatter( x, r, s=size**2, vmin=rmin, vmax=rmax, marker=marker, **kwargs ) @@ -535,10 +532,7 @@ def _calc_name2rect(self) -> dict[str, Rectangle]: parent_xr = self.name2xr[str(parent_node.name)] upper_r = (xr[1] + parent_xr[1]) / 2 r_plot_lim = self.track.r_plot_lim - if self._align_leaf_label: - lower_r = max(r_plot_lim) if self._outer else min(r_plot_lim) - else: - lower_r = max(r_list) if self._outer else min(r_list) + lower_r = max(r_plot_lim) if self._outer else min(r_plot_lim) rmin, rmax = min(upper_r, lower_r), max(upper_r, lower_r) # Set rectangle @@ -581,8 +575,6 @@ def _plot_tree_label(self) -> None: """Plot tree label""" text_kws = dict(size=self._leaf_label_size, orientation="vertical") for node in self.tree.get_terminals(): - if self._leaf_label_size <= 0: - continue # Set label text (x, r) position label = str(node.name) x, r = self.name2xr[label] @@ -599,4 +591,6 @@ def _plot_tree_label(self) -> None: _text_kws.update(params) _text_kws.update(self._node2label_props[label]) - self.track.text(label, x, r, **_text_kws) # type: ignore + # Plot label if text size > 0 + if float(_text_kws["size"]) > 0: + self.track.text(label, x, r, **_text_kws) # type: ignore