Training#
This demonstrates our training process for a Generative Adversarial Network (GAN) to generate stellar core collapse signals. The GAN is trained on a dataset of 1764 stellar core collapse signals, each with 256 timestamps.
%load_ext autoreload
%autoreload 2
%matplotlib inline
! pip install starccato -q
Load training data#
from starccato.training.training_data import TrainingData
training_data = TrainingData()
training_data.summary()
Signal Dataset mean: -0.516 +/- 39.724
Signal Dataset scaling factor (to match noise in generator): 5
Signal Dataset shape: (256, 1684)
from starccato.plotting import overplot_signals
signals = training_data.standardize(training_data.raw_signals)[:, 130:-50]
fig = overplot_signals(signals, color="k", alpha=0.01, linewidth=0.2)
fig.axes[0].set_axis_off()
fig.axes[0].grid(False)
_ = fig.suptitle("Standardised Stellar Core Collapse Signals [Training Data]")
from starccato.plotting import plot_stacked_signals
fig = plot_stacked_signals(signals, norm="linear", cmap="inferno_r")
fig, axes = training_data.plot_waveforms(standardised=True)
Train GAN and discriminator models#
For details on the model, see the model architecture.
For details on the training, see the training code.
from starccato.training import train
result = train(num_epochs=8)
Plots#
Signals#
Training Loss plot#
Gradients#
Generator Gradients |
Discriminator Gradients |
---|---|