# Copyright 2015 ibu radempa <ibu@radempa.de>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
Generate an entity-relationship diagram from an extended JSON table schema.
`JSON table schema`_ is a simple schema for describing the structure of
tabular data. It can be extended to allow for a comprehensive representation
of an SQL relational database schema.
Starting from such a description this python module generates visualizations
of the database schema using `graphviz`_ via `PyGraphviz`_.
.. _`JSON table schema`: http://dataprotocols.org/json-table-schema/
.. _`graphviz`: http://graphviz.org/
.. _`PyGraphviz`: http://pygraphviz.github.io/
"""
import pygraphviz as pgv
import textwrap
options_defaults = {
'html_color_default': '#ccff99',
'html_color_highlight': '#33cc99',
'fontname': 'Helvetica',
'fontsize': 8,
'fontsize_title': 10,
'fontsize_label': 6,
'bgcolor_indexes': '#ccccff',
'rankdir': 'LR',
'edge_thickness': 1.0,
'display_columns': True,
'display_indexes': True,
'display_crowfoots': True,
'omit_isolated_tables': False,
}
"""
Options and their default values.
Options:
* **html_color_default**
* **html_color_highlight**
* **fontname**
* **fontsize**
* **fontsize_title**
* **fontsize_label**
* **bgcolor_indexes**
* **rankdir**: 'LR' or 'RL'; whether dependent tables appear on the
right (left) hand side
* **edge_thickness**
* **display_columns**: bool
* **display_indexes**: bool
* **display_crowfoots**: bool
* **omit_isolated_tables**: bool
"""
[docs]def get_graph(json_database_schema, **options):
"""
Create and return a graph from the given *json_database_schema*.
All keys from :any:`options_defaults` are allowed in *kwargs*.
"""
opt = options_defaults.copy()
opt.update(options)
database = json_database_schema['database_name']
datetime = json_database_schema['generation_begin_time']
namespaces = json_database_schema['datapackages']
schema_graph = pgv.AGraph(
strict=False,
directed=True,
name='Postgres database %s (as of %s)' % (database, datetime),
rankdir=opt['rankdir'],
fontname=opt['fontname'],
fontsize=opt['fontsize'],
splines=True,
overlap='scale'
)
# inventory
present_tables = {}
tables_with_edges = set() # contains only tables having at least one edge
for namespace in namespaces:
namespace_name = namespace['datapackage']
for table in namespace['resources']:
present_tables[(namespace_name, table['name'])] = table
if 'foreignKeys' in table:
for foreign_key in table['foreignKeys']:
reference = foreign_key['reference']
tables_with_edges.add((namespace_name, table['name']))
tables_with_edges.add((reference['datapackage'],
reference['resource']))
# add table nodes
for namespace in namespaces:
namespace_name = namespace['datapackage']
for table in namespace['resources']:
has_edge = (namespace_name, table['name']) in tables_with_edges
if not opt['omit_isolated_tables'] or has_edge:
_graph_add_table(opt, schema_graph, namespace_name, table)
# add foreign key edges
for namespace in namespaces:
namespace_name = namespace['datapackage']
table_edges = set()
for tail_table in namespace['resources']:
tail_table_name = tail_table['name']
if 'foreignKeys' in tail_table:
for foreign_key in tail_table['foreignKeys']:
columns = foreign_key['fields']
if isinstance(columns, str):
tail_column_names = [columns]
else:
tail_column_names = columns
reference = foreign_key['reference']
head_namespace_name = reference['datapackage']
head_table_name = reference['resource']
head_column_names = reference['fields']
head_table = present_tables[
(head_namespace_name, head_table_name)
]
enforced = foreign_key.get('enforced', True)
color = 'black' if enforced else 'blue'
card_self = reference.get('cardinalitySelf')
card_ref = reference.get('cardinalityRef')
if card_self or card_ref:
if opt['rankdir'] == 'RL':
label = '%s \u2194 %s' % (card_ref, card_self)
else:
label = '%s \u2194 %s' % (card_self, card_ref)
else:
label = ''
if opt['rankdir'] == 'RL':
tooltip = '%s %s(%s) \u2194 %s(%s)' % (
label,
head_table_name,
', '.join(head_column_names),
tail_table_name,
', '.join(tail_column_names)
)
else:
tooltip = '%s %s(%s) \u2194 %s(%s)' % (
label,
tail_table_name,
', '.join(tail_column_names),
head_table_name,
', '.join(head_column_names)
)
if reference.get('label'):
label += '\n' + reference.get('label')
tooltip += ' ' + reference.get('label')
else:
edge_name = reference.get('name')
if edge_name:
label += ' ' + edge_name
tooltip += ' ' + edge_name
label = label.strip()
tooltip = tooltip.strip()
if opt['display_columns']:
_add_foreign_key_edge(
schema_graph,
tail_table_name,
head_table_name,
tail_table,
head_table,
tail_column_names,
head_column_names,
label,
tooltip,
opt,
color,
card_self,
card_ref
)
if not opt['display_columns']:
table_edges.add((tail_table_name, head_table_name))
if not opt['display_columns']:
for tail_table_name, head_table_name in table_edges:
schema_graph.add_edge(
tail_table_name,
head_table_name,
color='black'
)
return schema_graph
[docs]def save_svg(json_database_schema, filepath, **options):
"""
Write an ERD in SVG format for a database to a file.
*json_database_schema* must be compatible with what pg_jts produces.
*filepath* must end in '.svg'.
"""
schema_graph = get_graph(json_database_schema, **options)
#print(schema_graph)
# alternatives: neato, dot, twopi, circo, fdp, nop, wc, acyclic,
# gvpr, gvcolor, ccomps, sccmap, tred, sfdp
schema_graph.layout(prog='dot')
# print(schema_graph)
schema_graph.draw(filepath)
def _graph_add_table(opt, graph, namespace_name, table,
default_namespace_name='public'):
"""
Add a record-shaped node to *graph* with information on a *table*.
All keys from `options_defaults` are allowed in *opt*.
"""
table_name = table['name']
table_comment = table.get('description', '')
display = ['name', 'type', 'combined']
title = (namespace_name + '.' if namespace_name != default_namespace_name
else '') + table_name
html_row0 = '<TR>\n <TD COLOR="black" BGCOLOR="lightgrey"'\
' COLSPAN="%s"><FONT POINT-SIZE="%s"><b>%s</b></FONT>'\
'<FONT POINT-SIZE="%s"><BR/>%s</FONT></TD>\n</TR>\n'\
% (str(len(display)), opt['fontsize_title'],
title, opt['fontsize'], table_comment)
html_rows = [html_row0]
if opt['display_columns']:
if 'primaryKey' in table:
pk = table['primaryKey']
for i, col_name in enumerate(pk):
col = [c for c in table['fields'] if c['name'] == col_name][0]
col_display = _get_column_display(display, table, col)
table_row_html = _get_table_row_html(
opt, display, i + 1, col_display, highlight=True)
html_rows.append(table_row_html)
else:
pk = []
columns = [c for c in table['fields'] if c['name'] not in pk]
#sorted_columns = sorted(columns, key=lambda c: c['pos'])
for col_i, col in enumerate(columns):
col_display = _get_column_display(display, table, col)
html_row = _get_table_row_html(opt, display, col_i + len(pk) + 1,
col_display)
html_rows.append(html_row)
if opt['display_indexes'] and 'indexes' in table:
indexes = [i for i in table['indexes'] if not i.get('unique')]
if indexes:
index_definitions = ['<FONT POINT-SIZE="%s">%s</FONT>' %
(opt['fontsize'], index['definition'])
for index in indexes]
html_index_definitions = '<BR/>'.join(sorted(index_definitions))
html_row = '<TR>\n <TD COLOR="black" BGCOLOR="%s"'\
' ALIGN="LEFT" COLSPAN="%s">Extra indexes:</TD>\n'\
' <TD COLOR="black" BGCOLOR="%s"'\
' ALIGN="LEFT" BALIGN="LEFT">%s</TD>\n</TR>\n'\
% (opt['bgcolor_indexes'], str(len(display) - 1),
opt['bgcolor_indexes'], html_index_definitions)
html_rows.append(html_row)
html_table = '<TABLE ID="%s" ALIGN="LEFT" BORDER="0" CELLBORDER="0"'\
' CELLSPACING="0" BGCOLOR="%s">\n%s</TABLE>'\
% ('table__' + table_name, 'black', ''.join(html_rows))
label = '<\n%s\n>' % html_table
graph.add_node(
table_name,
id=table_name,
label=label,
style='filled',
color='white',
fontname=opt['fontname'],
fontsize=opt['fontsize'],
shape='plaintext',
tooltip=table_comment or 'Table ' + table_name
)
def _get_column_display(display, table, column, pk=False):
"""
Return a list of strings describing a column.
The returned attributes and their order are given by *display*; allowed
attributes are:
* name
* type
* combined (combined str with unique constraint information,
default value and description texts)
"""
res = []
for d in display:
if d == 'name':
res.append(column['name'])
elif d == 'type':
res.append(column['type'])
elif d == 'combined':
vals = []
if column.get('constraints'):
constr = column['constraints']
if 'required' in constr:
vals.append(_format_attribute('null', constr['required']))
uniques = []
table_unique = table.get('unique')
if table_unique:
for t_u_i, t_u in enumerate(table_unique):
if column['name'] in t_u['fields']:
i = t_u['fields'].index(column['name'])
if len(t_u['fields']) == 1:
uniques.append('UNIQ')
else:
uniques.append('UNIQ%s:%s'
% (str(t_u_i + 1), str(i + 1)))
if column.get('constraints'):
column_unique = column['constraints'].get('unique')
if 'UNIQ' not in uniques and column_unique:
uniques.append('UNIQ')
vals.append('; '.join(uniques))
default_value = 'DEFAULT=' + column['default_value']\
if 'default_value' in column else ''
vals.append(default_value)
description = column.get('description', '')
vals.append(description)
text = '; '.join([v for v in vals if v]).replace('\n', '; ')
wrapped_text = '<BR/>\n'.join(textwrap.wrap(text, width=50))
res.append(wrapped_text)
return res
def _get_table_row_html(opt, display, port, table_cols,
align='LEFT', highlight=False):
"""
Return a graphviz HTML string for a table row describing a column.
Add graphviz PORT numbers, prepedend with 'i' for the leftmost cell and
with 'f' for the rightmost cell.
"""
cols_html = ''
for i, table_col in enumerate(table_cols):
table_col = _format_attribute(display[i], table_col)
port_ = ''
if i == 0 and port is not None:
port_ = ' PORT="i%s"' % str(port)
if i == len(table_cols) - 1 and port is not None:
port_ = ' PORT="f%s"' % str(port)
color = (opt['html_color_highlight'] if highlight
else opt['html_color_default'])
cols_html += '<TD BGCOLOR="%s" ALIGN="%s" BALIGN="%s"%s>%s</TD>'\
% (color, align, align, port_, table_col)
return '<TR>\n %s\n</TR>\n' % cols_html
def _format_attribute(attribute_type, attribute_value):
"""
Return *attribute_value*, except for special *attribute_type*s.
For special *attribute_type*s the given string *attribute_value* is
modified, depending on the *attribute_type*:
* **null**
* **name**
* **default**
"""
if attribute_type.lower() == 'null':
if attribute_value:
return ''
else:
return '<s>NULL</s>'
elif attribute_type.lower() == 'name':
return '<b>%s</b>' % attribute_value
elif attribute_type == 'default':
if attribute_value is not None:
if attribute_value.lower().startswith('nextval('):
return '[sequence]'
else:
return attribute_value
else:
return ''
else:
return attribute_value
def _add_foreign_key_edge(schema_graph, tail_table_name, head_table_name,
tail_table, head_table, tail_column_names,
head_column_names, label, tooltip, opt, color,
card_tail, card_head):
"""
Modify *schema_graph* by adding edges (for a foreign key relation).
For multi-column relations also intermediate nodes are added.
"""
port_l = 'i'
port_r = 'f'
if opt['rankdir'] == 'RL':
port_l = 'f'
port_r = 'i'
if len(tail_column_names) > 1:
tail_agg = 'tail agg %s%s->%s' % (
tail_table_name, str(tail_column_names), head_table_name)
schema_graph.add_node(
tail_agg,
id=tail_table_name,
label='',
style='filled',
color='red',
arrowtail=None,
arrowhead=None,
shape='point'
)
for tail_column_name in tail_column_names:
tail_port = port_r + str(_get_port(tail_table, tail_column_name))
schema_graph.add_edge(
tail_table_name,
tail_agg,
tailport=tail_port,
penwidth=opt['edge_thickness'],
color=color,
dir='none'
)
tail_node = tail_agg
tail_port = ''
else:
tail_node = tail_table_name
tail_port = port_r + str(_get_port(tail_table, tail_column_names[0]))
if len(head_column_names) > 1:
head_agg = 'head agg %s->%s%s' % (
tail_table_name, head_table_name, str(tail_column_names))
schema_graph.add_node(
head_agg,
id=head_table_name,
label='',
style='filled',
color='red',
arrowtail=None,
arrowhead=None,
shape='point'
)
for head_column_name in head_column_names:
head_port = port_l + str(_get_port(head_table, head_column_name))
schema_graph.add_edge(
head_agg,
head_table_name,
headport=head_port,
penwidth=opt['edge_thickness'],
color=color,
dir='none'
)
head_node = head_agg
head_port = ''
else:
head_node = head_table_name
head_port = head_port = port_l + str(_get_port(head_table,
head_column_names[0]))
schema_graph.add_edge(
tail_node,
head_node,
tailport=tail_port,
headport=head_port,
penwidth=opt['edge_thickness'],
color=color,
label=label,
fontname=opt['fontname'],
fontsize=opt['fontsize_label'],
fontcolor=color,
arrowtail=_get_crowfoot(card_tail, opt),
arrowhead=_get_crowfoot(card_head, opt),
tooltip=tooltip,
labeltooltip=tooltip,
dir='both'
)
def _get_port(table, column):
"""
Return the port number of a table column.
The port number is the row number in the html table, counting from 0.
Row 0 is the row containing the table name. It is followed by
rows describing primary key columns and then by all other columns.
"""
if 'primaryKey' in table:
pk = table['primaryKey']
if column in pk:
return int(pk.index(column)) + 1
offset = len(pk)
else:
pk = []
offset = 0
columns_non_pk = [c['name'] for c in table['fields']
if c['name'] not in pk]
return columns_non_pk.index(column) + offset + 1
def _get_crowfoot(cardinality, opt):
"""
Return the arrow name for a crowfoot with given *cardinality*.
Cardinalities are:
* 0..1
* 1
* 0..N
* 1..N
"""
if not opt['display_crowfoots']:
return 'none'
if cardinality == '0..1':
return 'teeodot'
if cardinality == '1':
return 'teetee'
if cardinality == '0..N':
return 'crowodot'
if cardinality == '1..N':
return 'crowtee'
return 'none'