Add nelder-mead in matlab
This commit is contained in:
parent
8a12d6462e
commit
4c1c89cb21
|
@ -0,0 +1,88 @@
|
|||
clc;clear;clf;
|
||||
% functions
|
||||
RBeval=@(x,y) (1-x)^2+100*(y-x^2)^2;
|
||||
RBplot=@(x,y,z)contour(x,y,z,[.05,2,10,75,333,1666],'Color',[0,0,0],'LineWidth',1.33);
|
||||
% plot rosen function
|
||||
[X,Y]=meshgrid(linspace(-2,2,256), linspace(-1,3,256)); z=arrayfun(RBeval,X,Y);
|
||||
RBplot(X(1,:),Y(:,1),z);
|
||||
% constants
|
||||
alpha=1; gamma=2; rho=1/2; sigma=1/2;
|
||||
% variables
|
||||
iter=0;
|
||||
% randomize starting points
|
||||
p0 = (3+2)*rand(1,3)-2; p0(3) = RBeval(p0(1),p0(2));
|
||||
p1 = (3+2)*rand(1,3)-2; p1(3) = RBeval(p1(1),p1(2));
|
||||
p2 = (3+2)*rand(1,3)-2; p2(3) = RBeval(p2(1),p2(2));
|
||||
BMW = [p0;p1;p2]; B_prev = Inf; M_prev = Inf; W_prev = Inf;
|
||||
% prepare visualization stuff
|
||||
hold on
|
||||
anim=animatedline([p0(1),p1(1),p2(1)],[p0(2),p1(2),p2(2)],[p0(3),p1(3),p2(3)],'Color','b','LineWidth',1.33);
|
||||
bplot=[]; mplot=[]; wplot=[]; rplot=[]; eplot=[]; cplot=[];
|
||||
|
||||
while true
|
||||
disp(['iteration: ',num2str(iter)]);
|
||||
% sort by function value
|
||||
BMW = sortrows(BMW,3);
|
||||
% assign points to separate vars
|
||||
B = BMW(1,:); M = BMW(2,:); W = BMW(3,:);
|
||||
if (abs(B_prev-B(3)) < eps) && (abs(M_prev-M(3)) < eps) && (abs(W_prev-W(3)) < eps)
|
||||
disp("done");
|
||||
disp(['found minimum ',num2str(B(3)),' at [',num2str(B(1)),',',num2str(B(2)),']']);
|
||||
delete(bplot); delete(mplot); delete(wplot); delete(rplot); delete(eplot); delete(cplot);
|
||||
clearpoints(anim); scatter3(B(1),B(2),B(3),80,'g*','LineWidth',3);
|
||||
break
|
||||
elseif iter >= 500
|
||||
disp("max iteration count exceeded")
|
||||
break
|
||||
end
|
||||
disp("sorted")
|
||||
% save function values for comparison
|
||||
B_prev = B(3); M_prev = M(3); W_prev = W(3);
|
||||
% compute centroid between best and second-best points and evaluate
|
||||
BM_c = (B(:)+M(:)).'/2; BM_c(3) = RBeval(BM_c(1),BM_c(2));
|
||||
disp("got centroid")
|
||||
% compute reflected point and evaluate
|
||||
R = BM_c + alpha*(BM_c-W); R(3) = RBeval(R(1),R(2));
|
||||
delete(rplot); rplot = scatter3(R(1),R(2),R(3),'mo');
|
||||
disp("got reflected pt")
|
||||
% check the generated point and decide on progress
|
||||
if (B(3) <= R(3)) && (R(3) < M(3))
|
||||
W = R;
|
||||
disp("replaced worst with reflected")
|
||||
elseif R(3) < B(3)
|
||||
disp("reflected better")
|
||||
% compute expanded point and evaluate
|
||||
E = BM_c + gamma*(R-BM_c); E(3) = RBeval(E(1),E(2));
|
||||
delete(eplot); eplot = scatter3(E(1),E(2),E(3),'ko');
|
||||
disp("got expanded pt")
|
||||
if E(3) < R(3)
|
||||
W = E;
|
||||
disp("expanded better")
|
||||
else
|
||||
W = R;
|
||||
disp("reflected still best")
|
||||
end
|
||||
else
|
||||
disp("contraction")
|
||||
% compute contracted point and evaluate
|
||||
C = BM_c + rho*(W-BM_c); C(3) = RBeval(C(1),C(2));
|
||||
delete(cplot); cplot = scatter3(C(1),C(2),C(3),'co');
|
||||
if C(3) < W(3)
|
||||
W = C;
|
||||
disp("contracted better")
|
||||
else
|
||||
% recalculate second-best and worst (shrink)
|
||||
M = B + sigma*(M-B); M(3) = RBeval(M(1),M(2));
|
||||
W = B + sigma*(W-B); W(3) = RBeval(W(1),W(2));
|
||||
disp("shrunk")
|
||||
end
|
||||
end
|
||||
BMW = [B;M;W];
|
||||
iter=iter+1;
|
||||
clearpoints(anim);
|
||||
delete(bplot); delete(mplot); delete(wplot);
|
||||
bplot=scatter3(B(1),B(2),B(3),'go'); mplot=scatter3(M(1),M(2),M(3),'yo'); wplot=scatter3(W(1),W(2),W(3),'ro');
|
||||
addpoints(anim,[B(1),M(1),W(1)],[B(2),M(2),W(2)],[B(3),M(3),W(3)]); addpoints(anim,B(1),B(2),B(3));
|
||||
drawnow; pause(0.05);
|
||||
end
|
||||
|
Loading…
Reference in New Issue