Skip to content

Commit

Permalink
Merge pull request #18 from thomas-saigre/16-integer-marker-bug
Browse files Browse the repository at this point in the history
16 integer marker bug
  • Loading branch information
thomas-saigre authored Jul 17, 2024
2 parents 5a9a6fb + 0e8ca06 commit fdd9646
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "tikzplotly"
version = "0.1.6"
version = "0.1.7"
description = "Convert plotly figures to LaTeX / tikz figures"
readme = "README.md"
authors = [{name = "Thomas Saigre", email = "[email protected]"}]
Expand Down
2 changes: 1 addition & 1 deletion src/tikzplotly/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.6"
__version__ = "0.1.7"
111 changes: 109 additions & 2 deletions src/tikzplotly/_marker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,96 @@
from warnings import warn

# Source : https://github.com/plotly/plotly.py/blob/51eb5ea9fefda27bccfdb21e660b8d4035cef3b0/packages/python/plotly/plotly/graph_objs/box/_marker.py#L256-L344
AUTHORIZED_SYMBOLS = [0, '0', 'circle', 100, '100', 'circle-open', 200, '200',
'circle-dot', 300, '300', 'circle-open-dot', 1, '1',
'square', 101, '101', 'square-open', 201, '201',
'square-dot', 301, '301', 'square-open-dot', 2, '2',
'diamond', 102, '102', 'diamond-open', 202, '202',
'diamond-dot', 302, '302', 'diamond-open-dot', 3, '3',
'cross', 103, '103', 'cross-open', 203, '203',
'cross-dot', 303, '303', 'cross-open-dot', 4, '4', 'x',
104, '104', 'x-open', 204, '204', 'x-dot', 304, '304',
'x-open-dot', 5, '5', 'triangle-up', 105, '105',
'triangle-up-open', 205, '205', 'triangle-up-dot', 305,
'305', 'triangle-up-open-dot', 6, '6', 'triangle-down',
106, '106', 'triangle-down-open', 206, '206',
'triangle-down-dot', 306, '306', 'triangle-down-open-dot',
7, '7', 'triangle-left', 107, '107', 'triangle-left-open',
207, '207', 'triangle-left-dot', 307, '307',
'triangle-left-open-dot', 8, '8', 'triangle-right', 108,
'108', 'triangle-right-open', 208, '208',
'triangle-right-dot', 308, '308',
'triangle-right-open-dot', 9, '9', 'triangle-ne', 109,
'109', 'triangle-ne-open', 209, '209', 'triangle-ne-dot',
309, '309', 'triangle-ne-open-dot', 10, '10',
'triangle-se', 110, '110', 'triangle-se-open', 210, '210',
'triangle-se-dot', 310, '310', 'triangle-se-open-dot', 11,
'11', 'triangle-sw', 111, '111', 'triangle-sw-open', 211,
'211', 'triangle-sw-dot', 311, '311',
'triangle-sw-open-dot', 12, '12', 'triangle-nw', 112,
'112', 'triangle-nw-open', 212, '212', 'triangle-nw-dot',
312, '312', 'triangle-nw-open-dot', 13, '13', 'pentagon',
113, '113', 'pentagon-open', 213, '213', 'pentagon-dot',
313, '313', 'pentagon-open-dot', 14, '14', 'hexagon', 114,
'114', 'hexagon-open', 214, '214', 'hexagon-dot', 314,
'314', 'hexagon-open-dot', 15, '15', 'hexagon2', 115,
'115', 'hexagon2-open', 215, '215', 'hexagon2-dot', 315,
'315', 'hexagon2-open-dot', 16, '16', 'octagon', 116,
'116', 'octagon-open', 216, '216', 'octagon-dot', 316,
'316', 'octagon-open-dot', 17, '17', 'star', 117, '117',
'star-open', 217, '217', 'star-dot', 317, '317',
'star-open-dot', 18, '18', 'hexagram', 118, '118',
'hexagram-open', 218, '218', 'hexagram-dot', 318, '318',
'hexagram-open-dot', 19, '19', 'star-triangle-up', 119,
'119', 'star-triangle-up-open', 219, '219',
'star-triangle-up-dot', 319, '319',
'star-triangle-up-open-dot', 20, '20',
'star-triangle-down', 120, '120',
'star-triangle-down-open', 220, '220',
'star-triangle-down-dot', 320, '320',
'star-triangle-down-open-dot', 21, '21', 'star-square',
121, '121', 'star-square-open', 221, '221',
'star-square-dot', 321, '321', 'star-square-open-dot', 22,
'22', 'star-diamond', 122, '122', 'star-diamond-open',
222, '222', 'star-diamond-dot', 322, '322',
'star-diamond-open-dot', 23, '23', 'diamond-tall', 123,
'123', 'diamond-tall-open', 223, '223',
'diamond-tall-dot', 323, '323', 'diamond-tall-open-dot',
24, '24', 'diamond-wide', 124, '124', 'diamond-wide-open',
224, '224', 'diamond-wide-dot', 324, '324',
'diamond-wide-open-dot', 25, '25', 'hourglass', 125,
'125', 'hourglass-open', 26, '26', 'bowtie', 126, '126',
'bowtie-open', 27, '27', 'circle-cross', 127, '127',
'circle-cross-open', 28, '28', 'circle-x', 128, '128',
'circle-x-open', 29, '29', 'square-cross', 129, '129',
'square-cross-open', 30, '30', 'square-x', 130, '130',
'square-x-open', 31, '31', 'diamond-cross', 131, '131',
'diamond-cross-open', 32, '32', 'diamond-x', 132, '132',
'diamond-x-open', 33, '33', 'cross-thin', 133, '133',
'cross-thin-open', 34, '34', 'x-thin', 134, '134',
'x-thin-open', 35, '35', 'asterisk', 135, '135',
'asterisk-open', 36, '36', 'hash', 136, '136',
'hash-open', 236, '236', 'hash-dot', 336, '336',
'hash-open-dot', 37, '37', 'y-up', 137, '137',
'y-up-open', 38, '38', 'y-down', 138, '138',
'y-down-open', 39, '39', 'y-left', 139, '139',
'y-left-open', 40, '40', 'y-right', 140, '140',
'y-right-open', 41, '41', 'line-ew', 141, '141',
'line-ew-open', 42, '42', 'line-ns', 142, '142',
'line-ns-open', 43, '43', 'line-ne', 143, '143',
'line-ne-open', 44, '44', 'line-nw', 144, '144',
'line-nw-open', 45, '45', 'arrow-up', 145, '145',
'arrow-up-open', 46, '46', 'arrow-down', 146, '146',
'arrow-down-open', 47, '47', 'arrow-left', 147, '147',
'arrow-left-open', 48, '48', 'arrow-right', 148, '148',
'arrow-right-open', 49, '49', 'arrow-bar-up', 149, '149',
'arrow-bar-up-open', 50, '50', 'arrow-bar-down', 150,
'150', 'arrow-bar-down-open', 51, '51', 'arrow-bar-left',
151, '151', 'arrow-bar-left-open', 52, '52',
'arrow-bar-right', 152, '152', 'arrow-bar-right-open', 53,
'53', 'arrow', 153, '153', 'arrow-open', 54, '54',
'arrow-wide', 154, '154', 'arrow-wide-open']

