Figure 7: 4S Saliency Maps#
1. Imports#
[9]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from fours.models.noise import FourSNoise
from fours.utils.data_handling import read_fours_root_dir
2. Load S4 noise model#
[2]:
root_dir = Path(read_fours_root_dir())
noise_model_file1 = root_dir / Path("70_results/x1_fake_planet_experiments/HD22049_351_096_C-0679_A_/scratch/tensorboard_S4/models/noise_model_0000_lambda_010000.pkl")
noise_model_file2 = root_dir / Path("70_results/x1_fake_planet_experiments/HD40136_333_1101_C-0092_C_/scratch/tensorboard_S4/models/noise_model_0000_lambda_001000.pkl")
Data in the FOURS_ROOT_DIR found. Location: /fast/mbonse/s4
[3]:
s4_noise_model_1 = FourSNoise.load(noise_model_file1)
s4_noise_model_2 = FourSNoise.load(noise_model_file2)
3. Get input gradients / Saliency Map#
[4]:
s4_noise_model_1.compute_betas()
input_gradients1 = np.abs(s4_noise_model_1.betas.detach().numpy())
input_gradients1 = input_gradients1.reshape(-1, 91, 91)
[5]:
s4_noise_model_2.compute_betas()
input_gradients2 = np.abs(s4_noise_model_2.betas.detach().numpy())
input_gradients2 = input_gradients2.reshape(-1, 91, 91)
4. Plot the result#
[6]:
def plot_saliency_map(
axis_in,
input_gradients,
position):
idx = position[0] * 91 + position[1]
axis_in.imshow(input_gradients[idx])
axis_in.axis("off")
axis_in.scatter(position[1], position[0], color="white", marker="+", s=50)
axis_in.scatter(45, 45, color="red", marker="*", s=50)
[8]:
position1 = (49, 55)
position2 = (53, 30)
position3 = (22, 22)
# 1.) Create the Plot Layout ------------------------------
fig = plt.figure(
constrained_layout=False,
figsize=(8, 12))
gs0 = fig.add_gridspec(3, 2, width_ratios = [1, 1])
gs0.update(wspace=0.05, hspace=0.07)
# Residual Plots
example_1a = fig.add_subplot(gs0[0, 0])
example_2a = fig.add_subplot(gs0[1, 0])
example_3a = fig.add_subplot(gs0[2, 0])
example_1b = fig.add_subplot(gs0[0, 1])
example_2b = fig.add_subplot(gs0[1, 1])
example_3b = fig.add_subplot(gs0[2, 1])
plot_saliency_map(example_1a, input_gradients1, position1)
plot_saliency_map(example_2a, input_gradients1, position2)
plot_saliency_map(example_3a, input_gradients1, position3)
plot_saliency_map(example_1b, input_gradients2, position1)
plot_saliency_map(example_2b, input_gradients2, position2)
plot_saliency_map(example_3b, input_gradients2, position3)
# Add Figure Title
example_1a.set_title(
"#4 (HD 22049) - Dec 2015",
fontsize=14,
fontweight="bold",
y=1.01)
example_1b.set_title(
"#7 (HD 40136) - Nov 2018",
fontsize=14,
fontweight="bold",
y=1.01)
fig_title = fig.suptitle(
"4S Saliency Map",
size=16, fontweight="bold", y=0.94)
fig.patch.set_facecolor('white')
plt.savefig("./final_plots/04_s4_sailency_map.pdf",
bbox_inches='tight')