// Copyright 2023 wanderer // SPDX-License-Identifier: AGPL-3.0-only // implement request scheduling in order to attempt to comply with the // HIBP's rate-limiting. package hibp import ( "errors" "net/http" "os" "os/signal" "sync" "time" "golang.org/x/exp/slog" ) type reqQueue struct { Requests []*http.Request RespChans []*chan *http.Response ErrChans []*chan error } const sch = "HIBP requests scheduler" var ( // requests-per-minute limit; consult https://haveibeenpwned.com/API/Key. rpm = 10 requestsFired int working bool timeToWait = 0 * time.Millisecond backOff = 50 * time.Millisecond schedulerPeriod = 2000 * time.Millisecond schedInitNapDuration = 2000 * time.Millisecond nextCheckpoint time.Time queue = &reqQueue{} // requestsFired mutex. rfLock sync.RWMutex // timeToWait mutex. ttwLock sync.RWMutex // requests queue mutex. rqqLock sync.RWMutex // nextCheckpoint mutex. nxckpLock sync.RWMutex // working state mutex. wLock sync.RWMutex ) // RunReqScheduler runs the HIBP requests scheduler, which schedules the // request in such a fashion that it does not cross the limit defined by the // used API key. func RunReqScheduler(quit chan os.Signal, errCh chan error, wg *sync.WaitGroup) { slog.Info("Hello from " + sch) initQueue() done := false ok := make(chan error) timeoutchan := make(chan bool) defer close(quit) defer signal.Stop(quit) if !isTesting() { defer close(ok) defer close(timeoutchan) } go func() { // sleep intermittently, report to chan when done. <-time.After(schedInitNapDuration) timeoutchan <- true }() select { // listen for signal interrupts from the start. case <-quit: slog.Info("Interrupt received, shutting down " + sch) // this line below is a trick: we'd need to wait here for the entire // schedNapDuration, however, since the call is executed inside a goroutine // that is not synchronised, the channel read gets spawned and the // program simply continues, not caring one bit that the timeout has // not yet passed. that is a great workaround for when we *do not* // actually want to stand here and wait that time out. go func() { <-timeoutchan }() wg.Done() // break return // after we've had a nap, it's time to work. case <-timeoutchan: slog.Info("Starting " + sch) setNextCheckpoint(time.Now().Add(time.Minute)) for !done { checkNextCheckpoint() go doReqs(ok, wg) // nolint:wsl select { case <-quit: done = true if os.Getenv("GO_ENV_TESTING") == "1" { slog.Info("Testing, shutting down " + sch) wg.Done() return } err := <-ok errCh <- err slog.Info("Interrupt received, shutting down " + sch) break case err := <-ok: if err != nil { // alternatively, only notify that there was an error: // slog.Error(sch + " encountered an error", "error", err) slog.Error("Shutting down "+sch+" due to an error", "error", err) errCh <- err wg.Done() return } slog.Debug("Requests scheduler sleeping") time.Sleep(schedulerPeriod) slog.Debug("Requests scheduler awake again") continue } } } wg.Done() slog.Info(sch + " done") } func doReqs(ok chan error, wg *sync.WaitGroup) { if isWorking() { slog.Debug("Already WIP") return } setWorking(true) zeroTTWOrWait() wg.Add(1) go func() { defer wg.Done() if zeroQueue() { setWorking(false) slog.Debug("Queue empty, nothing to do") return } rqqLock.Lock() rfLock.Lock() for i := range queue.Requests { // again: if requestsFired < rpm { req := queue.Requests[i] respCh := queue.RespChans[i] errCh := queue.ErrChans[i] requestsFired++ slog.Debug("Sending the request") resp, err := client.Do(req) if err != nil { slog.Error("got err", "error", err) if errors.Is(err, http.ErrHandlerTimeout) { ok <- err rqqLock.Unlock() rfLock.Unlock() setWorking(false) return // alternatively: goto again } } *respCh <- resp *errCh <- err // remove the performed request and shift the remaining elements left. queue.Requests = append(queue.Requests[:i], queue.Requests[i+1:]...) // remove the corresponding response chan. queue.RespChans = append(queue.RespChans[:i], queue.RespChans[i+1:]...) // remove the corresponding error chan. queue.ErrChans = append(queue.ErrChans[:i], queue.ErrChans[i+1:]...) } else { slog.Error("Throttled - setting time to wait") ttwLock.Lock() nxckpLock.Lock() timeToWait = time.Until(nextCheckpoint) ttwLock.Unlock() nxckpLock.Unlock() break } } rfLock.Unlock() rqqLock.Unlock() }() setWorking(false) ok <- nil } // scheduleReq schedules a HTTP requests. respCh and errCh are the channels to // send back *http.Response and error, respectively. func scheduleReq(r *http.Request, respCh *chan *http.Response, errCh *chan error) { go func() { rqqLock.Lock() slog.Debug("Adding req to queue", "method", r.Method, "url", r.URL, "ua", r.UserAgent()) queue.Requests = append(queue.Requests, r) queue.RespChans = append(queue.RespChans, respCh) queue.ErrChans = append(queue.ErrChans, errCh) rqqLock.Unlock() slog.Debug("Added req to queue") }() } func setWorking(w bool) { wLock.Lock() defer wLock.Unlock() working = w } func isWorking() bool { wLock.RLock() defer wLock.RUnlock() return working } func setNextCheckpoint(t time.Time) { nxckpLock.Lock() defer nxckpLock.Unlock() slog.Debug("Setting next checkpoint") nextCheckpoint = t } func checkNextCheckpoint() { nxckpLock.Lock() defer nxckpLock.Unlock() if nextCheckpoint.Before(time.Now()) { slog.Debug("Checkpoint passed, updating", "checkpoint", nextCheckpoint) nextCheckpoint = time.Now().Add(1 * time.Minute) slog.Debug("Checkpoint updated", "checkpoint", nextCheckpoint) } } func initQueue() { slog.Debug("initialising the queue") rqqLock.Lock() defer rqqLock.Unlock() if isTesting() || queue.Requests == nil { slog.Debug("initialising queue.Requests") queue.Requests = make([]*http.Request, 0) } if isTesting() || queue.RespChans == nil { slog.Debug("initialising queue.RespChans") queue.RespChans = make([]*chan *http.Response, 0) } if isTesting() || queue.ErrChans == nil { slog.Debug("initialising queue.ErrChans") queue.ErrChans = make([]*chan error, 0) } } func zeroQueue() bool { rqqLock.RLock() defer rqqLock.RUnlock() if len(queue.Requests) == 0 { slog.Debug("Queue empty") return true } slog.Debug("Something in queue") return false } func zeroTTWOrWait() { timeoutchan := make(chan bool) ttwLock.RLock() if timeToWait > 0*time.Millisecond { slog.Debug("Waiting ttw") go func() { ttwLock.RLock() ttw := timeToWait ttwLock.RUnlock() <-time.After(ttw + backOff) timeoutchan <- true }() } else { ttwLock.RUnlock() return } <-timeoutchan ttwLock.Lock() timeToWait = 0 * time.Millisecond ttwLock.Unlock() slog.Debug("Waited ttw") } func isTesting() bool { return os.Getenv("GO_ENV_TESTING") == "1" }