1- import io
2- import sys
31from typing import Dict , List , Tuple , Union
42
5- from IPython .display import Image , display
63from mat3ra .made .material import Material
74from mat3ra .made .tools .analyze .interface import ZSLMatchHolder
85from mat3ra .made .tools .analyze .rdf import RadialDistributionFunction
9- from mat3ra .utils .jupyterlite .plot import plot_distribution_function , scatter_plot_2d
6+ from mat3ra .utils .jupyterlite .plot import plot_distribution_function , render_figure , scatter_plot_2d
107from matplotlib import pyplot as plt
118
129
@@ -43,7 +40,7 @@ def plot_strain_vs_area(matches: List["ZSLMatchHolder"], settings: Dict[str, Uni
4340 }
4441
4542 fig = scatter_plot_2d (x_values , y_values , hover_texts , plot_settings , trace_names )
46- fig . show ( )
43+ render_figure ( fig )
4744
4845
4946def plot_twisted_interface_solutions (interfaces : List ["Material" ]) -> None :
@@ -70,31 +67,18 @@ def plot_twisted_interface_solutions(interfaces: List["Material"]) -> None:
7067 plot_settings = {"x_title" : "Twist Angle (°)" , "y_title" : "Number of Atoms" , "title" : "Twisted Interface Solutions" }
7168
7269 fig = scatter_plot_2d (x_values , y_values , hover_texts , plot_settings , trace_names )
73- fig . show ( )
70+ render_figure ( fig )
7471
7572
7673def plot_rdf (material : "Material" , cutoff : float = 10.0 , bin_size : float = 0.1 ) -> None :
7774 """
7875 Plot RDF for a material.
7976 """
80- is_pyodide = sys .platform == "emscripten"
81- if is_pyodide :
82- # This is needed so that plt is adjusted before import to work in Pyodide environment
83- plt .switch_backend ("Agg" )
84-
8577 rdf = RadialDistributionFunction .from_material (material , cutoff = cutoff , bin_size = bin_size )
8678 plot_distribution_function (
8779 rdf .bin_centers , rdf .rdf , xlabel = "Distance (Å)" , ylabel = "g(r)" , title = "Radial Distribution Function (RDF)"
8880 )
8981
90- if is_pyodide :
91- # Necessary to display the plot in Pyodide environment
92- buf = io .BytesIO ()
93- plt .savefig (buf , format = "png" )
94- buf .seek (0 )
95- display (Image (buf .read ()))
96- plt .close ()
97-
9882
9983def plot_series (
10084 series : List [Dict ],
@@ -124,12 +108,12 @@ def plot_series(
124108 x_labels = [str (item [x_key ]) for item in series ]
125109 y_values = [item [y_key ] for item in series ]
126110 x_indices = list (range (len (series )))
127- _ , ax = plt .subplots (figsize = figsize )
111+ figure , ax = plt .subplots (figsize = figsize )
128112 ax .plot (x_indices , y_values , marker = marker )
129113 ax .set_xticks (x_indices )
130114 ax .set_xticklabels (x_labels , rotation = rotation , ha = "right" )
131115 ax .set_xlabel (xlabel )
132116 ax .set_ylabel (ylabel )
133117 ax .set_title (title )
134118 plt .tight_layout ()
135- plt . show ( )
119+ render_figure ( figure )
0 commit comments