123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- function [model] = svmTrain(X, Y, C, kernelFunction, ...
- tol, max_passes)
- %SVMTRAIN Trains an SVM classifier using a simplified version of the SMO
- %algorithm.
- % [model] = SVMTRAIN(X, Y, C, kernelFunction, tol, max_passes) trains an
- % SVM classifier and returns trained model. X is the matrix of training
- % examples. Each row is a training example, and the jth column holds the
- % jth feature. Y is a column matrix containing 1 for positive examples
- % and 0 for negative examples. C is the standard SVM regularization
- % parameter. tol is a tolerance value used for determining equality of
- % floating point numbers. max_passes controls the number of iterations
- % over the dataset (without changes to alpha) before the algorithm quits.
- %
- % Note: This is a simplified version of the SMO algorithm for training
- % SVMs. In practice, if you want to train an SVM classifier, we
- % recommend using an optimized package such as:
- %
- % LIBSVM (http://www.csie.ntu.edu.tw/~cjlin/libsvm/)
- % SVMLight (http://svmlight.joachims.org/)
- %
- %
- if ~exist('tol', 'var') || isempty(tol)
- tol = 1e-3;
- end
- if ~exist('max_passes', 'var') || isempty(max_passes)
- max_passes = 5;
- end
- % Data parameters
- m = size(X, 1);
- n = size(X, 2);
- % Map 0 to -1
- Y(Y==0) = -1;
- % Variables
- alphas = zeros(m, 1);
- b = 0;
- E = zeros(m, 1);
- passes = 0;
- eta = 0;
- L = 0;
- H = 0;
- % Pre-compute the Kernel Matrix since our dataset is small
- % (in practice, optimized SVM packages that handle large datasets
- % gracefully will _not_ do this)
- %
- % We have implemented optimized vectorized version of the Kernels here so
- % that the svm training will run faster.
- if strcmp(func2str(kernelFunction), 'linearKernel')
- % Vectorized computation for the Linear Kernel
- % This is equivalent to computing the kernel on every pair of examples
- K = X*X';
- elseif strfind(func2str(kernelFunction), 'gaussianKernel')
- % Vectorized RBF Kernel
- % This is equivalent to computing the kernel on every pair of examples
- X2 = sum(X.^2, 2);
- K = bsxfun(@plus, X2, bsxfun(@plus, X2', - 2 * (X * X')));
- K = kernelFunction(1, 0) .^ K;
- else
- % Pre-compute the Kernel Matrix
- % The following can be slow due to the lack of vectorization
- K = zeros(m);
- for i = 1:m
- for j = i:m
- K(i,j) = kernelFunction(X(i,:)', X(j,:)');
- K(j,i) = K(i,j); %the matrix is symmetric
- end
- end
- end
- % Train
- fprintf('\nTraining ...');
- dots = 12;
- while passes < max_passes,
-
- num_changed_alphas = 0;
- for i = 1:m,
-
- % Calculate Ei = f(x(i)) - y(i) using (2).
- % E(i) = b + sum (X(i, :) * (repmat(alphas.*Y,1,n).*X)') - Y(i);
- E(i) = b + sum (alphas.*Y.*K(:,i)) - Y(i);
-
- if ((Y(i)*E(i) < -tol && alphas(i) < C) || (Y(i)*E(i) > tol && alphas(i) > 0)),
-
- % In practice, there are many heuristics one can use to select
- % the i and j. In this simplified code, we select them randomly.
- j = ceil(m * rand());
- while j == i, % Make sure i \neq j
- j = ceil(m * rand());
- end
- % Calculate Ej = f(x(j)) - y(j) using (2).
- E(j) = b + sum (alphas.*Y.*K(:,j)) - Y(j);
- % Save old alphas
- alpha_i_old = alphas(i);
- alpha_j_old = alphas(j);
-
- % Compute L and H by (10) or (11).
- if (Y(i) == Y(j)),
- L = max(0, alphas(j) + alphas(i) - C);
- H = min(C, alphas(j) + alphas(i));
- else
- L = max(0, alphas(j) - alphas(i));
- H = min(C, C + alphas(j) - alphas(i));
- end
-
- if (L == H),
- % continue to next i.
- continue;
- end
- % Compute eta by (14).
- eta = 2 * K(i,j) - K(i,i) - K(j,j);
- if (eta >= 0),
- % continue to next i.
- continue;
- end
-
- % Compute and clip new value for alpha j using (12) and (15).
- alphas(j) = alphas(j) - (Y(j) * (E(i) - E(j))) / eta;
-
- % Clip
- alphas(j) = min (H, alphas(j));
- alphas(j) = max (L, alphas(j));
-
- % Check if change in alpha is significant
- if (abs(alphas(j) - alpha_j_old) < tol),
- % continue to next i.
- % replace anyway
- alphas(j) = alpha_j_old;
- continue;
- end
-
- % Determine value for alpha i using (16).
- alphas(i) = alphas(i) + Y(i)*Y(j)*(alpha_j_old - alphas(j));
-
- % Compute b1 and b2 using (17) and (18) respectively.
- b1 = b - E(i) ...
- - Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
- - Y(j) * (alphas(j) - alpha_j_old) * K(i,j)';
- b2 = b - E(j) ...
- - Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
- - Y(j) * (alphas(j) - alpha_j_old) * K(j,j)';
- % Compute b by (19).
- if (0 < alphas(i) && alphas(i) < C),
- b = b1;
- elseif (0 < alphas(j) && alphas(j) < C),
- b = b2;
- else
- b = (b1+b2)/2;
- end
- num_changed_alphas = num_changed_alphas + 1;
- end
-
- end
-
- if (num_changed_alphas == 0),
- passes = passes + 1;
- else
- passes = 0;
- end
- fprintf('.');
- dots = dots + 1;
- if dots > 78
- dots = 0;
- fprintf('\n');
- end
- if exist('OCTAVE_VERSION')
- fflush(stdout);
- end
- end
- fprintf(' Done! \n\n');
- % Save the model
- idx = alphas > 0;
- model.X= X(idx,:);
- model.y= Y(idx);
- model.kernelFunction = kernelFunction;
- model.b= b;
- model.alphas= alphas(idx);
- model.w = ((alphas.*Y)'*X)';
- end
|