pcmt/modules/hibp/schedule.go

356 lines
7.1 KiB
Go
Raw Normal View History

// Copyright 2023 wanderer <a_mirre at utb dot cz>
// SPDX-License-Identifier: AGPL-3.0-only
// implement request scheduling in order to attempt to comply with the
// HIBP's rate-limiting.
package hibp
import (
"errors"
"net/http"
"os"
"os/signal"
"sync"
"time"
"golang.org/x/exp/slog"
)
type reqQueue struct {
Requests []*http.Request
RespChans []*chan *http.Response
ErrChans []*chan error
}
const sch = "HIBP requests scheduler"
var (
// requests-per-minute limit; consult https://haveibeenpwned.com/API/Key.
rpm = 10
requestsFired int
working bool
timeToWait = 0 * time.Millisecond
backOff = 50 * time.Millisecond
schedulerPeriod = 2000 * time.Millisecond
schedInitNapDuration = 2000 * time.Millisecond
nextCheckpoint time.Time
queue = &reqQueue{}
// requestsFired mutex.
rfLock sync.RWMutex
// timeToWait mutex.
ttwLock sync.RWMutex
// requests queue mutex.
rqqLock sync.RWMutex
// nextCheckpoint mutex.
nxckpLock sync.RWMutex
// working state mutex.
wLock sync.RWMutex
)
// RunReqScheduler runs the HIBP requests scheduler, which schedules the
// request in such a fashion that it does not cross the limit defined by the
// used API key.
func RunReqScheduler(quit chan os.Signal, errCh chan error, wg *sync.WaitGroup) {
slog.Info("Hello from " + sch)
initQueue()
done := false
ok := make(chan error)
timeoutchan := make(chan bool)
defer close(quit)
defer signal.Stop(quit)
if !isTesting() {
defer close(ok)
defer close(timeoutchan)
}
go func() {
// sleep intermittently, report to chan when done.
<-time.After(schedInitNapDuration)
timeoutchan <- true
}()
select {
// listen for signal interrupts from the start.
case <-quit:
slog.Info("Interrupt received, shutting down " + sch)
// this line below is a trick: we'd need to wait here for the entire
// schedNapDuration, however, since the call is executed inside a goroutine
// that is not synchronised, the channel read gets spawned and the
// program simply continues, not caring one bit that the timeout has
// not yet passed. that is a great workaround for when we *do not*
// actually want to stand here and wait that time out.
go func() { <-timeoutchan }()
wg.Done()
// break
return
// after we've had a nap, it's time to work.
case <-timeoutchan:
slog.Info("Starting " + sch)
setNextCheckpoint(time.Now().Add(time.Minute))
for !done {
checkNextCheckpoint()
go doReqs(ok, wg) // nolint:wsl
select {
case <-quit:
done = true
if os.Getenv("GO_ENV_TESTING") == "1" {
slog.Info("Testing, shutting down " + sch)
wg.Done()
return
}
err := <-ok
errCh <- err
slog.Info("Interrupt received, shutting down " + sch)
break
case err := <-ok:
if err != nil {
// alternatively, only notify that there was an error:
// slog.Error(sch + " encountered an error", "error", err)
slog.Error("Shutting down "+sch+" due to an error",
"error", err)
errCh <- err
wg.Done()
return
}
slog.Debug("Requests scheduler sleeping")
time.Sleep(schedulerPeriod)
slog.Debug("Requests scheduler awake again")
continue
}
}
}
wg.Done()
slog.Info(sch + " done")
}
func doReqs(ok chan error, wg *sync.WaitGroup) {
if isWorking() {
slog.Debug("Already WIP")
return
}
setWorking(true)
zeroTTWOrWait()
wg.Add(1)
go func() {
defer wg.Done()
if zeroQueue() {
setWorking(false)
slog.Debug("Queue empty, nothing to do")
return
}
rqqLock.Lock()
rfLock.Lock()
for i := range queue.Requests {
// again:
if requestsFired < rpm {
req := queue.Requests[i]
respCh := queue.RespChans[i]
errCh := queue.ErrChans[i]
requestsFired++
slog.Debug("Sending the request")
resp, err := client.Do(req)
if err != nil {
slog.Error("got err", "error", err)
if errors.Is(err, http.ErrHandlerTimeout) {
ok <- err
rqqLock.Unlock()
rfLock.Unlock()
setWorking(false)
return // alternatively: goto again
}
}
*respCh <- resp
*errCh <- err
// remove the performed request and shift the remaining elements left.
queue.Requests = append(queue.Requests[:i], queue.Requests[i+1:]...)
// remove the corresponding response chan.
queue.RespChans = append(queue.RespChans[:i], queue.RespChans[i+1:]...)
// remove the corresponding error chan.
queue.ErrChans = append(queue.ErrChans[:i], queue.ErrChans[i+1:]...)
} else {
slog.Error("Throttled - setting time to wait")
ttwLock.Lock()
nxckpLock.Lock()
timeToWait = time.Until(nextCheckpoint)
time.Sleep(timeToWait) // yes, sleep while locked.
ttwLock.Unlock()
nxckpLock.Unlock()
return
}
}
rfLock.Unlock()
rqqLock.Unlock()
}()
setWorking(false)
ok <- nil
}
// scheduleReq schedules a HTTP requests. respCh and errCh are the channels to
// send back *http.Response and error, respectively.
func scheduleReq(r *http.Request, respCh *chan *http.Response, errCh *chan error) {
go func() {
rqqLock.Lock()
slog.Debug("Adding req to queue", "method", r.Method, "url", r.URL, "ua", r.UserAgent())
queue.Requests = append(queue.Requests, r)
queue.RespChans = append(queue.RespChans, respCh)
queue.ErrChans = append(queue.ErrChans, errCh)
rqqLock.Unlock()
slog.Debug("Added req to queue")
}()
}
func setWorking(w bool) {
wLock.Lock()
defer wLock.Unlock()
working = w
}
func isWorking() bool {
wLock.RLock()
defer wLock.RUnlock()
return working
}
func setNextCheckpoint(t time.Time) {
nxckpLock.Lock()
defer nxckpLock.Unlock()
slog.Debug("Setting next checkpoint")
nextCheckpoint = t
}
func checkNextCheckpoint() {
nxckpLock.Lock()
defer nxckpLock.Unlock()
if nextCheckpoint.Before(time.Now()) {
slog.Debug("Checkpoint passed, updating", "checkpoint", nextCheckpoint)
nextCheckpoint = time.Now().Add(1 * time.Minute)
slog.Debug("Checkpoint updated", "checkpoint", nextCheckpoint)
}
}
func initQueue() {
slog.Debug("initialising the queue")
rqqLock.Lock()
defer rqqLock.Unlock()
if isTesting() || queue.Requests == nil {
slog.Debug("initialising queue.Requests")
queue.Requests = make([]*http.Request, 0)
}
if isTesting() || queue.RespChans == nil {
slog.Debug("initialising queue.RespChans")
queue.RespChans = make([]*chan *http.Response, 0)
}
if isTesting() || queue.ErrChans == nil {
slog.Debug("initialising queue.ErrChans")
queue.ErrChans = make([]*chan error, 0)
}
}
func zeroQueue() bool {
rqqLock.RLock()
defer rqqLock.RUnlock()
if len(queue.Requests) == 0 {
return true
}
slog.Debug("Something in queue")
return false
}
func zeroTTWOrWait() {
timeoutchan := make(chan bool)
ttwLock.RLock()
if timeToWait > 0*time.Millisecond {
slog.Debug("Waiting ttw")
go func() {
ttwLock.RLock()
ttw := timeToWait
ttwLock.RUnlock()
<-time.After(ttw + backOff)
timeoutchan <- true
}()
} else {
ttwLock.RUnlock()
return
}
<-timeoutchan
ttwLock.Lock()
timeToWait = 0 * time.Millisecond
ttwLock.Unlock()
slog.Debug("Waited ttw")
}
func isTesting() bool {
return os.Getenv("GO_ENV_TESTING") == "1"
}