Skip to content

Commit

Permalink
Refactor codes
Browse files Browse the repository at this point in the history
  • Loading branch information
moshi4 committed Nov 9, 2023
1 parent f4c8a9d commit 8e0c74e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 22 deletions.
1 change: 0 additions & 1 deletion src/pycirclize/circos.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def initialize_from_bed(
circos : Circos
Circos instance initialized from BED file
"""
sector2clockwise = {} if sector2clockwise is None else sector2clockwise
records = Bed(bed_file).records
sectors = {rec.chr: rec.size for rec in records}
sector2start_pos = {rec.chr: rec.start for rec in records}
Expand Down
36 changes: 15 additions & 21 deletions src/pycirclize/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def load_tree(data: str | Path | Tree, format: str) -> Tree:
# Load tree string
return Phylo.read(io.StringIO(data), format=format)
elif isinstance(data, Tree):
return data
return deepcopy(data)
else:
raise ValueError(f"{data=} is invalid input tree data!!")

Expand Down Expand Up @@ -255,26 +255,23 @@ def marker(
"""
target_node_name = self._search_target_node_name(query)

# Set markers (x, r) coordinates (include descendent nodes)
x: list[float] = []
r: list[float] = []
rmin, rmax = self.track.r_plot_lim
if descendent:
clade: Clade = next(self.tree.find_clades(target_node_name))
descendent_nodes: list[Clade] = list(clade.find_clades())
for descendent_node in descendent_nodes:
node_x, node_r = self.name2xr[str(descendent_node.name)]
if descendent_node.is_terminal() and self._align_leaf_label:
node_r = rmax if self._outer else rmin
x.append(node_x)
r.append(node_r)
else:
node_x, node_r = self.name2xr[target_node_name]
target_node: Clade = next(self.tree.find_clades(target_node_name))
if target_node.is_terminal() and self._align_leaf_label:
clade: Clade = next(self.tree.find_clades(target_node_name))
descendent_nodes: list[Clade] = list(clade.find_clades())
for descendent_node in descendent_nodes:
node_x, node_r = self.name2xr[str(descendent_node.name)]
if descendent_node.is_terminal() and self._align_leaf_label:
node_r = rmax if self._outer else rmin
x.append(node_x)
r.append(node_r)

# If `descendent=False`, remove descendent nodes (x, r) coordinate
if not descendent:
x, r = [x[0]], [r[0]]

self.track.scatter(
x, r, s=size**2, vmin=rmin, vmax=rmax, marker=marker, **kwargs
)
Expand Down Expand Up @@ -535,10 +532,7 @@ def _calc_name2rect(self) -> dict[str, Rectangle]:
parent_xr = self.name2xr[str(parent_node.name)]
upper_r = (xr[1] + parent_xr[1]) / 2
r_plot_lim = self.track.r_plot_lim
if self._align_leaf_label:
lower_r = max(r_plot_lim) if self._outer else min(r_plot_lim)
else:
lower_r = max(r_list) if self._outer else min(r_list)
lower_r = max(r_plot_lim) if self._outer else min(r_plot_lim)
rmin, rmax = min(upper_r, lower_r), max(upper_r, lower_r)

# Set rectangle
Expand Down Expand Up @@ -581,8 +575,6 @@ def _plot_tree_label(self) -> None:
"""Plot tree label"""
text_kws = dict(size=self._leaf_label_size, orientation="vertical")
for node in self.tree.get_terminals():
if self._leaf_label_size <= 0:
continue
# Set label text (x, r) position
label = str(node.name)
x, r = self.name2xr[label]
Expand All @@ -599,4 +591,6 @@ def _plot_tree_label(self) -> None:
_text_kws.update(params)
_text_kws.update(self._node2label_props[label])

self.track.text(label, x, r, **_text_kws) # type: ignore
# Plot label if text size > 0
if float(_text_kws["size"]) > 0:
self.track.text(label, x, r, **_text_kws) # type: ignore

0 comments on commit 8e0c74e

Please sign in to comment.