checkNNGradients.m 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. function checkNNGradients(lambda)
  2. %CHECKNNGRADIENTS Creates a small neural network to check the
  3. %backpropagation gradients
  4. % CHECKNNGRADIENTS(lambda) Creates a small neural network to check the
  5. % backpropagation gradients, it will output the analytical gradients
  6. % produced by your backprop code and the numerical gradients (computed
  7. % using computeNumericalGradient). These two gradient computations should
  8. % result in very similar values.
  9. %
  10. if ~exist('lambda', 'var') || isempty(lambda)
  11. lambda = 0;
  12. end
  13. input_layer_size = 3;
  14. hidden_layer_size = 5;
  15. num_labels = 3;
  16. m = 5;
  17. % We generate some 'random' test data
  18. Theta1 = debugInitializeWeights(hidden_layer_size, input_layer_size);
  19. Theta2 = debugInitializeWeights(num_labels, hidden_layer_size);
  20. % Reusing debugInitializeWeights to generate X
  21. X = debugInitializeWeights(m, input_layer_size - 1);
  22. y = 1 + mod(1:m, num_labels)';
  23. % Unroll parameters
  24. nn_params = [Theta1(:) ; Theta2(:)];
  25. % Short hand for cost function
  26. costFunc = @(p) nnCostFunction(p, input_layer_size, hidden_layer_size, ...
  27. num_labels, X, y, lambda);
  28. [cost, grad] = costFunc(nn_params);
  29. numgrad = computeNumericalGradient(costFunc, nn_params);
  30. % Visually examine the two gradient computations. The two columns
  31. % you get should be very similar.
  32. disp([numgrad grad]);
  33. fprintf(['The above two columns you get should be very similar.\n' ...
  34. '(Left-Your Numerical Gradient, Right-Analytical Gradient)\n\n']);
  35. % Evaluate the norm of the difference between two solutions.
  36. % If you have a correct implementation, and assuming you used EPSILON = 0.0001
  37. % in computeNumericalGradient.m, then diff below should be less than 1e-9
  38. diff = norm(numgrad-grad)/norm(numgrad+grad);
  39. fprintf(['If your backpropagation implementation is correct, then \n' ...
  40. 'the relative difference will be small (less than 1e-9). \n' ...
  41. '\nRelative Difference: %g\n'], diff);
  42. end