Figure 3 & 4: PCA Saliency Map#

1. Imports#

[1]:
from pathlib import Path
import json
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from fours.utils.data_handling import read_fours_root_dir, load_adi_data

2. Load the dataset#

[2]:
dataset_name = "HD22049_351_096_C-0679_A_"
root_dir = Path(read_fours_root_dir())
json_file = root_dir / Path("30_data/" + dataset_name + ".json")
Data in the FOURS_ROOT_DIR found. Location: /fast/mbonse/s4
[3]:
with open(json_file) as f:
    parameter_config = json.load(f)

dit_psf_template = float(parameter_config["dit_psf"])
dit_science = float(parameter_config["dit_science"])
fwhm = float(parameter_config["fwhm"])
scaling_factor = float(parameter_config["nd_scaling"])
lambda_reg = float(parameter_config["lambda_reg"])
svd_approx = int(parameter_config["svd_approx"])
pixel_scale=0.02718
[4]:
dataset_file = root_dir / Path("30_data/" + dataset_name + ".hdf5")
experiment_root_dir = root_dir / Path("70_results/x1_fake_planet_experiments/" + dataset_name)
experiment_root_dir.mkdir(exist_ok=True)
[5]:
science_data, angles, raw_psf_template_data = load_adi_data(
    dataset_file,
    data_tag="object_stacked_05",
    psf_template_tag="psf_template",
    para_tag="header_object_stacked_05/PARANG")

psf_template = np.median(raw_psf_template_data, axis=0)
[6]:
# we cut the image to 91 x 91 pixel to be slightly larger than 1.2 arcsec
cut_off = int((science_data.shape[1] - 91) / 2)
science_data = science_data[:, cut_off:-cut_off, cut_off:-cut_off]

3. Bild PCA basis#

[7]:
# 1.) Convert images to torch tensor
im_shape = science_data.shape
images_torch = torch.from_numpy(science_data)

# 2.) remove the mean as needed for PCA
images_torch = images_torch - images_torch.mean(dim=0)

# 3.) reshape images to fit for PCA
images_torch = images_torch.view(im_shape[0], im_shape[1] * im_shape[2])

# 4.) compute PCA basis
_, _, V = torch.svd_lowrank(images_torch, niter=1, q=1000)
[8]:
pca_number = 300
[9]:
# 5.) compute input gradients
input_gradients = torch.matmul(V[:, :pca_number], V[:, :pca_number].T)
input_gradients = np.abs(input_gradients.detach().numpy())
input_gradients = input_gradients.reshape(-1, 91, 91)
[10]:
# 6.) compute input gradients for different pca components
input_gradients_dict = dict()

for tmp_pca_number in [10, 100, 300]:
    tmp_input_gradients = torch.matmul(V[:, :tmp_pca_number], V[:, :tmp_pca_number].T)
    tmp_input_gradients = np.abs(tmp_input_gradients.detach().numpy())
    tmp_input_gradients = tmp_input_gradients.reshape(-1, 91, 91)

    input_gradients_dict[tmp_pca_number] = tmp_input_gradients

4. Get an example image and projection#

[11]:
from matplotlib.colors import LogNorm, SymLogNorm
[12]:
pca_rep = torch.matmul(images_torch, V[:, :pca_number])
noise_estimate = torch.matmul(pca_rep, V[:, :pca_number].T)

idx = 300
example_image = images_torch[idx].reshape(91, 91).numpy()
noise_estiamte = noise_estimate[idx].reshape(91, 91).numpy()
[13]:
example_image = np.log(example_image - np.min(example_image)*1.1)
noise_estiamte = np.log(noise_estiamte - np.min(noise_estiamte)*1.1)

5. Create the First Plot#

[14]:
colors = ["dimgray", "aqua", "darkorange", "crimson"]
[15]:
def plot_saliency_map(
    axis_in,
    input_gradients,
    position):

    idx = position[0] * 91 + position[1]
    axis_in.imshow(input_gradients[idx])
    axis_in.axis("off")

    axis_in.scatter(45, 45, color="red", marker="*", s=50)
[16]:
def add_marker_circle(axis_in, position, color, size=80, lw=1):
    # add a circle around the position
    axis_in.scatter(position[1], position[0],
                    edgecolors=color,
                    facecolors='none',
                    alpha=0.5,
                    marker="o",
                    lw=lw,
                    s=size)
[17]:
position1 = (55, 37)
position2 = (70, 30)
idx1 = position1[0] * 91 + position1[1]
idx2 = position2[0] * 91 + position2[1]
[18]:
# --------------------------------------------------------------------
# 1.) Create Plot Layout
fig = plt.figure(constrained_layout=False, figsize=(8, 7.15))
gs01 = fig.add_gridspec(
    2, 2,
    wspace=0.05, hspace=0.2,
    height_ratios = [1, 0.7])

ax_example_1 = fig.add_subplot(gs01[0, 0])
ax_example_2 = fig.add_subplot(gs01[0, 1])
ax_zoom_in = fig.add_subplot(gs01[1, 0])
ax_psf = fig.add_subplot(gs01[1, 1])

# --------------------------------------------------------------------
# 3.) Plot the Saliency Maps
ax_example_1.imshow(input_gradients[idx1])
ax_example_1.set_xticks([])
ax_example_1.set_yticks([])
ax_example_2.imshow(input_gradients[idx2])
ax_example_2.set_xticks([])
ax_example_2.set_yticks([])

plt.setp(ax_example_2.spines.values(),
         color=colors[1],
         linewidth=2,
         linestyle="-")

plt.setp(ax_example_1.spines.values(),
         color=colors[0],
         linewidth=2,
         linestyle="-")

add_marker_circle(ax_example_1, position1, "white", 250, 1.5)
add_marker_circle(ax_example_2, position2, "white", 250, 1.5)

# Zoomed in Version ---------------------
ax_example_1.add_patch(Rectangle(
    [position1[1]-10,
     position1[0]-10],
    21, 21,
    fill=False,
    edgecolor=colors[2],
    lw=2, ls="dotted"))

arrowKwargs = {
        'arrowstyle' : '-',
        'linestyle' : 'dotted',
        'color': colors[2],
        'linewidth':2}

ax_example_1.annotate(
    '',
    xy=[position1[1]-10,
        position1[0]+11],
    xytext=[13, 105],
    arrowprops=arrowKwargs)

ax_example_1.annotate(
    '',
    xy=[position1[1]+11,
        position1[0]+11],
    xytext=[77, 105],
    arrowprops=arrowKwargs)

# Zoom plot
ax_zoom_in.imshow(input_gradients[idx1][
    position1[0]-10:position1[0]+11,
    position1[1]-10:position1[1]+11])
ax_zoom_in.set_xticks([])
ax_zoom_in.set_yticks([])
plt.setp(ax_zoom_in.spines.values(),
         color=colors[2],
         linewidth=2,
         linestyle="-")

# PSF ---------------------
ax_psf.imshow(psf_template)
ax_psf.set_xticks([])
ax_psf.set_yticks([])

# Add Figure Title
fig.patch.set_facecolor('white')
plt.savefig("./final_plots/03_pca_saliency_map.pdf", bbox_inches='tight')
../../_images/04_use_the_fours_paper_experiments_03_pca_saliency_map_23_0.png