From d17b33131c14864bd1eae275f49a3f148e21cf29 Mon Sep 17 00:00:00 2001 From: Leo Chan Date: Thu, 22 Oct 2020 01:53:21 -0400 Subject: Squashed commit of the sb-vbs branch. Includes the SD-VBS benchmarks modified to: - Use libextra to loop as realtime jobs - Preallocate memory before starting their main computation - Accept input via stdin instead of via argc Does not include the SD-VBS matlab code. Fixes libextra execution in LITMUS^RT. --- .../common/toolbox/toolbox_basic/textons/kmeans2.m | 127 +++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100755 SD-VBS/common/toolbox/toolbox_basic/textons/kmeans2.m (limited to 'SD-VBS/common/toolbox/toolbox_basic/textons/kmeans2.m') diff --git a/SD-VBS/common/toolbox/toolbox_basic/textons/kmeans2.m b/SD-VBS/common/toolbox/toolbox_basic/textons/kmeans2.m new file mode 100755 index 0000000..0bd87b2 --- /dev/null +++ b/SD-VBS/common/toolbox/toolbox_basic/textons/kmeans2.m @@ -0,0 +1,127 @@ +function [centres, options, d2, post, errlog] = kmeans2(centres, data, options) +%KMEANS Trains a k means cluster model. +% +% Description +% CENTRES = KMEANS(CENTRES, DATA, OPTIONS) uses the batch K-means +% algorithm to set the centres of a cluster model. The matrix DATA +% represents the data which is being clustered, with each row +% corresponding to a vector. The sum of squares error function is used. +% The point at which a local minimum is achieved is returned as +% CENTRES. The error value at that point is returned in OPTIONS(8). +% +% [CENTRES, OPTIONS, POST, ERRLOG] = KMEANS(CENTRES, DATA, OPTIONS) +% also returns the cluster number (in a one-of-N encoding) for each +% data point in POST and a log of the error values after each cycle in +% ERRLOG. The optional parameters have the following +% interpretations. +% +% OPTIONS(1) is set to 1 to display error values; also logs error +% values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then +% only warning messages are displayed. If OPTIONS(1) is -1, then +% nothing is displayed. +% +% OPTIONS(2) is a measure of the absolute precision required for the +% value of CENTRES at the solution. If the absolute difference between +% the values of CENTRES between two successive steps is less than +% OPTIONS(2), then this condition is satisfied. +% +% OPTIONS(3) is a measure of the precision required of the error +% function at the solution. If the absolute difference between the +% error functions between two successive steps is less than OPTIONS(3), +% then this condition is satisfied. Both this and the previous +% condition must be satisfied for termination. +% +% OPTIONS(14) is the maximum number of iterations; default 100. +% +% See also +% GMMINIT, GMMEM +% + +% Copyright (c) Christopher M Bishop, Ian T Nabney (1996, 1997) + +[ndata, data_dim] = size(data); +[ncentres, dim] = size(centres); + +if dim ~= data_dim + error('Data dimension does not match dimension of centres') +end + +if (ncentres > ndata) + error('More centres than data') +end + +% Sort out the options +if (options(14)) + niters = options(14); +else + niters = 100; +end + +store = 0; +if (nargout > 3) + store = 1; + errlog = zeros(1, niters); +end + +% Check if centres and posteriors need to be initialised from data +if (options(5) == 1) + % Do the initialisation + perm = randperm(ndata); + perm = perm(1:ncentres); + + % Assign first ncentres (permuted) data points as centres + centres = data(perm, :); +end +% Matrix to make unit vectors easy to construct +id = eye(ncentres); + +% Main loop of algorithm +for n = 1:niters + + % Save old centres to check for termination + old_centres = centres; + + % Calculate posteriors based on existing centres + d2 = dist2(data, centres); + % Assign each point to nearest centre + [minvals, index] = min(d2', [], 1); + post = id(index,:); + + num_points = sum(post, 1); + % Adjust the centres based on new posteriors + for j = 1:ncentres + if (num_points(j) > 0) + centres(j,:) = sum(data(find(post(:,j)),:), 1)/num_points(j); + end + end + + % Error value is total squared distance from cluster centres + e = sum(minvals); + tmp = sort(minvals); + e95 = sqrt(tmp(round(length(tmp) * 0.95))); + erms = sqrt(e/ndata); + if store + errlog(n) = e; + end + if options(1) > 0 + fprintf(1, ' Cycle %4d RMS Error %11.6f 95-tier Error %11.6f\n', n, erms,e95); + end + + if n > 1 + % Test for termination + if max(max(abs(centres - old_centres))) < options(2) & ... + abs(old_e - e) < options(3) + options(8) = e; + return; + end + end + old_e = e; +end + +% If we get here, then we haven't terminated in the given number of +% iterations. +options(8) = e; +%if (options(1) >= 0) +% disp('Warning: Maximum number of iterations has been exceeded'); +%end + -- cgit v1.2.2