ex1.m 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. %% Machine Learning Online Class - Exercise 1: Linear Regression
  2. % Instructions
  3. % ------------
  4. %
  5. % This file contains code that helps you get started on the
  6. % linear exercise. You will need to complete the following functions
  7. % in this exericse:
  8. %
  9. % warmUpExercise.m
  10. % plotData.m
  11. % gradientDescent.m
  12. % computeCost.m
  13. % gradientDescentMulti.m
  14. % computeCostMulti.m
  15. % featureNormalize.m
  16. % normalEqn.m
  17. %
  18. % For this exercise, you will not need to change any code in this file,
  19. % or any other files other than those mentioned above.
  20. %
  21. % x refers to the population size in 10,000s
  22. % y refers to the profit in $10,000s
  23. %
  24. %% Initialization
  25. clear ; close all; clc
  26. %% ==================== Part 1: Basic Function ====================
  27. % Complete warmUpExercise.m
  28. fprintf('Running warmUpExercise ... \n');
  29. fprintf('5x5 Identity Matrix: \n');
  30. warmUpExercise()
  31. fprintf('Program paused. Press enter to continue.\n');
  32. pause;
  33. %% ======================= Part 2: Plotting =======================
  34. fprintf('Plotting Data ...\n')
  35. data = load('ex1data1.txt');
  36. X = data(:, 1); y = data(:, 2);
  37. m = length(y); % number of training examples
  38. % Plot Data
  39. % Note: You have to complete the code in plotData.m
  40. plotData(X, y);
  41. fprintf('Program paused. Press enter to continue.\n');
  42. pause;
  43. %% =================== Part 3: Cost and Gradient descent ===================
  44. X = [ones(m, 1), data(:,1)]; % Add a column of ones to x
  45. theta = zeros(2, 1); % initialize fitting parameters
  46. % Some gradient descent settings
  47. iterations = 1500;
  48. alpha = 0.01;
  49. fprintf('\nTesting the cost function ...\n')
  50. % compute and display initial cost
  51. J = computeCost(X, y, theta);
  52. fprintf('With theta = [0 ; 0]\nCost computed = %f\n', J);
  53. fprintf('Expected cost value (approx) 32.07\n');
  54. % further testing of the cost function
  55. J = computeCost(X, y, [-1 ; 2]);
  56. fprintf('\nWith theta = [-1 ; 2]\nCost computed = %f\n', J);
  57. fprintf('Expected cost value (approx) 54.24\n');
  58. fprintf('Program paused. Press enter to continue.\n');
  59. pause;
  60. fprintf('\nRunning Gradient Descent ...\n')
  61. % run gradient descent
  62. theta = gradientDescent(X, y, theta, alpha, iterations);
  63. % print theta to screen
  64. fprintf('Theta found by gradient descent:\n');
  65. fprintf('%f\n', theta);
  66. fprintf('Expected theta values (approx)\n');
  67. fprintf(' -3.6303\n 1.1664\n\n');
  68. % Plot the linear fit
  69. hold on; % keep previous plot visible
  70. plot(X(:,2), X*theta, '-')
  71. legend('Training data', 'Linear regression')
  72. hold off % don't overlay any more plots on this figure
  73. % Predict values for population sizes of 35,000 and 70,000
  74. predict1 = [1, 3.5] *theta;
  75. fprintf('For population = 35,000, we predict a profit of %f\n',...
  76. predict1*10000);
  77. predict2 = [1, 7] * theta;
  78. fprintf('For population = 70,000, we predict a profit of %f\n',...
  79. predict2*10000);
  80. fprintf('Program paused. Press enter to continue.\n');
  81. pause;
  82. %% ============= Part 4: Visualizing J(theta_0, theta_1) =============
  83. fprintf('Visualizing J(theta_0, theta_1) ...\n')
  84. % Grid over which we will calculate J
  85. theta0_vals = linspace(-10, 10, 100);
  86. theta1_vals = linspace(-1, 4, 100);
  87. % initialize J_vals to a matrix of 0's
  88. J_vals = zeros(length(theta0_vals), length(theta1_vals));
  89. % Fill out J_vals
  90. for i = 1:length(theta0_vals)
  91. for j = 1:length(theta1_vals)
  92. t = [theta0_vals(i); theta1_vals(j)];
  93. J_vals(i,j) = computeCost(X, y, t);
  94. end
  95. end
  96. % Because of the way meshgrids work in the surf command, we need to
  97. % transpose J_vals before calling surf, or else the axes will be flipped
  98. J_vals = J_vals';
  99. % Surface plot
  100. figure;
  101. surf(theta0_vals, theta1_vals, J_vals)
  102. xlabel('\theta_0'); ylabel('\theta_1');
  103. % Contour plot
  104. figure;
  105. % Plot J_vals as 15 contours spaced logarithmically between 0.01 and 100
  106. contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20))
  107. xlabel('\theta_0'); ylabel('\theta_1');
  108. hold on;
  109. plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);