""" calculate_loss.py Date: 2013-02-26 Author: Naftali Harris """ from __future__ import division import numpy as np import matplotlib.pyplot as plt from scipy.stats import norm def risk(n, m): """Computes the risk for n samples at mean m""" border = n ** (-0.25) loss = lambda x: m ** 2 if -border < x < border else (x - m) ** 2 return norm.expect(loss, loc=m, scale=np.sqrt(1 / n)) def plot_loss(n, mu_left, mu_right): mu = np.linspace(mu_left, mu_right, 300) fancyloss = [risk(n, m) for m in mu] meanloss = np.ones(len(mu)) / n plt.plot(mu, fancyloss) plt.plot(mu, meanloss) plt.xlabel("Normal Mean, mu") plt.ylabel("Risk (Expected Squared Error)") plt.title("Risk for n = %d" % n) plt.show() if __name__ == "__main__": for n in [1, 16, 81, 256, 625, 1296]: plot_loss(n, -1.5, 1.5)