Skip to content

Commit 20f2c5f

Browse files
committed
add plotting script
1 parent a57acfe commit 20f2c5f

1 file changed

Lines changed: 253 additions & 0 deletions

File tree

src/amuse_metisse/plot_hr.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
"""
2+
Plot a Hertzsprung-Russell diagram of a star cluster, for a number of
3+
snapshots, and make a movie.
4+
5+
Uses data from
6+
https://astronomy.stackexchange.com/questions/39994/what-is-the-rgb-curve-for-blackbodies
7+
to convert temperature to RGB.
8+
9+
Shows a progress bar.
10+
"""
11+
12+
import sys
13+
import os.path
14+
import argparse
15+
import numpy as np
16+
import matplotlib.pyplot as plt
17+
from matplotlib import animation
18+
from amuse.io import read_set_from_file
19+
from amuse.units import units, constants
20+
21+
# package for progress bar
22+
from tqdm import tqdm
23+
24+
25+
def lumrad_to_temp(luminosity, radius):
26+
temperature = ((
27+
luminosity
28+
/ (constants.four_pi_stefan_boltzmann * radius**2)
29+
)**0.25).in_(units.K)
30+
return temperature
31+
32+
33+
def lumtemp_to_rad(luminosity, temperature):
34+
radius = ((luminosity / (constants.four_pi_stefan_boltzmann * temperature**4))**0.5).in_(units.RSun)
35+
return radius
36+
37+
38+
def templum_to_xyz(temperature, luminosity):
39+
log_temperature = np.nan_to_num(np.log10(temperature.value_in(units.K)))
40+
log_luminosity = np.nan_to_num(np.log10(luminosity.value_in(units.LSun)))
41+
color = temp_to_rgb(temperature)
42+
return log_temperature, log_luminosity, color
43+
44+
45+
def temp_to_rgb(temperature):
46+
temp = temperature.value_in(units.K)
47+
logT = np.log(temp)
48+
logT1000 = np.log(temp - 1000.0)
49+
rgb = np.zeros((len(temp), 3))
50+
rgb[:, 0] = 1.0
51+
rgb[:, 1] = 0.390081972 * logT - 2.427925631
52+
rgb[:, 2] = 0.543206396 * logT1000 - 3.698136688
53+
54+
t6600 = temp > 6600
55+
rgb[t6600, 0] = 2.4054 * (temp[t6600]-6000)**(-0.1332047592)
56+
rgb[t6600, 1] = 1.6 * (temp[t6600]-6000)**(-0.0755148492)
57+
rgb[t6600, 2] = 1.0
58+
59+
rgb = np.clip(rgb, 0, 1)
60+
return rgb
61+
62+
63+
64+
class StarHRPlotter:
65+
def __init__(self, name, extension="amuse"):
66+
self.fig = plt.figure(figsize=(10, 10))
67+
self.ax = self.fig.add_subplot(111)
68+
self.ax.set_xlabel("log(Teff)")
69+
self.ax.set_ylabel("log(L)")
70+
self.temperature_range = [5.5, 3]
71+
self.luminosity_range = [-5, 9]
72+
self.ax.set_xlim(self.temperature_range)
73+
self.ax.set_ylim(self.luminosity_range)
74+
self.ax.set_facecolor("k")
75+
self.scatter = None
76+
self.scatter2 = None
77+
self.name = name
78+
self.extension = extension
79+
self.ndigit = 4
80+
81+
82+
def make_movie(self, start, end):
83+
"Find all snapshots, and make a movie"
84+
def update(frame):
85+
"""
86+
update frame and update progress bar. The progress bar doesn't work
87+
yet so printing dots too.
88+
"""
89+
print(".", end="", flush=True)
90+
i = start + frame
91+
92+
filename = f"{self.name}{i:0{self.ndigit}d}.{self.extension}"
93+
stars = read_set_from_file(filename)
94+
size = 4 * stars.radius.value_in(units.RSun)**0.5
95+
x, y, color = templum_to_xyz(stars.temperature, stars.luminosity)
96+
self.ax.set_title(f"Snapshot {i}")
97+
self.scatter.set_offsets(np.array([x, y,]).T)
98+
self.scatter.set_sizes(size)
99+
self.scatter.set_facecolors(color)
100+
101+
102+
i = start
103+
filename = f"{self.name}{i:0{self.ndigit}d}.{self.extension}"
104+
stars = read_set_from_file(filename)
105+
x, y, color = templum_to_xyz(stars.temperature, stars.luminosity)
106+
# radius_is_zero = stars.radius == 0 | units.RSun
107+
# print(radius_is_zero)
108+
# stars[radius_is_zero] = lumtemp_to_rad(
109+
# stars[radius_is_zero].luminosity,
110+
# stars[radius_is_zero].temperature,
111+
# )
112+
size = 4 * stars.radius.value_in(units.RSun)**0.5
113+
self.scatter = self.ax.scatter(x, y, s=size, c=color, edgecolor="none")
114+
anim = animation.FuncAnimation(
115+
self.fig,
116+
update,
117+
frames=tqdm(range(end - start), position=0, file=sys.stdout),
118+
interval=30,
119+
repeat=False,
120+
)
121+
anim.save(f"{self.name}.mp4", dpi=150, writer=animation.FFMpegWriter(fps=25))
122+
123+
124+
def templum_to_xy(self):
125+
log_temperature = np.nan_to_num(np.log10(stars.temperature.value_in(units.K)))
126+
log_luminosity = np.nan_to_num(np.log10(stars.luminosity.value_in(units.LSun)))
127+
128+
129+
def plot_hr(self, stars, stars2=None):
130+
log_temperature = np.nan_to_num(np.log10(stars.temperature.value_in(units.K)))
131+
log_luminosity = np.nan_to_num(np.log10(stars.luminosity.value_in(units.LSun)))
132+
col = temp_to_rgb(stars.temperature)
133+
if not self.scatter:
134+
self.scatter = self.ax.scatter(
135+
log_temperature,
136+
log_luminosity,
137+
s=42,
138+
c=temp_to_rgb(stars.temperature),
139+
edgecolor="none",
140+
)
141+
else:
142+
self.scatter.set_offsets(
143+
np.array(
144+
[
145+
log_temperature,
146+
log_luminosity,
147+
]
148+
).T
149+
)
150+
self.scatter.set_facecolors(col)
151+
if stars2:
152+
153+
log_temperature2 = np.nan_to_num(np.log10(stars2.temperature.value_in(units.K)))
154+
log_luminosity2 = np.nan_to_num(np.log10(stars2.luminosity.value_in(units.LSun)))
155+
if not self.scatter2:
156+
self.scatter2 = self.ax.scatter(
157+
log_temperature2,
158+
log_luminosity2,
159+
s=2,
160+
)
161+
else:
162+
self.scatter2.set_offsets(
163+
np.array(
164+
[
165+
log_temperature2,
166+
log_luminosity2,
167+
]
168+
).T
169+
)
170+
171+
172+
def savefig(self, filename):
173+
self.fig.savefig(filename)
174+
175+
176+
177+
def new_argument_parser():
178+
"Parse command line arguments, show defaults"
179+
parser = argparse.ArgumentParser(
180+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
181+
)
182+
parser.add_argument(
183+
"-i",
184+
"--infile",
185+
type=str,
186+
default="",
187+
help="The first snapshot to plot",
188+
)
189+
parser.add_argument(
190+
"-I",
191+
"--infile2",
192+
type=str,
193+
default="",
194+
help="Second set of snapshots to plot",
195+
)
196+
parser.add_argument(
197+
"-n",
198+
"--number",
199+
type=int,
200+
default=1,
201+
help="The number of snapshots to plot",
202+
)
203+
return parser
204+
205+
206+
def main():
207+
args = new_argument_parser().parse_args()
208+
filename_template = args.infile
209+
# check extension of the file
210+
filename_template = filename_template.split(".")
211+
extension = filename_template[-1]
212+
name = filename_template[0]
213+
if args.infile2:
214+
filename_template2 = args.infile2
215+
filename_template2 = filename_template2.split(".")
216+
extension2 = filename_template2[-1]
217+
name2 = filename_template2[0]
218+
# check if the name ends with a number, if so, check how many characters
219+
# are there and store it and strip it
220+
i = 0
221+
while name[-i - 1].isdigit():
222+
i += 1
223+
if i > 1:
224+
name = name[:-i]
225+
if args.infile2:
226+
name2 = name2[:-i]
227+
ndigit = i
228+
else:
229+
raise ValueError("The name of the first snapshot should end with a number")
230+
231+
# set up plotter
232+
plotter = StarHRPlotter(name)
233+
plotter.make_movie(0, args.number)
234+
sys.exit()
235+
236+
# read the snapshots
237+
for i in range(args.number):
238+
filename = f"{name}{i:0{ndigit}}.{extension}"
239+
snapshot = read_set_from_file(filename)
240+
if args.infile2:
241+
filename2 = f"{name2}{i:0{ndigit}}.{extension}"
242+
snapshot2 = read_set_from_file(filename2)
243+
print(filename, filename2)
244+
plotter.plot_hr(snapshot, snapshot2)
245+
plotter.savefig(f"{name}{name2}{i:0{ndigit}}.png")
246+
else:
247+
print(filename)
248+
plotter.plot_hr(snapshot)
249+
plotter.savefig(f"rgb{name}{i:0{ndigit}}.png")
250+
251+
252+
if __name__ == "__main__":
253+
main()

0 commit comments

Comments
 (0)