marker_symbol_dict = {
"circle": ("*", None),
"circle-open": ("o", None),
Expand Down Expand Up @@ -116,6 +207,22 @@
}

def marker_symbol_to_tex(symbol):
if "-dot" in symbol:

if symbol not in AUTHORIZED_SYMBOLS:
warn(f"Symbol '{symbol}' not supported, defaulting to '*'")
return "*", None

# Explanation : with plotly, there is a list of predefined symbols that can be given with a string or a corresponding integer.
# In the list AUTHORIZED_SYMBOLS (taken from plotly sources, see above), the symbols are given in the following order:
# the integer, then the interger as a string, then the symbol name. (e.g. 0, '0', 'circle')
idx = AUTHORIZED_SYMBOLS.index(symbol)
if idx % 3 == 0:
symbol_name = AUTHORIZED_SYMBOLS[idx + 2]
elif idx % 3 == 1:
symbol_name = AUTHORIZED_SYMBOLS[idx + 1]
else:
symbol_name = AUTHORIZED_SYMBOLS[idx]

if "-dot" in symbol_name:
warn("Dotted markers are not supported (yet), the symbol without dot will be used instead.")
return marker_symbol_dict.get(symbol.replace("-dot", ""), ("*", None))
return marker_symbol_dict.get(symbol_name.replace("-dot", ""), ("*", None))
17 changes: 17 additions & 0 deletions src/tikzplotly/_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,19 @@ def get_tikz_code(
warn("No data in figure.")

for trace in figure_data:

if trace.type == "scatter":
# Handle the case where x or y is empty
if trace.x is None and trace.y is None:
warn("Adding empty trace.")
data_str.append( "\\addplot coordinates {};\n" )
continue
else:
if trace.x is None:
trace.x = list(range(len(trace.y)))
if trace.y is None:
trace.y = list(range(len(trace.x)))

data_name_macro, y_name = data_container.addData(trace.x, trace.y, trace.name)
data_str.append( draw_scatter2d(data_name_macro, trace, y_name, axis, colors_set) )
if trace.name and trace['showlegend'] != False:
Expand All @@ -65,6 +77,11 @@ def get_tikz_code(
colors_set.add(convert_color(trace.fillcolor)[:3])

elif trace.type == "heatmap":
# Handle the case where x, y or z is empty
if trace.z is None:
warn("Adding empty trace.")
data_str.append( "\\addplot coordinates {};\n" )
continue
data_str.append( draw_heatmap(trace, fig, img_name, axis) )

else:
Expand Down
28 changes: 26 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Adapted from https://github.com/nschloe/tikzplotlib/blob/450712b4014799ec5f151f234df84335c90f4b9d/tests/helpers.py

import tikzplotly
import re
from math import isclose


# https://stackoverflow.com/a/845432/353337
Expand All @@ -12,9 +14,31 @@ def _unidiff_output(expected, actual):
diff = difflib.unified_diff(expected, actual)
return "".join(diff)

def assert_equality(fig, target_file, **kwargs):
def extract_floats_from_string(s):
"""Extract all floating-point numbers from a string."""
float_pattern = re.compile(r"[-+]?\d*\.\d+|\d+")
floats = [float(num) for num in float_pattern.findall(s)]
return floats

def assert_equality(fig, target_file, tolerance=1e-9, **kwargs):
tikz_code = tikzplotly.get_tikz_code(fig, include_disclamer=False, **kwargs)

with open(target_file, encoding="utf-8") as f:
reference = f.read()
assert reference == tikz_code, target_file + "\n" + _unidiff_output(reference, tikz_code)

reference_floats = extract_floats_from_string(reference)
tikz_floats = extract_floats_from_string(tikz_code)

if len(reference_floats) != len(tikz_floats):
assert False, "Number of floats in the reference and tikz code differ.\n" + _unidiff_output(reference, tikz_code)

for ref, tikz in zip(reference_floats, tikz_floats):
if not isclose(ref, tikz, rel_tol=tolerance, abs_tol=tolerance):
assert False, f"Values differ: {ref} vs {tikz}\n" + _unidiff_output(reference, tikz_code)

# If all floating-point comparisons pass, ensure the structures are the same.
reference_non_floats = re.sub(r"[-+]?\d*\.\d+|\d+", "FLOAT", reference)
tikz_non_floats = re.sub(r"[-+]?\d*\.\d+|\d+", "FLOAT", tikz_code)

assert reference_non_floats == tikz_non_floats, target_file + "\n" + _unidiff_output(reference, tikz_code)

11 changes: 10 additions & 1 deletion tests/test_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def plot_4():

return fig

def plot_5():
fig = px.imshow([[1, 20, 30],
[20, 1, 60],
[30, 60, 1]])
fig.data[0].z = None
return fig

def test_1():
assert_equality(plot_1(), os.path.join(this_dir, test_name, test_name + "_1_reference.tex"), img_name="/tmp/tikzplotly/fig1.png")
Expand All @@ -62,4 +68,7 @@ def test_3():
assert_equality(plot_3(), os.path.join(this_dir, test_name, test_name + "_3_reference.tex"), img_name="/tmp/tikzplotly/fig3.png")

def test_4():
assert_equality(plot_4(), os.path.join(this_dir, test_name, test_name + "_4_reference.tex"), img_name="/tmp/tikzplotly/fig4.png")
assert_equality(plot_4(), os.path.join(this_dir, test_name, test_name + "_4_reference.tex"), img_name="/tmp/tikzplotly/fig4.png")

def test_5():
assert_equality(plot_5(), os.path.join(this_dir, test_name, test_name + "_5_reference.tex"))
9 changes: 9 additions & 0 deletions tests/test_heatmap/test_heatmap_5_reference.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
\begin{tikzpicture}


\begin{axis}[
y dir=reverse
]
\addplot coordinates {};
\end{axis}
\end{tikzpicture}
9 changes: 6 additions & 3 deletions tests/test_markers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import plotly.express as px
import plotly.graph_objects as go
import pytest
import numpy as np
import os
from .helpers import assert_equality
Expand All @@ -8,12 +9,13 @@
this_dir = pathlib.Path(__file__).resolve().parent
test_name = "test_markers"

def plot_1():
def plot_1(symbol):

df = px.data.iris()
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species")

fig.update_traces(marker=dict(size=12,
symbol = symbol,
line=dict(width=2,
color='DarkSlateGrey')),
selector=dict(mode='markers'))
Expand Down Expand Up @@ -79,8 +81,9 @@ def plot_3():

return fig

def test_1():
assert_equality(plot_1(), os.path.join(this_dir, test_name, test_name + "_1_reference.tex"))
@pytest.mark.parametrize("symbol", ["circle", 0, "0"])
def test_1(symbol):
assert_equality(plot_1(symbol), os.path.join(this_dir, test_name, test_name + "_1_reference.tex"))

def test_2():
assert_equality(plot_2(), os.path.join(this_dir, test_name, test_name + "_2_reference.tex"))
Expand Down
13 changes: 12 additions & 1 deletion tests/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from .helpers import assert_equality
import pathlib
import pytest

this_dir = pathlib.Path(__file__).resolve().parent
test_name = "test_scatter"
Expand Down Expand Up @@ -136,6 +137,12 @@ def plot_5():
fig = px.line(df, x='date', y="GOOG")
return fig

def plot_6(x=True, y=True):
fig = px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16])
if x: fig.data[0].x = None
if y: fig.data[0].y = None
return fig

