From f2025395b26ff0faa6d2af257182b54f9343a94e Mon Sep 17 00:00:00 2001 From: surtur Date: Tue, 22 Aug 2023 19:57:48 +0200 Subject: [PATCH] go: add basic hibp handling, requests scheduling * change hibp schema's date field to string, as the date format would prevent direct unmarshaling. instead, marshal to string, convert later * the scheduler is in place in order not to get throttled after going over API limit * the scheduler detects when in testing mode and changes little bits of behaviour * add tests for some basic requests * run the requests scheduler as a background service during testing --- app/settings/settings.go | 1 + ent/schema/hibp.go | 3 +- modules/hibp/hibp.go | 130 ++++++++++++++ modules/hibp/hibp_test.go | 90 ++++++++++ modules/hibp/schedule.go | 355 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 578 insertions(+), 1 deletion(-) create mode 100644 modules/hibp/hibp.go create mode 100644 modules/hibp/hibp_test.go create mode 100644 modules/hibp/schedule.go diff --git a/app/settings/settings.go b/app/settings/settings.go index d124533..df7f4d1 100644 --- a/app/settings/settings.go +++ b/app/settings/settings.go @@ -60,6 +60,7 @@ var cleantgt = []string{ "PCMT_SESSION_AUTH_SECRET", "PCMT_SESSION_ENCR_SECRET", "PCMT_INIT_ADMIN_PASSWORD", + "PCMT_HIBP_API_KEY", } // New returns a new instance of the settings struct. diff --git a/ent/schema/hibp.go b/ent/schema/hibp.go index c8d2568..2a9e229 100644 --- a/ent/schema/hibp.go +++ b/ent/schema/hibp.go @@ -27,7 +27,8 @@ type HIBPSchema struct { // The domain of the primary website the breach occurred on. This may be used for identifying other assets external systems may have for the site. Domain string `json:"Domain"` // The date (with no time) the breach originally occurred on in ISO 8601 format. This is not always accurate — frequently breaches are discovered and reported long after the original incident. Use this attribute as a guide only. - BreachDate time.Time `json:"BreachDate"` + // YY-MM-DD -> marshal to string, convert to proper time later. + BreachDate string `json:"BreachDate"` // The date and time (precision to the minute) the breach was added to the system in ISO 8601 format. AddedDate time.Time `json:"AddedDate"` // The date and time (precision to the minute) the breach was modified in ISO 8601 format. This will only differ from the AddedDate attribute if other attributes represented here are changed or data in the breach itself is changed (i.e. additional data is identified and loaded). It is always either equal to or greater then the AddedDate attribute, never less than. diff --git a/modules/hibp/hibp.go b/modules/hibp/hibp.go new file mode 100644 index 0000000..4bec9d8 --- /dev/null +++ b/modules/hibp/hibp.go @@ -0,0 +1,130 @@ +// Copyright 2023 wanderer +// SPDX-License-Identifier: AGPL-3.0-only + +package hibp + +import ( + "encoding/json" + "io" + "log" + "net/http" + "os" + "time" + + "git.dotya.ml/mirre-mt/pcmt/ent/schema" + "golang.org/x/exp/slog" +) + +// Subscription models the HIBP subscription struct. +type Subscription struct { + // The name representing the subscription being either "Pwned 1", "Pwned 2", "Pwned 3" or "Pwned 4". + SubscriptionName string + // A human readable sentence explaining the scope of the subscription. + Description string + // The date and time the current subscription ends in ISO 8601 format. + SubscribedUntil time.Time + // The rate limit in requests per minute. This applies to the rate the breach search by email address API can be requested. + Rpm int + // The size of the largest domain the subscription can search. This is expressed in the total number of breached accounts on the domain, excluding those that appear solely in spam list. + DomainSearchMaxBreachedAccounts int +} + +const ( + api = "https://haveibeenpwned.com/api/v3" + appID = "pcmt (https://git.dotya.ml/mirre-mt/pcmt)" + // set default request timeout so as not to hang forever. + reqTmOut = 5 * time.Second + + headerUA = "user-agent" + headerHIBP = "hibp-api-key" +) + +var ( + apiKey = os.Getenv("PCMT_HIBP_API_KEY") + client = &http.Client{Timeout: reqTmOut} +) + +// SubscriptionStatus models https://haveibeenpwned.com/API/v3#SubscriptionStatus. +func SubscriptionStatus() (*Subscription, error) { + u := api + "/subscription" + + req, err := http.NewRequest("GET", u, nil) + if err != nil { + log.Fatalln(err) + } + + setUA(req) + setAuthHeader(req) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + // bodyString := string(body) + // fmt.Println("API Response as a string:\n" + bodyString) + + var s Subscription + + if err := json.Unmarshal(body, &s); err != nil { + return nil, err + } + // fmt.Printf("Subscription struct %+v\n", s) + + return &Subscription{}, nil +} + +// GetAllBreaches retrieves all breaches available in HIBP, as per +// https://haveibeenpwned.com/API/v3#AllBreaches. +func GetAllBreaches() (*[]schema.HIBPSchema, error) { + u := api + "/breaches" + + req, err := http.NewRequest("GET", u, nil) + if err != nil { + log.Fatalln(err) + } + + respCh, errCh := rChans() + + setUA(req) + slog.Info("scheduling all breaches") + scheduleReq(req, &respCh, &errCh) + slog.Info("scheduled all breaches") + + resp := <-respCh + err = <-errCh + + defer resp.Body.Close() + + if err != nil { + return nil, err + } + + body, _ := io.ReadAll(resp.Body) + + // bodyString := string(body) + // fmt.Println("API Response as a string:\n" + bodyString) + + ab := make([]schema.HIBPSchema, 0) + + if err = json.Unmarshal(body, &ab); err != nil { + return nil, err + } + + return &ab, nil +} + +func setUA(r *http.Request) { + r.Header.Set(headerUA, appID) +} + +func setAuthHeader(r *http.Request) { + r.Header.Set(headerHIBP, apiKey) +} + +func rChans() (chan *http.Response, chan error) { + return make(chan *http.Response), make(chan error) +} diff --git a/modules/hibp/hibp_test.go b/modules/hibp/hibp_test.go new file mode 100644 index 0000000..e2ed9f2 --- /dev/null +++ b/modules/hibp/hibp_test.go @@ -0,0 +1,90 @@ +// Copyright 2023 wanderer +// SPDX-License-Identifier: AGPL-3.0-only + +package hibp + +import ( + "io" + "net/http" + "os" + "os/signal" + "sync" + "testing" + "time" +) + +const hibpTestDomain = "@hibp-integration-tests.com" + +func init() { + os.Setenv("GO_ENV_TESTING", "1") + prepScheduler() // purposely ignoring the output values. +} + +func prepScheduler() (*chan os.Signal, *chan error, *sync.WaitGroup) { + quit := make(chan os.Signal, 1) + errCh := make(chan error) + wg := &sync.WaitGroup{} + + signal.Notify(quit, os.Interrupt) + + schedInitNapDuration = 0 * time.Millisecond + + go func() { + wg.Add(1) // nolint:staticcheck + go RunReqScheduler(quit, errCh, wg) //nolint:wsl + }() + + // give the scheduler some time to start up. + time.Sleep(200 * time.Millisecond) + + return &quit, &errCh, wg +} + +func TestMultipleBreaches(t *testing.T) { + a := "multiple-breaches" + u := api + "/breachedaccount/" + a + hibpTestDomain + + req, err := http.NewRequest("GET", u, nil) + if err != nil { + t.Error(err) + } + + respCh, errCh := rChans() + + setUA(req) + scheduleReq(req, &respCh, &errCh) + + resp := <-respCh + err = <-errCh + + defer resp.Body.Close() + + if err != nil { + t.Error(err) + } + + t.Logf("%+v\n", resp) + + body, _ := io.ReadAll(resp.Body) + b := string(body) + t.Log(b) + + if resp.StatusCode != 200 { + if apiKey == "" { + if resp.StatusCode == 401 { + t.Logf("apiKey is empty, expected a 401") + } else { + t.Errorf("apiKey is empty, expected 401, got: %q", resp.Status) + } + } else { + t.Errorf("wanted 200, got: %q", resp.Status) + } + } +} + +func TestGetAllBreaches(t *testing.T) { + _, err := GetAllBreaches() + if err != nil { + t.Errorf("error: %q", err) + } +} diff --git a/modules/hibp/schedule.go b/modules/hibp/schedule.go new file mode 100644 index 0000000..b9f53d3 --- /dev/null +++ b/modules/hibp/schedule.go @@ -0,0 +1,355 @@ +// Copyright 2023 wanderer +// SPDX-License-Identifier: AGPL-3.0-only + +// implement request scheduling in order to attempt to comply with the +// HIBP's rate-limiting. +package hibp + +import ( + "errors" + "net/http" + "os" + "os/signal" + "sync" + "time" + + "golang.org/x/exp/slog" +) + +type reqQueue struct { + Requests []*http.Request + RespChans []*chan *http.Response + ErrChans []*chan error +} + +const sch = "HIBP requests scheduler" + +var ( + // requests-per-minute limit; consult https://haveibeenpwned.com/API/Key. + rpm = 10 + requestsFired int + working bool + + timeToWait = 0 * time.Millisecond + backOff = 50 * time.Millisecond + schedulerPeriod = 2000 * time.Millisecond + schedInitNapDuration = 2000 * time.Millisecond + nextCheckpoint time.Time + + queue = &reqQueue{} + + // requestsFired mutex. + rfLock sync.RWMutex + // timeToWait mutex. + ttwLock sync.RWMutex + // requests queue mutex. + rqqLock sync.RWMutex + // nextCheckpoint mutex. + nxckpLock sync.RWMutex + // working state mutex. + wLock sync.RWMutex +) + +// RunReqScheduler runs the HIBP requests scheduler, which schedules the +// request in such a fashion that it does not cross the limit defined by the +// used API key. +func RunReqScheduler(quit chan os.Signal, errCh chan error, wg *sync.WaitGroup) { + slog.Info("Hello from " + sch) + + initQueue() + + done := false + ok := make(chan error) + timeoutchan := make(chan bool) + + defer close(quit) + defer signal.Stop(quit) + + if !isTesting() { + defer close(ok) + defer close(timeoutchan) + } + + go func() { + // sleep intermittently, report to chan when done. + <-time.After(schedInitNapDuration) + timeoutchan <- true + }() + + select { + // listen for signal interrupts from the start. + case <-quit: + slog.Info("Interrupt received, shutting down " + sch) + // this line below is a trick: we'd need to wait here for the entire + // schedNapDuration, however, since the call is executed inside a goroutine + // that is not synchronised, the channel read gets spawned and the + // program simply continues, not caring one bit that the timeout has + // not yet passed. that is a great workaround for when we *do not* + // actually want to stand here and wait that time out. + go func() { <-timeoutchan }() + + wg.Done() + // break + + return + + // after we've had a nap, it's time to work. + case <-timeoutchan: + slog.Info("Starting " + sch) + setNextCheckpoint(time.Now().Add(time.Minute)) + + for !done { + checkNextCheckpoint() + go doReqs(ok, wg) // nolint:wsl + + select { + case <-quit: + done = true + + if os.Getenv("GO_ENV_TESTING") == "1" { + slog.Info("Testing, shutting down " + sch) + wg.Done() + + return + } + + err := <-ok + errCh <- err + + slog.Info("Interrupt received, shutting down " + sch) + + break + + case err := <-ok: + if err != nil { + // alternatively, only notify that there was an error: + // slog.Error(sch + " encountered an error", "error", err) + slog.Error("Shutting down "+sch+" due to an error", + "error", err) + + errCh <- err + + wg.Done() + + return + } + + slog.Debug("Requests scheduler sleeping") + time.Sleep(schedulerPeriod) + slog.Debug("Requests scheduler awake again") + + continue + } + } + } + + wg.Done() + slog.Info(sch + " done") +} + +func doReqs(ok chan error, wg *sync.WaitGroup) { + if isWorking() { + slog.Debug("Already WIP") + return + } + + setWorking(true) + zeroTTWOrWait() + + wg.Add(1) + + go func() { + defer wg.Done() + + if zeroQueue() { + setWorking(false) + + slog.Debug("Queue empty, nothing to do") + + return + } + + rqqLock.Lock() + rfLock.Lock() + + for i := range queue.Requests { + // again: + if requestsFired < rpm { + req := queue.Requests[i] + respCh := queue.RespChans[i] + errCh := queue.ErrChans[i] + + requestsFired++ + + slog.Debug("Sending the request") + + resp, err := client.Do(req) + if err != nil { + slog.Error("got err", "error", err) + + if errors.Is(err, http.ErrHandlerTimeout) { + ok <- err + + rqqLock.Unlock() + rfLock.Unlock() + + setWorking(false) + + return // alternatively: goto again + } + } + + *respCh <- resp + *errCh <- err + + // remove the performed request and shift the remaining elements left. + queue.Requests = append(queue.Requests[:i], queue.Requests[i+1:]...) + // remove the corresponding response chan. + queue.RespChans = append(queue.RespChans[:i], queue.RespChans[i+1:]...) + // remove the corresponding error chan. + queue.ErrChans = append(queue.ErrChans[:i], queue.ErrChans[i+1:]...) + } else { + slog.Error("Throttled - setting time to wait") + + ttwLock.Lock() + nxckpLock.Lock() + + timeToWait = time.Until(nextCheckpoint) + + ttwLock.Unlock() + nxckpLock.Unlock() + + break + } + } + + rfLock.Unlock() + rqqLock.Unlock() + }() + + setWorking(false) + ok <- nil +} + +// scheduleReq schedules a HTTP requests. respCh and errCh are the channels to +// send back *http.Response and error, respectively. +func scheduleReq(r *http.Request, respCh *chan *http.Response, errCh *chan error) { + go func() { + rqqLock.Lock() + + slog.Debug("Adding req to queue", "method", r.Method, "url", r.URL, "ua", r.UserAgent()) + + queue.Requests = append(queue.Requests, r) + queue.RespChans = append(queue.RespChans, respCh) + queue.ErrChans = append(queue.ErrChans, errCh) + rqqLock.Unlock() + + slog.Debug("Added req to queue") + }() +} + +func setWorking(w bool) { + wLock.Lock() + defer wLock.Unlock() + + working = w +} + +func isWorking() bool { + wLock.RLock() + defer wLock.RUnlock() + + return working +} + +func setNextCheckpoint(t time.Time) { + nxckpLock.Lock() + defer nxckpLock.Unlock() + + slog.Debug("Setting next checkpoint") + + nextCheckpoint = t +} + +func checkNextCheckpoint() { + nxckpLock.Lock() + defer nxckpLock.Unlock() + + if nextCheckpoint.Before(time.Now()) { + slog.Debug("Checkpoint passed, updating", "checkpoint", nextCheckpoint) + nextCheckpoint = time.Now().Add(1 * time.Minute) + slog.Debug("Checkpoint updated", "checkpoint", nextCheckpoint) + } +} + +func initQueue() { + slog.Debug("initialising the queue") + rqqLock.Lock() + defer rqqLock.Unlock() + + if isTesting() || queue.Requests == nil { + slog.Debug("initialising queue.Requests") + + queue.Requests = make([]*http.Request, 0) + } + + if isTesting() || queue.RespChans == nil { + slog.Debug("initialising queue.RespChans") + + queue.RespChans = make([]*chan *http.Response, 0) + } + + if isTesting() || queue.ErrChans == nil { + slog.Debug("initialising queue.ErrChans") + + queue.ErrChans = make([]*chan error, 0) + } +} + +func zeroQueue() bool { + rqqLock.RLock() + defer rqqLock.RUnlock() + + if len(queue.Requests) == 0 { + slog.Debug("Queue empty") + return true + } + + slog.Debug("Something in queue") + + return false +} + +func zeroTTWOrWait() { + timeoutchan := make(chan bool) + + ttwLock.RLock() + if timeToWait > 0*time.Millisecond { + slog.Debug("Waiting ttw") + + go func() { + ttwLock.RLock() + + ttw := timeToWait + + ttwLock.RUnlock() + + <-time.After(ttw + backOff) + timeoutchan <- true + }() + } else { + ttwLock.RUnlock() + return + } + + <-timeoutchan + ttwLock.Lock() + timeToWait = 0 * time.Millisecond + ttwLock.Unlock() + + slog.Debug("Waited ttw") +} + +func isTesting() bool { + return os.Getenv("GO_ENV_TESTING") == "1" +}