Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add color themes - drafted #84

Open
johndpope opened this issue Mar 27, 2024 · 2 comments
Open

add color themes - drafted #84

johndpope opened this issue Mar 27, 2024 · 2 comments

Comments

@johndpope
Copy link

johndpope commented Mar 27, 2024

Firstly - this is great library. I added to some ML library to illustrate the attention mechanism and was blown away.
xmu-xiaoma666/External-Attention-pytorch#115

Looking at the gray colors - look a bit tired. I banged in the dot class into chatgpt and asked it to add colour theming from third party library. - there's 4 or 5 off the shelf libraries that can handle this.
the palattable by @jiffyclub looked fine. https://github.com/jiffyclub/palettable

ChatGPT spat out this code to upgrade to support.

Would add 300kb - but a small price to pay for clarity.

from collections import namedtuple
from distutils.version import LooseVersion
from graphviz import Digraph
import torch
from torch.autograd import Variable
import warnings
import palettable

# Use a color palette from Palettable
palette = palettable.colorbrewer.qualitative.Set1_7.mpl_colors

def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50):
    """ Produces Graphviz representation of PyTorch autograd graph using Palettable for colors.
    """
    if params is not None:
        assert all(isinstance(p, Variable) for p in params.values())
        param_map = {id(v): k for k, v in params.items()}
    else:
        param_map = {}

    node_attr = dict(style='filled', shape='box', align='left', fontsize='10', ranksep='0.1', height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def color_to_hex(color):
        return '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))

    # Define color scheme
    colors = {
        'tensor': color_to_hex(palette[0]),
        'operation': color_to_hex(palette[1]),
        'saved_tensor': color_to_hex(palette[2]),
        'param_tensor': color_to_hex(palette[3]),
        'base_tensor': color_to_hex(palette[4]),
        'view_tensor': color_to_hex(palette[5])
    }

    def get_var_name(var, name=None):
        if not name:
            name = param_map[id(var)] if id(var) in param_map else ''
        return f'{name}\n {var.size()}'

    def add_nodes(fn):
        if fn in seen:
            return
        seen.add(fn)

        # Add nodes for saved tensors
        if show_saved and hasattr(fn, 'saved_tensors'):
            for t in fn.saved_tensors:
                if t not in seen:
                    seen.add(t)
                    dot.node(str(id(t)), get_var_name(t), fillcolor=colors['saved_tensor'])
                    dot.edge(str(id(t)), str(id(fn)), dir="none")

        # Add the node for this grad_fn
        fn_name = str(type(fn).__name__)
        dot.node(str(id(fn)), fn_name, fillcolor=colors['operation'])

        # Recurse for next functions
        if hasattr(fn, 'next_functions'):
            for u in fn.next_functions:
                if u[0] is not None:
                    dot.edge(str(id(u[0])), str(id(fn)))
                    add_nodes(u[0])

    def add_base_tensor(var):
        if var in seen:
            return
        seen.add(var)

        color = colors['base_tensor'] if var._is_view() else colors['tensor']
        dot.node(str(id(var)), get_var_name(var), fillcolor=color)

        if var.grad_fn:
            add_nodes(var.grad_fn)
            dot.edge(str(id(var.grad_fn)), str(id(var)))

        if var._is_view():
            base_var = var._base
            add_base_tensor(base_var)
            dot.edge(str(id(base_var)), str(id(var)), style="dotted", fillcolor=colors['view_tensor'])

    # handle multiple outputs
    if isinstance(var, tuple):
        for v in var:
            add_base_tensor(v)
    else:
        add_base_tensor(var)

    resize_graph(dot)

    return dot

def resize_graph(dot, size_per_element=0.15, min_size=12):
    num_rows = len(dot.body)
    content_size = num_rows * size_per_element
    size = max(min_size, content_size)
    size_str = str(size) + "," + str(size)
    dot.graph_attr.update(size=size_str)

UPDATE
maybe better just to use google theme - dont need 10 million styles.
just rip off their blue + green + font.
and hey presto.

<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg width="901pt" height="1998pt" viewBox="0.00 0.00 901.00 1998.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1994)">
<title>%3</title>
<polygon fill="white" stroke="transparent" points="-4,4 -4,-1994 897,-1994 897,4 -4,4"/>

<!-- Example of a node in Google's blue -->
<g id="node1" class="node">
<title>137319056510800</title>
<polygon fill="#4285F4" stroke="black" points="249,-31 148,-31 148,0 249,0 249,-31"/>
<text text-anchor="middle" x="198.5" y="-7" font-family="Arial, Helvetica, sans-serif" font-size="10.00" fill="white"> (50, 64, 512)</text>
</g>

<!-- Example of an edge with Google's style -->
<g id="edge80" class="edge">
<title>137319055447216&#45;&gt;137319056510800</title>
<path fill="none" stroke="#34A853" d="M150.92,-72.73C158.26,-64.06 169.79,-50.43 179.66,-38.76"/>
<polygon fill="#34A853" stroke="#34A853" points="182.37,-40.98 186.16,-31.08 177.03,-36.46 182.37,-40.98"/>
</g>

<!-- Add more nodes and edges here with similar styles -->
</g>
</svg>

Screenshot from 2024-03-27 14-35-37

@johndpope johndpope changed the title add color themes - drafted PR add color themes - drafted Mar 27, 2024
@leo-ware
Copy link

@johndpope I forked the repo and republished as torchviz2. Do you want to open a pull request on the new repo?

https://github.com/leo-ware/torchviz2

@johndpope
Copy link
Author

thx for update - feel free to cherry pick whatever.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants