Skip to content

Commit 9cfa8aa

Browse files
authored
Merge pull request #296 from Exabyte-io/feature/SOF-7887
update: handle plotting in pyodide
2 parents d090615 + 9818f3a commit 9cfa8aa

1 file changed

Lines changed: 5 additions & 21 deletions

File tree

utils/plot.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import io
2-
import sys
31
from typing import Dict, List, Tuple, Union
42

5-
from IPython.display import Image, display
63
from mat3ra.made.material import Material
74
from mat3ra.made.tools.analyze.interface import ZSLMatchHolder
85
from mat3ra.made.tools.analyze.rdf import RadialDistributionFunction
9-
from mat3ra.utils.jupyterlite.plot import plot_distribution_function, scatter_plot_2d
6+
from mat3ra.utils.jupyterlite.plot import plot_distribution_function, render_figure, scatter_plot_2d
107
from matplotlib import pyplot as plt
118

129

@@ -43,7 +40,7 @@ def plot_strain_vs_area(matches: List["ZSLMatchHolder"], settings: Dict[str, Uni
4340
}
4441

4542
fig = scatter_plot_2d(x_values, y_values, hover_texts, plot_settings, trace_names)
46-
fig.show()
43+
render_figure(fig)
4744

4845

4946
def plot_twisted_interface_solutions(interfaces: List["Material"]) -> None:
@@ -70,31 +67,18 @@ def plot_twisted_interface_solutions(interfaces: List["Material"]) -> None:
7067
plot_settings = {"x_title": "Twist Angle (°)", "y_title": "Number of Atoms", "title": "Twisted Interface Solutions"}
7168

7269
fig = scatter_plot_2d(x_values, y_values, hover_texts, plot_settings, trace_names)
73-
fig.show()
70+
render_figure(fig)
7471

7572

7673
def plot_rdf(material: "Material", cutoff: float = 10.0, bin_size: float = 0.1) -> None:
7774
"""
7875
Plot RDF for a material.
7976
"""
80-
is_pyodide = sys.platform == "emscripten"
81-
if is_pyodide:
82-
# This is needed so that plt is adjusted before import to work in Pyodide environment
83-
plt.switch_backend("Agg")
84-
8577
rdf = RadialDistributionFunction.from_material(material, cutoff=cutoff, bin_size=bin_size)
8678
plot_distribution_function(
8779
rdf.bin_centers, rdf.rdf, xlabel="Distance (Å)", ylabel="g(r)", title="Radial Distribution Function (RDF)"
8880
)
8981

90-
if is_pyodide:
91-
# Necessary to display the plot in Pyodide environment
92-
buf = io.BytesIO()
93-
plt.savefig(buf, format="png")
94-
buf.seek(0)
95-
display(Image(buf.read()))
96-
plt.close()
97-
9882

9983
def plot_series(
10084
series: List[Dict],
@@ -124,12 +108,12 @@ def plot_series(
124108
x_labels = [str(item[x_key]) for item in series]
125109
y_values = [item[y_key] for item in series]
126110
x_indices = list(range(len(series)))
127-
_, ax = plt.subplots(figsize=figsize)
111+
figure, ax = plt.subplots(figsize=figsize)
128112
ax.plot(x_indices, y_values, marker=marker)
129113
ax.set_xticks(x_indices)
130114
ax.set_xticklabels(x_labels, rotation=rotation, ha="right")
131115
ax.set_xlabel(xlabel)
132116
ax.set_ylabel(ylabel)
133117
ax.set_title(title)
134118
plt.tight_layout()
135-
plt.show()
119+
render_figure(figure)

0 commit comments

Comments
 (0)