Source code for dialogy.utils.temperature_scaling

import os
from typing import Any, Tuple

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


[docs]def calc_bins( preds: npt.NDArray[np.float64], labels_oneh: npt.NDArray[np.float64] ) -> Any: # Assign each prediction to a bin num_bins = 10 bins = np.linspace(0.1, 1, num_bins) binned = np.digitize(preds, bins) # Save the accuracy, confidence and size of each bin bin_accs = np.zeros(num_bins) bin_confs = np.zeros(num_bins) bin_sizes = np.zeros(num_bins) for bin in range(num_bins): bin_sizes[bin] = len(preds[binned == bin]) if bin_sizes[bin] > 0: bin_accs[bin] = (labels_oneh[binned == bin]).sum() / bin_sizes[bin] bin_confs[bin] = (preds[binned == bin]).sum() / bin_sizes[bin] return bins, binned, bin_accs, bin_confs, bin_sizes
[docs]def get_metrics( preds: npt.NDArray[np.float64], labels_oneh: npt.NDArray[np.float64] ) -> Tuple[float, float]: ECE = 0 MCE = 0 bins, _, bin_accs, bin_confs, bin_sizes = calc_bins(preds, labels_oneh) for i in range(len(bins)): abs_conf_dif = abs(bin_accs[i] - bin_confs[i]) ECE += (bin_sizes[i] / sum(bin_sizes)) * abs_conf_dif MCE = max(MCE, abs_conf_dif) return ECE, MCE
[docs]def save_reliability_graph( preds: npt.NDArray[np.float64], labels_oneh: npt.NDArray[np.float64], dir_path: str, prefix: str, ) -> None: ECE, MCE = get_metrics(preds, labels_oneh) bins, _, bin_accs, _, _ = calc_bins(preds, labels_oneh) fig = plt.figure(figsize=(8, 8)) ax = fig.gca() # x/y limits ax.set_xlim(0, 1.05) ax.set_ylim(0, 1) # x/y labels plt.xlabel("Confidence") plt.ylabel("Accuracy") # Create grid ax.set_axisbelow(True) ax.grid(color="gray", linestyle="dashed") # Error bars plt.bar(bins, bins, width=0.1, alpha=0.3, edgecolor="black", color="r", hatch="\\") # Draw bars and identity line plt.bar(bins, bin_accs, width=0.1, alpha=1, edgecolor="black", color="b") plt.plot([0, 1], [0, 1], "--", color="gray", linewidth=2) # Equally spaced axes plt.gca().set_aspect("equal", adjustable="box") # ECE and MCE legend ECE_patch = mpatches.Patch(color="green", label="ECE = {:.2f}%".format(ECE * 100)) MCE_patch = mpatches.Patch(color="red", label="MCE = {:.2f}%".format(MCE * 100)) plt.legend(handles=[ECE_patch, MCE_patch]) plt.savefig( os.path.join(dir_path, f"{prefix}_reliability_graph.png"), bbox_inches="tight" )
[docs]def T_scaling(logits: Tensor, temperature: Tensor) -> Tensor: return torch.div(logits, temperature)
[docs]def fit_ts_parameter( logits_list: npt.NDArray[np.float64], labels_list: npt.NDArray[np.int64], lr: float = 0.001, max_iter: int = 10000, device: torch.device = DEVICE, ) -> float: logits_tensor = torch.from_numpy(logits_list).to(device) labels_tensor = torch.from_numpy(labels_list).to(device) temperature = nn.Parameter(torch.ones(1).to(device)) criterion = nn.CrossEntropyLoss() optimizer = optim.LBFGS( [temperature], lr=lr, max_iter=max_iter, line_search_fn="strong_wolfe" ) import time def _eval() -> Any: loss = criterion(T_scaling(logits_tensor, temperature), labels_tensor) loss.backward() return loss optimizer.step(_eval) return round(temperature.item(), 4)