Source code for exonamd.plot

import os
from pathlib import Path
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator, FuncFormatter, NullFormatter

from loguru import logger

from exonamd.utils import ROOT
from exonamd.solve import solve_namd_mc

# set the default fontsizes
plt.rcParams.update(
    {
        "font.size": 14,
        "axes.titlesize": 20,
        "axes.labelsize": 16,
        "xtick.labelsize": 14,
        "ytick.labelsize": 14,
        "legend.fontsize": 14,
        "figure.titlesize": 22,
    }
)


[docs]@logger.catch def simple_plot( df, kind, title="", which="namd", ylabel="Frequency", xlabel=None, scale="linear", bins=50, out_path=None, figsize=None, xlim=(None, None), ): samples = df[f"{which}_{kind}_mc"] q50 = df[f"{which}_{kind}_q50"] q16 = df[f"{which}_{kind}_q16"] q84 = df[f"{which}_{kind}_q84"] if xlabel is None: xlabel = rf"{kind[0].upper()}-{which.upper()}" if scale == "log": samples = np.log10(samples) q16, q50, q84 = np.percentile(samples, [16, 50, 84]) xlabel = rf"log$\,${xlabel}" errup = q84 - q50 errdown = q50 - q16 title = f"{title}: " + rf"${q50:.2f}^{{+{errup:.2f}}}_{{-{errdown:.2f}}}$" plt.figure(figsize=figsize) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.hist( samples, bins=bins, histtype="step", weights=np.ones_like(samples) / len(samples), ) plt.grid(which="both", linestyle="--", alpha=0.5) plt.vlines( [q16, q50, q84], 0, plt.ylim()[1], color=["red", "black", "red"], linestyles="dashed", ) plt.xlim(xlim) if out_path: Path(out_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, bbox_inches="tight", dpi=300, format="pdf") plt.show()
[docs]@logger.catch def pop_plot( df, kind, title="", which="namd", yscale="log", xoffs=0.3, out_path=None, replace_nan=False, ): # Plot the values vs multiplicity and color by their relative uncertainty df = df.sort_values(by="sy_pnum") sy_pnum = df["sy_pnum"] nanidx = df[f"{which}_{kind}_q50"].isnull() replaced_idx = nanidx.copy() if replace_nan: df.loc[nanidx, f"{which}_{kind}_q50"] = df.loc[nanidx, f"{which}_{kind}"] df.loc[nanidx, f"{which}_{kind}_q16"] = df.loc[nanidx, f"{which}_{kind}"] df.loc[nanidx, f"{which}_{kind}_q84"] = df.loc[nanidx, f"{which}_{kind}"] q50 = df[f"{which}_{kind}_q50"] q16 = df[f"{which}_{kind}_q16"] q84 = df[f"{which}_{kind}_q84"] nanidx = q50.isnull() sy_pnum = sy_pnum[~nanidx] q50 = q50[~nanidx] q16 = q16[~nanidx] q84 = q84[~nanidx] errup = q84 - q50 errdown = q50 - q16 iq = (q84 - q16) / 2 sigma_rel = iq / q50 bad_idx = sigma_rel > 1.0 ylabel = rf"{kind[0].upper()}-{which.upper()}" coeffs = np.polyfit(sy_pnum, q50, 1) line = np.polyval(coeffs, np.array(list(set(sy_pnum)))) if yscale == "log": coeffs = np.polyfit(sy_pnum, np.log10(q50), 1) line = 10 ** np.polyval(coeffs, np.array(list(set(sy_pnum)))) plt.figure(figsize=(6, 6)) plt.plot( np.array(list(set(sy_pnum))), line, "k--", alpha=0.5, lw=1.5, zorder=10, ) if xoffs > 0.0: M = set(sy_pnum) n_list = [] for m in M: idx = sy_pnum == m n_list.append(idx.sum()) n_list = np.array(n_list) xoffs = xoffs * n_list / n_list.max() for i, m in enumerate(M): idx = sy_pnum == m sy_pnum[idx] += np.linspace(-xoffs[i], xoffs[i], idx.sum()) plt.errorbar( sy_pnum[~replaced_idx], q50[~replaced_idx], yerr=[errdown[~replaced_idx], errup[~replaced_idx]], fmt="none", c="k", alpha=0.8, lw=0.5, capsize=2, zorder=10, ) s = plt.scatter( sy_pnum[~bad_idx], q50[~bad_idx], c=sigma_rel[~bad_idx], cmap="coolwarm", zorder=0, ) plt.colorbar(s, label="Relative uncertainty") plt.clim(0, 1) plt.scatter( sy_pnum[bad_idx], q50[bad_idx], color="w", edgecolors="k", linewidths=0.5, facecolors="none", zorder=0, ) plt.xlabel("Multiplicity") plt.yscale(yscale) if yscale == "log": ax = plt.gca() ax.yaxis.set_major_locator(LogLocator(base=10)) ax.yaxis.set_major_formatter( FuncFormatter(lambda y, pos: f"{int(np.log10(y))}" if y > 0 else "") ) ax.yaxis.set_minor_formatter(NullFormatter()) ylabel = rf"log$\,${ylabel}" plt.ylabel(ylabel) plt.title(title) plt.gca().xaxis.set_major_locator(plt.MaxNLocator(integer=True)) plt.grid(which="both", linestyle="--", alpha=0.5) if out_path: Path(out_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, bbox_inches="tight", dpi=300, format="pdf") plt.show()
[docs]def plot_host_namd( df: pd.DataFrame, hostname: str, kind: str, Npt: int = 100000, threshold: int = 1000, out_path: str = None, ): """ Plot the NAMD for a given host. Parameters ---------- df : pd.DataFrame The NAMD database. If None, the function reloads the database from disk. hostname : str The hostname to plot. kind : str Which type of NAMD to plot. One of 'rel' (relative NAMD) or 'abs' (absolute NAMD). Npt : int Number of Monte Carlo samples. threshold : int Minimum number of valid samples required. Returns ------- None """ # Task 1: reload database if df is None: logger.info("Reloading the database") df = pd.read_csv(os.path.join(ROOT, "data", "exo_namd.csv")) logger.info("Database reloaded") # Task 1: sample the NAMD for a given host logger.info(f"Selecting the host: {hostname}") host = df[df["hostname"] == hostname] logger.info("Host selected") logger.info("Computing the Monte Carlo relative NAMD") retval = solve_namd_mc( host=host, kind=f"{kind}", Npt=Npt, threshold=threshold, use_trunc_normal=True, full=True, ) logger.info("Values computed") # Task 2: plot the NAMD for a given host logger.info("Plotting the relative NAMD distribution") simple_plot( df=retval, kind=f"{kind}", title=hostname, which="namd", scale="log", out_path=out_path, ) logger.info("Plot done")
[docs]def plot_sample_namd( df: pd.DataFrame, title: str, kind: str = "rel", out_path: str = None ): """ Plot the sample NAMD against the multiplicity. If df is None, the function reloads the database from disk. Parameters ---------- df : pd.DataFrame The NAMD database. title : str The title of the plot. kind : str Which type of NAMD to plot. One of 'rel' (relative NAMD) or 'abs' (absolute NAMD). out_path : str The path to save the plot. Returns ------- None """ # Task 1: reload database if df is None: logger.info("Reloading the database") df = pd.read_csv(os.path.join(ROOT, "data", "exo_namd.csv")) logger.info("Database reloaded") # Task 2: plot the sample NAMD logger.info("Plotting the NAMD vs. multiplicity") pop_plot( df=df.groupby("hostname").apply( lambda g: g.select_dtypes(exclude=["object"]).mean(), ), kind=kind, title=title, which="namd", yscale="log", out_path=out_path, ) logger.info("Plot done")