def test_1():
assert_equality(plot_1(), os.path.join(this_dir, test_name, test_name + "_1_reference.tex"))

Expand All @@ -149,4 +156,8 @@ def test_4():
assert_equality(plot_4(), os.path.join(this_dir, test_name, test_name + "_4_reference.tex"))

def test_5():
assert_equality(plot_5(), os.path.join(this_dir, test_name, test_name + "_5_reference.tex"))
assert_equality(plot_5(), os.path.join(this_dir, test_name, test_name + "_5_reference.tex"))

@pytest.mark.parametrize("x, y", [(True, True), (True, False), (False, True)])
def test_6(x, y):
assert_equality(plot_6(x, y), os.path.join(this_dir, test_name, test_name + f"_6_{x}_{y}_reference.tex"))
19 changes: 19 additions & 0 deletions tests/test_scatter/test_scatter_6_False_True_reference.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
\pgfplotstableread{data0 y0
0 0
1 1
2 2
3 3
4 4
}\dataZ

\begin{tikzpicture}

\definecolor{636efa}{HTML}{636efa}

\begin{axis}[
xlabel=x,
ylabel=y
]
\addplot+ [mark=*, only marks, mark options={solid, fill=636efa}, forget plot] table[y=y0] {\dataZ};
\end{axis}
\end{tikzpicture}
19 changes: 19 additions & 0 deletions tests/test_scatter/test_scatter_6_True_False_reference.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
\pgfplotstableread{data0 y0
0 0
1 1
2 4
3 9
4 16
}\dataZ

\begin{tikzpicture}

\definecolor{636efa}{HTML}{636efa}

\begin{axis}[
xlabel=x,
ylabel=y
]
\addplot+ [mark=*, only marks, mark options={solid, fill=636efa}, forget plot] table[y=y0] {\dataZ};
\end{axis}
\end{tikzpicture}
10 changes: 10 additions & 0 deletions tests/test_scatter/test_scatter_6_True_True_reference.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
\begin{tikzpicture}


\begin{axis}[
xlabel=x,
ylabel=y
]
\addplot coordinates {};
\end{axis}
\end{tikzpicture}
6 changes: 6 additions & 0 deletions tests/test_tikzplotly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import tikzplotly
import plotly.express as px

def test_tikzplotly():
fig = px.scatter(x=[1, 2, 3], y=[1, 2, 3])
tikzplotly.save("/tmp/tikzplotly/test_tikzplotly.tex", fig)

0 comments on commit fdd9646

Please sign in to comment.