go: add basic hibp handling, requests scheduling
All checks were successful
continuous-integration/drone/push Build is passing

* change hibp schema's date field to string, as the date format would
  prevent direct unmarshaling. instead, marshal to string, convert later
* the scheduler is in place in order not to get throttled after going
  over API limit
* the scheduler detects when in testing mode and changes little bits of
  behaviour
* add tests for some basic requests
* run the requests scheduler as a background service during testing
This commit is contained in:
surtur 2023-08-22 19:57:48 +02:00
parent 3077eb80c6
commit f2025395b2
Signed by: wanderer
SSH Key Fingerprint: SHA256:MdCZyJ2sHLltrLBp0xQO0O1qTW9BT/xl5nXkDvhlMCI
5 changed files with 578 additions and 1 deletions

@ -60,6 +60,7 @@ var cleantgt = []string{
"PCMT_SESSION_AUTH_SECRET",
"PCMT_SESSION_ENCR_SECRET",
"PCMT_INIT_ADMIN_PASSWORD",
"PCMT_HIBP_API_KEY",
}
// New returns a new instance of the settings struct.

@ -27,7 +27,8 @@ type HIBPSchema struct {
// The domain of the primary website the breach occurred on. This may be used for identifying other assets external systems may have for the site.
Domain string `json:"Domain"`
// The date (with no time) the breach originally occurred on in ISO 8601 format. This is not always accurate — frequently breaches are discovered and reported long after the original incident. Use this attribute as a guide only.
BreachDate time.Time `json:"BreachDate"`
// YY-MM-DD -> marshal to string, convert to proper time later.
BreachDate string `json:"BreachDate"`
// The date and time (precision to the minute) the breach was added to the system in ISO 8601 format.
AddedDate time.Time `json:"AddedDate"`
// The date and time (precision to the minute) the breach was modified in ISO 8601 format. This will only differ from the AddedDate attribute if other attributes represented here are changed or data in the breach itself is changed (i.e. additional data is identified and loaded). It is always either equal to or greater then the AddedDate attribute, never less than.

130
modules/hibp/hibp.go Normal file

@ -0,0 +1,130 @@
// Copyright 2023 wanderer <a_mirre at utb dot cz>
// SPDX-License-Identifier: AGPL-3.0-only
package hibp
import (
"encoding/json"
"io"
"log"
"net/http"
"os"
"time"
"git.dotya.ml/mirre-mt/pcmt/ent/schema"
"golang.org/x/exp/slog"
)
// Subscription models the HIBP subscription struct.
type Subscription struct {
// The name representing the subscription being either "Pwned 1", "Pwned 2", "Pwned 3" or "Pwned 4".
SubscriptionName string
// A human readable sentence explaining the scope of the subscription.
Description string
// The date and time the current subscription ends in ISO 8601 format.
SubscribedUntil time.Time
// The rate limit in requests per minute. This applies to the rate the breach search by email address API can be requested.
Rpm int
// The size of the largest domain the subscription can search. This is expressed in the total number of breached accounts on the domain, excluding those that appear solely in spam list.
DomainSearchMaxBreachedAccounts int
}
const (
api = "https://haveibeenpwned.com/api/v3"
appID = "pcmt (https://git.dotya.ml/mirre-mt/pcmt)"
// set default request timeout so as not to hang forever.
reqTmOut = 5 * time.Second
headerUA = "user-agent"
headerHIBP = "hibp-api-key"
)
var (
apiKey = os.Getenv("PCMT_HIBP_API_KEY")
client = &http.Client{Timeout: reqTmOut}
)
// SubscriptionStatus models https://haveibeenpwned.com/API/v3#SubscriptionStatus.
func SubscriptionStatus() (*Subscription, error) {
u := api + "/subscription"
req, err := http.NewRequest("GET", u, nil)
if err != nil {
log.Fatalln(err)
}
setUA(req)
setAuthHeader(req)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
// bodyString := string(body)
// fmt.Println("API Response as a string:\n" + bodyString)
var s Subscription
if err := json.Unmarshal(body, &s); err != nil {
return nil, err
}
// fmt.Printf("Subscription struct %+v\n", s)
return &Subscription{}, nil
}
// GetAllBreaches retrieves all breaches available in HIBP, as per
// https://haveibeenpwned.com/API/v3#AllBreaches.
func GetAllBreaches() (*[]schema.HIBPSchema, error) {
u := api + "/breaches"
req, err := http.NewRequest("GET", u, nil)
if err != nil {
log.Fatalln(err)
}
respCh, errCh := rChans()
setUA(req)
slog.Info("scheduling all breaches")
scheduleReq(req, &respCh, &errCh)
slog.Info("scheduled all breaches")
resp := <-respCh
err = <-errCh
defer resp.Body.Close()
if err != nil {
return nil, err
}
body, _ := io.ReadAll(resp.Body)
// bodyString := string(body)
// fmt.Println("API Response as a string:\n" + bodyString)
ab := make([]schema.HIBPSchema, 0)
if err = json.Unmarshal(body, &ab); err != nil {
return nil, err
}
return &ab, nil
}
func setUA(r *http.Request) {
r.Header.Set(headerUA, appID)
}
func setAuthHeader(r *http.Request) {
r.Header.Set(headerHIBP, apiKey)
}
func rChans() (chan *http.Response, chan error) {
return make(chan *http.Response), make(chan error)
}

90
modules/hibp/hibp_test.go Normal file

@ -0,0 +1,90 @@
// Copyright 2023 wanderer <a_mirre at utb dot cz>
// SPDX-License-Identifier: AGPL-3.0-only
package hibp
import (
"io"
"net/http"
"os"
"os/signal"
"sync"
"testing"
"time"
)
const hibpTestDomain = "@hibp-integration-tests.com"
func init() {
os.Setenv("GO_ENV_TESTING", "1")
prepScheduler() // purposely ignoring the output values.
}
func prepScheduler() (*chan os.Signal, *chan error, *sync.WaitGroup) {
quit := make(chan os.Signal, 1)
errCh := make(chan error)
wg := &sync.WaitGroup{}
signal.Notify(quit, os.Interrupt)
schedInitNapDuration = 0 * time.Millisecond
go func() {
wg.Add(1) // nolint:staticcheck
go RunReqScheduler(quit, errCh, wg) //nolint:wsl
}()
// give the scheduler some time to start up.
time.Sleep(200 * time.Millisecond)
return &quit, &errCh, wg
}
func TestMultipleBreaches(t *testing.T) {
a := "multiple-breaches"
u := api + "/breachedaccount/" + a + hibpTestDomain
req, err := http.NewRequest("GET", u, nil)
if err != nil {
t.Error(err)
}
respCh, errCh := rChans()
setUA(req)
scheduleReq(req, &respCh, &errCh)
resp := <-respCh
err = <-errCh
defer resp.Body.Close()
if err != nil {
t.Error(err)
}
t.Logf("%+v\n", resp)
body, _ := io.ReadAll(resp.Body)
b := string(body)
t.Log(b)
if resp.StatusCode != 200 {
if apiKey == "" {
if resp.StatusCode == 401 {
t.Logf("apiKey is empty, expected a 401")
} else {
t.Errorf("apiKey is empty, expected 401, got: %q", resp.Status)
}
} else {
t.Errorf("wanted 200, got: %q", resp.Status)
}
}
}
func TestGetAllBreaches(t *testing.T) {
_, err := GetAllBreaches()
if err != nil {
t.Errorf("error: %q", err)
}
}

355
modules/hibp/schedule.go Normal file

@ -0,0 +1,355 @@
// Copyright 2023 wanderer <a_mirre at utb dot cz>
// SPDX-License-Identifier: AGPL-3.0-only
// implement request scheduling in order to attempt to comply with the
// HIBP's rate-limiting.
package hibp
import (
"errors"
"net/http"
"os"
"os/signal"
"sync"
"time"
"golang.org/x/exp/slog"
)
type reqQueue struct {
Requests []*http.Request
RespChans []*chan *http.Response
ErrChans []*chan error
}
const sch = "HIBP requests scheduler"
var (
// requests-per-minute limit; consult https://haveibeenpwned.com/API/Key.
rpm = 10
requestsFired int
working bool
timeToWait = 0 * time.Millisecond
backOff = 50 * time.Millisecond
schedulerPeriod = 2000 * time.Millisecond
schedInitNapDuration = 2000 * time.Millisecond
nextCheckpoint time.Time
queue = &reqQueue{}
// requestsFired mutex.
rfLock sync.RWMutex
// timeToWait mutex.
ttwLock sync.RWMutex
// requests queue mutex.
rqqLock sync.RWMutex
// nextCheckpoint mutex.
nxckpLock sync.RWMutex
// working state mutex.
wLock sync.RWMutex
)
// RunReqScheduler runs the HIBP requests scheduler, which schedules the
// request in such a fashion that it does not cross the limit defined by the
// used API key.
func RunReqScheduler(quit chan os.Signal, errCh chan error, wg *sync.WaitGroup) {
slog.Info("Hello from " + sch)
initQueue()
done := false
ok := make(chan error)
timeoutchan := make(chan bool)
defer close(quit)
defer signal.Stop(quit)
if !isTesting() {
defer close(ok)
defer close(timeoutchan)
}
go func() {
// sleep intermittently, report to chan when done.
<-time.After(schedInitNapDuration)
timeoutchan <- true
}()
select {
// listen for signal interrupts from the start.
case <-quit:
slog.Info("Interrupt received, shutting down " + sch)
// this line below is a trick: we'd need to wait here for the entire
// schedNapDuration, however, since the call is executed inside a goroutine
// that is not synchronised, the channel read gets spawned and the
// program simply continues, not caring one bit that the timeout has
// not yet passed. that is a great workaround for when we *do not*
// actually want to stand here and wait that time out.
go func() { <-timeoutchan }()
wg.Done()
// break
return
// after we've had a nap, it's time to work.
case <-timeoutchan:
slog.Info("Starting " + sch)
setNextCheckpoint(time.Now().Add(time.Minute))
for !done {
checkNextCheckpoint()
go doReqs(ok, wg) // nolint:wsl
select {
case <-quit:
done = true
if os.Getenv("GO_ENV_TESTING") == "1" {
slog.Info("Testing, shutting down " + sch)
wg.Done()
return
}
err := <-ok
errCh <- err
slog.Info("Interrupt received, shutting down " + sch)
break
case err := <-ok:
if err != nil {
// alternatively, only notify that there was an error:
// slog.Error(sch + " encountered an error", "error", err)
slog.Error("Shutting down "+sch+" due to an error",
"error", err)
errCh <- err
wg.Done()
return
}
slog.Debug("Requests scheduler sleeping")
time.Sleep(schedulerPeriod)
slog.Debug("Requests scheduler awake again")
continue
}
}
}
wg.Done()
slog.Info(sch + " done")
}
func doReqs(ok chan error, wg *sync.WaitGroup) {
if isWorking() {
slog.Debug("Already WIP")
return
}
setWorking(true)
zeroTTWOrWait()
wg.Add(1)
go func() {
defer wg.Done()
if zeroQueue() {
setWorking(false)
slog.Debug("Queue empty, nothing to do")
return
}
rqqLock.Lock()
rfLock.Lock()
for i := range queue.Requests {
// again:
if requestsFired < rpm {
req := queue.Requests[i]
respCh := queue.RespChans[i]
errCh := queue.ErrChans[i]
requestsFired++
slog.Debug("Sending the request")
resp, err := client.Do(req)
if err != nil {
slog.Error("got err", "error", err)
if errors.Is(err, http.ErrHandlerTimeout) {
ok <- err
rqqLock.Unlock()
rfLock.Unlock()
setWorking(false)
return // alternatively: goto again
}
}
*respCh <- resp
*errCh <- err
// remove the performed request and shift the remaining elements left.
queue.Requests = append(queue.Requests[:i], queue.Requests[i+1:]...)
// remove the corresponding response chan.
queue.RespChans = append(queue.RespChans[:i], queue.RespChans[i+1:]...)
// remove the corresponding error chan.
queue.ErrChans = append(queue.ErrChans[:i], queue.ErrChans[i+1:]...)
} else {
slog.Error("Throttled - setting time to wait")
ttwLock.Lock()
nxckpLock.Lock()
timeToWait = time.Until(nextCheckpoint)
ttwLock.Unlock()
nxckpLock.Unlock()
break
}
}
rfLock.Unlock()
rqqLock.Unlock()
}()
setWorking(false)
ok <- nil
}
// scheduleReq schedules a HTTP requests. respCh and errCh are the channels to
// send back *http.Response and error, respectively.
func scheduleReq(r *http.Request, respCh *chan *http.Response, errCh *chan error) {
go func() {
rqqLock.Lock()
slog.Debug("Adding req to queue", "method", r.Method, "url", r.URL, "ua", r.UserAgent())
queue.Requests = append(queue.Requests, r)
queue.RespChans = append(queue.RespChans, respCh)
queue.ErrChans = append(queue.ErrChans, errCh)
rqqLock.Unlock()
slog.Debug("Added req to queue")
}()
}
func setWorking(w bool) {
wLock.Lock()
defer wLock.Unlock()
working = w
}
func isWorking() bool {
wLock.RLock()
defer wLock.RUnlock()
return working
}
func setNextCheckpoint(t time.Time) {
nxckpLock.Lock()
defer nxckpLock.Unlock()
slog.Debug("Setting next checkpoint")
nextCheckpoint = t
}
func checkNextCheckpoint() {
nxckpLock.Lock()
defer nxckpLock.Unlock()
if nextCheckpoint.Before(time.Now()) {
slog.Debug("Checkpoint passed, updating", "checkpoint", nextCheckpoint)
nextCheckpoint = time.Now().Add(1 * time.Minute)
slog.Debug("Checkpoint updated", "checkpoint", nextCheckpoint)
}
}
func initQueue() {
slog.Debug("initialising the queue")
rqqLock.Lock()
defer rqqLock.Unlock()
if isTesting() || queue.Requests == nil {
slog.Debug("initialising queue.Requests")
queue.Requests = make([]*http.Request, 0)
}
if isTesting() || queue.RespChans == nil {
slog.Debug("initialising queue.RespChans")
queue.RespChans = make([]*chan *http.Response, 0)
}
if isTesting() || queue.ErrChans == nil {
slog.Debug("initialising queue.ErrChans")
queue.ErrChans = make([]*chan error, 0)
}
}
func zeroQueue() bool {
rqqLock.RLock()
defer rqqLock.RUnlock()
if len(queue.Requests) == 0 {
slog.Debug("Queue empty")
return true
}
slog.Debug("Something in queue")
return false
}
func zeroTTWOrWait() {
timeoutchan := make(chan bool)
ttwLock.RLock()
if timeToWait > 0*time.Millisecond {
slog.Debug("Waiting ttw")
go func() {
ttwLock.RLock()
ttw := timeToWait
ttwLock.RUnlock()
<-time.After(ttw + backOff)
timeoutchan <- true
}()
} else {
ttwLock.RUnlock()
return
}
<-timeoutchan
ttwLock.Lock()
timeToWait = 0 * time.Millisecond
ttwLock.Unlock()
slog.Debug("Waited ttw")
}
func isTesting() bool {
return os.Getenv("GO_ENV_TESTING") == "1"
}