math-optim/algo/algo.go
leo b742f0e091
All checks were successful
continuous-integration/drone/push Build is passing
go(algo,de): implement jDE (wip)
2023-01-21 02:35:29 +01:00

415 lines
10 KiB
Go

// Copyright 2023 wanderer <a_mirre at utb dot cz>
// SPDX-License-Identifier: GPL-3.0-or-later
package algo
import (
"fmt"
"log"
"sort"
"sync"
"git.dotya.ml/wanderer/math-optim/algo/de"
"git.dotya.ml/wanderer/math-optim/bench"
"git.dotya.ml/wanderer/math-optim/report"
"git.dotya.ml/wanderer/math-optim/stats"
)
// var Algos = []string{"Random Search", "Stochastic Hill Climbing"}
// mu protects access to meanStats.
var mu sync.Mutex
// mCoMPL protexts access to comparisonOfMeansPicList.
var mCoMPL sync.Mutex
var meanStats = &stats.MeanStats{}
var comparisonOfMeansPicList = &report.PicList{Algo: "Comparison of Means"}
// getComparisonOfMeansPics returns a sorted slice of pics field from the
// package global 'algoMeanPics'.
func getComparisonOfMeansPics() []report.Pic {
// note: sorting by filename (dimens value being 0-padded at generation
// time), relying on this as a hack so that we didn't have to implement our
// own natural-order sorter.
sort.Sort(comparisonOfMeansPicList.Pics)
return comparisonOfMeansPicList.Pics
}
// saveAlgoMeans saves algo bench means safely.
func saveAlgoMeans(sabm stats.AlgoBenchMean) {
mu.Lock()
meanStats.AlgoMeans = append(meanStats.AlgoMeans, sabm)
mu.Unlock()
}
// GetMeanStats returns a pointer of type stats.MeanStats to a sorted package
// global 'meanStats'.
func GetMeanStats() *stats.MeanStats {
sort.Sort(meanStats)
return meanStats
}
// PrepComparisonOfMeans returns a pointer to a slice of pics (of type
// report.PicList) and an integer - the count of unique benchmarking functions
// used.
func PrepComparisonOfMeans(wg *sync.WaitGroup) (*report.PicList, int) {
pL := report.NewPicList()
meanStats := GetMeanStats()
algos := make([]string, 0)
// learn how many algos were processed based on the data.
for _, v := range meanStats.AlgoMeans {
// if algos is empty just add the value directly, else determine if
// it's already been added or not.
if len(algos) > 0 {
alreadyadded := false
for _, algoName := range algos {
if algoName == v.Algo {
// early bail if already added.
alreadyadded = true
break
}
}
if !alreadyadded {
algos = append(algos, v.Algo)
}
} else {
algos = append(algos, v.Algo)
}
}
// construct title consisting of names of all involved algorithms.
for _, v := range algos {
switch pL.Algo {
case "":
pL.Algo = v
default:
pL.Algo += " vs " + v
}
}
log.Println(`generating "Comparison of Means" plots`)
algoCount := len(algos)
dimLen := len(bench.Dimensions)
// without Rastrigin in active duty for the moment.
benchCount := len(bench.Functions) - 1
// note: this is a wee bit ugly.
for d := 0; d < dimLen; d++ {
for i := 0; i < benchCount; i++ {
dimXAlgoMeanVals := make([]stats.AlgoMeanVals, 0, algoCount)
for j := 0; j < algoCount*benchCount; j += benchCount {
neighbInfo := ""
// only add info about neighbours if it was changed (from
// the default of -1).
if n := meanStats.AlgoMeans[d+j].BenchMeans[i].Neighbours; n != -1 {
neighbInfo = fmt.Sprintf(" (N: %d)", n)
}
ms := &stats.AlgoMeanVals{
Title: meanStats.AlgoMeans[d+j].Algo + neighbInfo,
MeanVals: meanStats.AlgoMeans[d+j].BenchMeans[i].MeanVals,
}
dimXAlgoMeanVals = append(dimXAlgoMeanVals, *ms)
}
dimens := meanStats.AlgoMeans[d].BenchMeans[i].Dimens
iterations := meanStats.AlgoMeans[d].BenchMeans[i].Iterations
bench := meanStats.AlgoMeans[d].BenchMeans[i].Bench
wg.Add(1)
// construct plots concurrently.
go PlotMeanValsMulti(
wg, dimens, iterations, bench, "plot-", ".pdf",
dimXAlgoMeanVals...,
)
}
}
// wait for all the plotting to conclude.
wg.Wait()
pL.Pics = getComparisonOfMeansPics()
return pL, benchCount
}
// DoRandomSearch executes a search using the 'Random search' method.
func DoRandomSearch(wg *sync.WaitGroup, m *sync.Mutex) {
defer wg.Done()
printRandomSearch("starting...")
// funcCount is the number of bench functions available.
// without Rastrigin in active duty for the moment.
funcCount := len(bench.Functions) - 1
// stats for the current algo (RandomSearch).
algoStats := make([][]stats.Stats, funcCount)
// ch serves as a way to get the actual computed output.
ch := make(chan []stats.Stats, funcCount)
defer close(ch)
for i := range algoStats {
// ng y'all.
go RandomSearchNG(10000, 30, bench.Dimensions, bench.FuncNames[i], ch)
}
// get results.
for i := range algoStats {
s := <-ch
algoStats[i] = s
}
// save stats to json.
// stats.SaveStats(schw, "schwefel")
// stats.SaveStats(djg1, "djg1")
// stats.SaveStats(djg2, "djg2")
pCh := make(chan report.PicList, funcCount*len(bench.Dimensions))
pMeanCh := make(chan report.PicList, funcCount*len(bench.Dimensions))
defer close(pCh)
defer close(pMeanCh)
for i := range algoStats {
go plotAllDims(algoStats[i], "plot", ".pdf", pCh, pMeanCh)
}
pLs := []report.PicList{}
pLsMean := []report.PicList{}
for range algoStats {
pL := <-pCh
pLMean := <-pMeanCh
pLs = append(pLs, pL)
pLsMean = append(pLsMean, pLMean)
}
algoName := "Random Search"
// protect access to shared data.
m.Lock()
report.SavePicsToFile(pLs, pLsMean, algoName)
stats.SaveTable(algoName, algoStats)
m.Unlock()
}
// TODO(me): split this package to multiple - package per algo, common code here.
// TODO(me): implement Simulated Annaeling.
// TODO(me): implement a variant of Stochastic Hill Climber that tweaks its
// Neighbourhood size or MaxNeighbourVariancePercent based on the
// latest 5 values, if they don't change, params get tweaked to
// broaden the search space to make sure it's not stuck in a local
// extreme.
// DoStochasticHillClimbing performs a search using the 'Stochastic Hill
// Climbing' method.
func DoStochasticHillClimbing(wg *sync.WaitGroup, m *sync.Mutex) {
defer wg.Done()
printSHC("starting...")
// funcCount is the number of bench functions available.
// without Rastrigin in active duty for the moment.
funcCount := len(bench.Functions) - 1
// stats for the current algo (StochasticHillClimber).
algoStats := make([][]stats.Stats, funcCount)
// ch serves as a way to get the actual computed output.
ch := make(chan []stats.Stats, funcCount)
defer close(ch)
for i := range algoStats {
// params:
// maxFES, benchMinIters, neighbours int,
// theD []int,
// benchFunc string,
// ch chan []stats.Stats
go HillClimb(10000, 30, 10, bench.Dimensions, bench.FuncNames[i], ch)
}
// get results.
for i := range algoStats {
s := <-ch
algoStats[i] = s
}
pCh := make(chan report.PicList, funcCount*len(bench.Dimensions))
pMeanCh := make(chan report.PicList, funcCount*len(bench.Dimensions))
defer close(pCh)
defer close(pMeanCh)
for _, algoStat := range algoStats {
go plotAllDims(algoStat, "plot", ".pdf", pCh, pMeanCh)
}
pLs := []report.PicList{}
pLsMean := []report.PicList{}
for range algoStats {
pL := <-pCh
pLMean := <-pMeanCh
pLs = append(pLs, pL)
pLsMean = append(pLsMean, pLMean)
}
algoName := "Stochastic Hill Climbing"
// protect access to shared data.
m.Lock()
report.SavePicsToFile(pLs, pLsMean, algoName)
stats.SaveTable(algoName, algoStats)
m.Unlock()
}
func DoStochasticHillClimbing100Neigh(wg *sync.WaitGroup, m *sync.Mutex) {
defer wg.Done()
printSHC("starting...")
// funcCount is the number of bench functions available.
// without Rastrigin in active duty for the moment.
funcCount := len(bench.Functions) - 1
// stats for the current algo (StochasticHillClimber).
algoStats := make([][]stats.Stats, funcCount)
// ch serves as a way to get the actual computed output.
ch := make(chan []stats.Stats, funcCount)
for i := range algoStats {
go HillClimb(10000, 30, 100, bench.Dimensions, bench.FuncNames[i], ch)
}
// get results.
for i := range algoStats {
s := <-ch
algoStats[i] = s
}
pCh := make(chan report.PicList, funcCount*len(bench.Dimensions))
pMeanCh := make(chan report.PicList, funcCount*len(bench.Dimensions))
for _, algoStat := range algoStats {
go plotAllDims(algoStat, "plot", ".pdf", pCh, pMeanCh)
}
pLs := []report.PicList{}
pLsMean := []report.PicList{}
for range algoStats {
pL := <-pCh
pLMean := <-pMeanCh
pLs = append(pLs, pL)
pLsMean = append(pLsMean, pLMean)
}
algoName := "Stochastic Hill Climbing 100 Neighbours"
// protect access to shared data.
m.Lock()
report.SavePicsToFile(pLs, pLsMean, algoName)
// report.SavePicsToFile(pLsMean, pLs, algoName)
// stats.PrintStatisticTable(algoStats)
stats.SaveTable(algoName, algoStats)
m.Unlock()
}
func DojDE(wg *sync.WaitGroup, m *sync.Mutex) {
defer wg.Done()
de.LogPrintln("starting")
// funcCount is the number of bench functions available and tested.
funcCount := len(bench.Functions)
// stats for the current algo.
algoStats := make([][]stats.Stats, funcCount)
// ch serves as a way to get the actual computed output.
ch := make(chan []stats.Stats, funcCount)
// chAlgoMeans := make(chan *stats.AlgoBenchMean, 1)
chAlgoMeans := make(chan *stats.AlgoBenchMean, funcCount)
defer close(ch)
defer close(chAlgoMeans)
// jDE params.
np := 50
f := 0.5
cr := 0.9
for i := range algoStats {
jDE := de.NewjDE()
// params:
// Generations, minimum bench iterations, mutation strategy, parameter
// self-adaptation scheme, initial population size, differential
// weight, mutation constant, dimensions, bench name and a
// synchronisation channel.
//
// -1 to disable generation limits,
// n > 0 for minimum bench iterations
// 0..17 to choose a mutation strategy,
// 0..1 to select a parameter self-adaptation scheme,
// np >= 4 as initial population size.
jDE.Init(-1, 30, 0, 0, np, f, cr, bench.DimensionsGA, bench.FuncNames[i], ch, chAlgoMeans)
go jDE.Run()
}
// get results.
for i := range algoStats {
s := <-ch
aM := <-chAlgoMeans
algoStats[i] = s
saveAlgoMeans(*aM)
}
pCh := make(chan report.PicList, funcCount*len(bench.DimensionsGA))
pMeanCh := make(chan report.PicList, funcCount*len(bench.DimensionsGA))
for _, algoStat := range algoStats {
go plotAllDims(algoStat, "plot", ".pdf", pCh, pMeanCh)
}
pLs := []report.PicList{}
pLsMean := []report.PicList{}
for range algoStats {
pL := <-pCh
pLMean := <-pMeanCh
pLs = append(pLs, pL)
pLsMean = append(pLsMean, pLMean)
}
algoName := "Self-adapting Differential Evolution"
// protect access to shared data.
m.Lock()
report.SavePicsToFile(pLs, pLsMean, algoName)
stats.SaveTable(algoName, algoStats)
m.Unlock()
}