runkMeans.m 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. function [centroids, idx] = runkMeans(X, initial_centroids, ...
  2. max_iters, plot_progress)
  3. %RUNKMEANS runs the K-Means algorithm on data matrix X, where each row of X
  4. %is a single example
  5. % [centroids, idx] = RUNKMEANS(X, initial_centroids, max_iters, ...
  6. % plot_progress) runs the K-Means algorithm on data matrix X, where each
  7. % row of X is a single example. It uses initial_centroids used as the
  8. % initial centroids. max_iters specifies the total number of interactions
  9. % of K-Means to execute. plot_progress is a true/false flag that
  10. % indicates if the function should also plot its progress as the
  11. % learning happens. This is set to false by default. runkMeans returns
  12. % centroids, a Kxn matrix of the computed centroids and idx, a m x 1
  13. % vector of centroid assignments (i.e. each entry in range [1..K])
  14. %
  15. % Set default value for plot progress
  16. if ~exist('plot_progress', 'var') || isempty(plot_progress)
  17. plot_progress = false;
  18. end
  19. % Plot the data if we are plotting progress
  20. if plot_progress
  21. figure;
  22. hold on;
  23. end
  24. % Initialize values
  25. [m n] = size(X);
  26. K = size(initial_centroids, 1);
  27. centroids = initial_centroids;
  28. previous_centroids = centroids;
  29. idx = zeros(m, 1);
  30. % Run K-Means
  31. for i=1:max_iters
  32. % Output progress
  33. fprintf('K-Means iteration %d/%d...\n', i, max_iters);
  34. if exist('OCTAVE_VERSION')
  35. fflush(stdout);
  36. end
  37. % For each example in X, assign it to the closest centroid
  38. idx = findClosestCentroids(X, centroids);
  39. % Optionally, plot progress here
  40. if plot_progress
  41. plotProgresskMeans(X, centroids, previous_centroids, idx, K, i);
  42. previous_centroids = centroids;
  43. fprintf('Press enter to continue.\n');
  44. pause;
  45. end
  46. % Given the memberships, compute new centroids
  47. centroids = computeCentroids(X, idx, K);
  48. end
  49. % Hold off if we are plotting progress
  50. if plot_progress
  51. hold off;
  52. end
  53. end