Source code for ast_toolbox.utils.exp_utils

import numpy as np


[docs]def log_mean_exp(x, dim): """ Compute the log(mean(exp(x), dim)) in a numerically stable manner """ return log_sum_exp(x, dim) - np.log(x.shape[dim])
[docs]def log_sum_exp(x, dim): """ Compute the log(sum(exp(x), dim)) in a numerically stable manner """ max_x = np.max(x, dim) new_x = x - np.repeat(np.expand_dims(max_x, dim), x.shape[dim], dim) return max_x + np.log(np.sum(np.exp(new_x), dim))
[docs]def softmax(x, dim): """Compute softmax values for each sets of scores in x along dim""" e_x = np.exp(x - np.max(x, dim)) return e_x / np.sum(e_x, dim)