svmTrain.m 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. function [model] = svmTrain(X, Y, C, kernelFunction, ...
  2. tol, max_passes)
  3. %SVMTRAIN Trains an SVM classifier using a simplified version of the SMO
  4. %algorithm.
  5. % [model] = SVMTRAIN(X, Y, C, kernelFunction, tol, max_passes) trains an
  6. % SVM classifier and returns trained model. X is the matrix of training
  7. % examples. Each row is a training example, and the jth column holds the
  8. % jth feature. Y is a column matrix containing 1 for positive examples
  9. % and 0 for negative examples. C is the standard SVM regularization
  10. % parameter. tol is a tolerance value used for determining equality of
  11. % floating point numbers. max_passes controls the number of iterations
  12. % over the dataset (without changes to alpha) before the algorithm quits.
  13. %
  14. % Note: This is a simplified version of the SMO algorithm for training
  15. % SVMs. In practice, if you want to train an SVM classifier, we
  16. % recommend using an optimized package such as:
  17. %
  18. % LIBSVM (http://www.csie.ntu.edu.tw/~cjlin/libsvm/)
  19. % SVMLight (http://svmlight.joachims.org/)
  20. %
  21. %
  22. if ~exist('tol', 'var') || isempty(tol)
  23. tol = 1e-3;
  24. end
  25. if ~exist('max_passes', 'var') || isempty(max_passes)
  26. max_passes = 5;
  27. end
  28. % Data parameters
  29. m = size(X, 1);
  30. n = size(X, 2);
  31. % Map 0 to -1
  32. Y(Y==0) = -1;
  33. % Variables
  34. alphas = zeros(m, 1);
  35. b = 0;
  36. E = zeros(m, 1);
  37. passes = 0;
  38. eta = 0;
  39. L = 0;
  40. H = 0;
  41. % Pre-compute the Kernel Matrix since our dataset is small
  42. % (in practice, optimized SVM packages that handle large datasets
  43. % gracefully will _not_ do this)
  44. %
  45. % We have implemented optimized vectorized version of the Kernels here so
  46. % that the svm training will run faster.
  47. if strcmp(func2str(kernelFunction), 'linearKernel')
  48. % Vectorized computation for the Linear Kernel
  49. % This is equivalent to computing the kernel on every pair of examples
  50. K = X*X';
  51. elseif strfind(func2str(kernelFunction), 'gaussianKernel')
  52. % Vectorized RBF Kernel
  53. % This is equivalent to computing the kernel on every pair of examples
  54. X2 = sum(X.^2, 2);
  55. K = bsxfun(@plus, X2, bsxfun(@plus, X2', - 2 * (X * X')));
  56. K = kernelFunction(1, 0) .^ K;
  57. else
  58. % Pre-compute the Kernel Matrix
  59. % The following can be slow due to the lack of vectorization
  60. K = zeros(m);
  61. for i = 1:m
  62. for j = i:m
  63. K(i,j) = kernelFunction(X(i,:)', X(j,:)');
  64. K(j,i) = K(i,j); %the matrix is symmetric
  65. end
  66. end
  67. end
  68. % Train
  69. fprintf('\nTraining ...');
  70. dots = 12;
  71. while passes < max_passes,
  72. num_changed_alphas = 0;
  73. for i = 1:m,
  74. % Calculate Ei = f(x(i)) - y(i) using (2).
  75. % E(i) = b + sum (X(i, :) * (repmat(alphas.*Y,1,n).*X)') - Y(i);
  76. E(i) = b + sum (alphas.*Y.*K(:,i)) - Y(i);
  77. if ((Y(i)*E(i) < -tol && alphas(i) < C) || (Y(i)*E(i) > tol && alphas(i) > 0)),
  78. % In practice, there are many heuristics one can use to select
  79. % the i and j. In this simplified code, we select them randomly.
  80. j = ceil(m * rand());
  81. while j == i, % Make sure i \neq j
  82. j = ceil(m * rand());
  83. end
  84. % Calculate Ej = f(x(j)) - y(j) using (2).
  85. E(j) = b + sum (alphas.*Y.*K(:,j)) - Y(j);
  86. % Save old alphas
  87. alpha_i_old = alphas(i);
  88. alpha_j_old = alphas(j);
  89. % Compute L and H by (10) or (11).
  90. if (Y(i) == Y(j)),
  91. L = max(0, alphas(j) + alphas(i) - C);
  92. H = min(C, alphas(j) + alphas(i));
  93. else
  94. L = max(0, alphas(j) - alphas(i));
  95. H = min(C, C + alphas(j) - alphas(i));
  96. end
  97. if (L == H),
  98. % continue to next i.
  99. continue;
  100. end
  101. % Compute eta by (14).
  102. eta = 2 * K(i,j) - K(i,i) - K(j,j);
  103. if (eta >= 0),
  104. % continue to next i.
  105. continue;
  106. end
  107. % Compute and clip new value for alpha j using (12) and (15).
  108. alphas(j) = alphas(j) - (Y(j) * (E(i) - E(j))) / eta;
  109. % Clip
  110. alphas(j) = min (H, alphas(j));
  111. alphas(j) = max (L, alphas(j));
  112. % Check if change in alpha is significant
  113. if (abs(alphas(j) - alpha_j_old) < tol),
  114. % continue to next i.
  115. % replace anyway
  116. alphas(j) = alpha_j_old;
  117. continue;
  118. end
  119. % Determine value for alpha i using (16).
  120. alphas(i) = alphas(i) + Y(i)*Y(j)*(alpha_j_old - alphas(j));
  121. % Compute b1 and b2 using (17) and (18) respectively.
  122. b1 = b - E(i) ...
  123. - Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
  124. - Y(j) * (alphas(j) - alpha_j_old) * K(i,j)';
  125. b2 = b - E(j) ...
  126. - Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
  127. - Y(j) * (alphas(j) - alpha_j_old) * K(j,j)';
  128. % Compute b by (19).
  129. if (0 < alphas(i) && alphas(i) < C),
  130. b = b1;
  131. elseif (0 < alphas(j) && alphas(j) < C),
  132. b = b2;
  133. else
  134. b = (b1+b2)/2;
  135. end
  136. num_changed_alphas = num_changed_alphas + 1;
  137. end
  138. end
  139. if (num_changed_alphas == 0),
  140. passes = passes + 1;
  141. else
  142. passes = 0;
  143. end
  144. fprintf('.');
  145. dots = dots + 1;
  146. if dots > 78
  147. dots = 0;
  148. fprintf('\n');
  149. end
  150. if exist('OCTAVE_VERSION')
  151. fflush(stdout);
  152. end
  153. end
  154. fprintf(' Done! \n\n');
  155. % Save the model
  156. idx = alphas > 0;
  157. model.X= X(idx,:);
  158. model.y= Y(idx);
  159. model.kernelFunction = kernelFunction;
  160. model.b= b;
  161. model.alphas= alphas(idx);
  162. model.w = ((alphas.*Y)'*X)';
  163. end