VBC_aux/c8_nelder_mead.m
2021-11-13 13:14:12 +01:00

88 lines
3.4 KiB
Matlab

clc;clear;clf;
% functions
RBeval=@(x,y) (1-x)^2+100*(y-x^2)^2;
RBplot=@(x,y,z)contour(x,y,z,[.05,2,10,75,333,1666],'Color',[0,0,0],'LineWidth',1.33);
% plot rosen function
[X,Y]=meshgrid(linspace(-2,2,256), linspace(-1,3,256)); z=arrayfun(RBeval,X,Y);
RBplot(X(1,:),Y(:,1),z);
% constants
alpha=1; gamma=2; rho=1/2; sigma=1/2;
% variables
iter=0;
% randomize starting points
p0 = (3+2)*rand(1,3)-2; p0(3) = RBeval(p0(1),p0(2));
p1 = (3+2)*rand(1,3)-2; p1(3) = RBeval(p1(1),p1(2));
p2 = (3+2)*rand(1,3)-2; p2(3) = RBeval(p2(1),p2(2));
BMW = [p0;p1;p2]; B_prev = Inf; M_prev = Inf; W_prev = Inf;
% prepare visualization stuff
hold on
anim=animatedline([p0(1),p1(1),p2(1)],[p0(2),p1(2),p2(2)],[p0(3),p1(3),p2(3)],'Color','b','LineWidth',1.33);
bplot=[]; mplot=[]; wplot=[]; rplot=[]; eplot=[]; cplot=[];
while true
disp(['iteration: ',num2str(iter)]);
% sort by function value
BMW = sortrows(BMW,3);
% assign points to separate vars
B = BMW(1,:); M = BMW(2,:); W = BMW(3,:);
if (abs(B_prev-B(3)) < eps) && (abs(M_prev-M(3)) < eps) && (abs(W_prev-W(3)) < eps)
disp("done");
disp(['found minimum ',num2str(B(3)),' at [',num2str(B(1)),',',num2str(B(2)),']']);
delete(bplot); delete(mplot); delete(wplot); delete(rplot); delete(eplot); delete(cplot);
clearpoints(anim); scatter3(B(1),B(2),B(3),80,'g*','LineWidth',3);
break
elseif iter >= 500
disp("max iteration count exceeded")
break
end
disp("sorted")
% save function values for comparison
B_prev = B(3); M_prev = M(3); W_prev = W(3);
% compute centroid between best and second-best points and evaluate
BM_c = (B(:)+M(:)).'/2; BM_c(3) = RBeval(BM_c(1),BM_c(2));
disp("got centroid")
% compute reflected point and evaluate
R = BM_c + alpha*(BM_c-W); R(3) = RBeval(R(1),R(2));
delete(rplot); rplot = scatter3(R(1),R(2),R(3),'mo');
disp("got reflected pt")
% check the generated point and decide on progress
if (B(3) <= R(3)) && (R(3) < M(3))
W = R;
disp("replaced worst with reflected")
elseif R(3) < B(3)
disp("reflected better")
% compute expanded point and evaluate
E = BM_c + gamma*(R-BM_c); E(3) = RBeval(E(1),E(2));
delete(eplot); eplot = scatter3(E(1),E(2),E(3),'ko');
disp("got expanded pt")
if E(3) < R(3)
W = E;
disp("expanded better")
else
W = R;
disp("reflected still best")
end
else
disp("contraction")
% compute contracted point and evaluate
C = BM_c + rho*(W-BM_c); C(3) = RBeval(C(1),C(2));
delete(cplot); cplot = scatter3(C(1),C(2),C(3),'co');
if C(3) < W(3)
W = C;
disp("contracted better")
else
% recalculate second-best and worst (shrink)
M = B + sigma*(M-B); M(3) = RBeval(M(1),M(2));
W = B + sigma*(W-B); W(3) = RBeval(W(1),W(2));
disp("shrunk")
end
end
BMW = [B;M;W];
iter=iter+1;
clearpoints(anim);
delete(bplot); delete(mplot); delete(wplot);
bplot=scatter3(B(1),B(2),B(3),'go'); mplot=scatter3(M(1),M(2),M(3),'yo'); wplot=scatter3(W(1),W(2),W(3),'ro');
addpoints(anim,[B(1),M(1),W(1)],[B(2),M(2),W(2)],[B(3),M(3),W(3)]); addpoints(anim,B(1),B(2),B(3));
drawnow; pause(0.05);
end