1
1
Fork 1
mirror of https://github.com/go-gitea/gitea.git synced 2024-05-09 03:46:06 +02:00

adjust websocket

This commit is contained in:
Anbraten 2024-04-06 16:22:35 +02:00
parent 7cb8e2f37c
commit 602a42a70e
6 changed files with 187 additions and 66 deletions

59
services/pubsub/memory.go Normal file
View File

@ -0,0 +1,59 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package pubsub
import (
"context"
"sync"
)
type Memory struct {
sync.Mutex
topics map[string]map[*Subscriber]struct{}
}
// New creates an in-memory publisher.
func NewMemory() Broker {
return &Memory{
topics: make(map[string]map[*Subscriber]struct{}),
}
}
func (p *Memory) Publish(_ context.Context, message Message) {
p.Lock()
topic, ok := p.topics[message.Topic]
if !ok {
p.Unlock()
return
}
for s := range topic {
go (*s)(message)
}
p.Unlock()
}
func (p *Memory) Subscribe(c context.Context, topic string, subscriber Subscriber) {
// Subscribe
p.Lock()
_, ok := p.topics[topic]
if !ok {
p.topics[topic] = make(map[*Subscriber]struct{})
}
p.topics[topic][&subscriber] = struct{}{}
p.Unlock()
// Wait for context to be done
<-c.Done()
// Unsubscribe
p.Lock()
delete(p.topics[topic], &subscriber)
if len(p.topics[topic]) == 0 {
delete(p.topics, topic)
}
p.Unlock()
}

View File

@ -0,0 +1,47 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package pubsub
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPubsub(t *testing.T) {
var (
wg sync.WaitGroup
testMessage = Message{
Data: []byte("test"),
Topic: "hello-world",
}
)
ctx, cancel := context.WithCancelCause(
context.Background(),
)
broker := NewMemory()
go func() {
broker.Subscribe(ctx, "hello-world", func(message Message) { assert.Equal(t, testMessage, message); wg.Done() })
}()
go func() {
broker.Subscribe(ctx, "hello-world", func(_ Message) { wg.Done() })
}()
// Wait a bit for the subscriptions to be registered
<-time.After(100 * time.Millisecond)
wg.Add(2)
go func() {
broker.Publish(ctx, testMessage)
}()
wg.Wait()
cancel(nil)
}

23
services/pubsub/types.go Normal file
View File

@ -0,0 +1,23 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package pubsub
import "context"
// Message defines a published message.
type Message struct {
// Data is the actual data in the entry.
Data []byte `json:"data"`
// Topic is the topic of the message.
Topic string `json:"topic"`
}
// Subscriber receives published messages.
type Subscriber func(Message)
type Broker interface {
Publish(c context.Context, message Message)
Subscribe(c context.Context, topic string, subscriber Subscriber)
}

View File

@ -5,45 +5,22 @@ package websocket
import (
"context"
"encoding/json"
"fmt"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/models/perm/access"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"github.com/olahol/melody"
"code.gitea.io/gitea/services/pubsub"
)
func (n *websocketNotifier) filterIssueSessions(ctx context.Context, repo *repo_model.Repository, issue *issues_model.Issue) []*melody.Session {
return n.filterSessions(func(s *melody.Session, data *sessionData) bool {
// if the user is watching the issue, they will get notifications
if !data.isOnURL(fmt.Sprintf("/%s/%s/issues/%d", repo.Owner.Name, repo.Name, issue.Index)) {
return false
}
func (n *websocketNotifier) DeleteComment(ctx context.Context, doer *user_model.User, c *issues_model.Comment) {
d, err := json.Marshal(c)
if err != nil {
return
}
// the user will get notifications if they have access to the repos issues
hasAccess, err := access.HasAccessUnit(ctx, data.user, repo, unit.TypeIssues, perm.AccessModeRead)
if err != nil {
log.Error("Failed to check access: %v", err)
return false
}
return hasAccess
n.pubsub.Publish(ctx, pubsub.Message{
Data: d,
Topic: fmt.Sprintf("repo:%s/%s", c.RefRepo.OwnerName, c.RefRepo.Name),
})
}
func (n *websocketNotifier) DeleteComment(ctx context.Context, doer *user_model.User, c *issues_model.Comment) {
sessions := n.filterIssueSessions(ctx, c.Issue.Repo, c.Issue)
for _, s := range sessions {
msg := fmt.Sprintf(htmxRemoveElement, fmt.Sprintf("#%s", c.HashTag()))
err := s.Write([]byte(msg))
if err != nil {
log.Error("Failed to write to session: %v", err)
}
}
}

View File

@ -4,17 +4,18 @@
package websocket
import (
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/templates"
notify_service "code.gitea.io/gitea/services/notify"
"code.gitea.io/gitea/services/pubsub"
"github.com/olahol/melody"
)
type websocketNotifier struct {
notify_service.NullNotifier
m *melody.Melody
rnd *templates.HTMLRender
m *melody.Melody
rnd *templates.HTMLRender
pubsub pubsub.Broker
}
// NewNotifier create a new webhooksNotifier notifier
@ -29,25 +30,3 @@ func newNotifier(m *melody.Melody) notify_service.Notifier {
// htmxUpdateElement = "<div hx-swap-oob=\"outerHTML:%s\">%s</div>"
var htmxRemoveElement = "<div hx-swap-oob=\"delete:%s\"></div>"
func (n *websocketNotifier) filterSessions(fn func(*melody.Session, *sessionData) bool) []*melody.Session {
sessions, err := n.m.Sessions()
if err != nil {
log.Error("Failed to get sessions: %v", err)
return nil
}
_sessions := make([]*melody.Session, 0, len(sessions))
for _, s := range sessions {
data, err := getSessionData(s)
if err != nil {
continue
}
if fn(s, data) {
_sessions = append(_sessions, s)
}
}
return _sessions
}

View File

@ -4,9 +4,17 @@
package websocket
import (
goContext "context"
"fmt"
"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/models/perm/access"
"code.gitea.io/gitea/models/unit"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/services/context"
notify_service "code.gitea.io/gitea/services/notify"
"code.gitea.io/gitea/services/pubsub"
"github.com/mitchellh/mapstructure"
"github.com/olahol/melody"
@ -20,19 +28,26 @@ type websocketMessage struct {
}
type subscribeMessageData struct {
URL string `json:"url"`
Repo string `json:"repo"`
}
func Init() *melody.Melody {
m = melody.New()
m.HandleConnect(handleConnect)
m.HandleMessage(handleMessage)
hub := &hub{
pubsub: pubsub.NewMemory(),
}
m.HandleConnect(hub.handleConnect)
m.HandleMessage(hub.handleMessage)
m.HandleDisconnect(handleDisconnect)
notify_service.RegisterNotifier(newNotifier(m))
return m
}
func handleConnect(s *melody.Session) {
type hub struct {
pubsub pubsub.Broker
}
func (h *hub) handleConnect(s *melody.Session) {
ctx := context.GetWebContext(s.Request)
data := &sessionData{}
@ -45,7 +60,7 @@ func handleConnect(s *melody.Session) {
// TODO: handle logouts
}
func handleMessage(s *melody.Session, _msg []byte) {
func (h *hub) handleMessage(s *melody.Session, _msg []byte) {
data, err := getSessionData(s)
if err != nil {
return
@ -59,21 +74,42 @@ func handleMessage(s *melody.Session, _msg []byte) {
switch msg.Action {
case "subscribe":
err := handleSubscribeMessage(data, msg.Data)
err := h.handleSubscribeMessage(s, data, msg.Data)
if err != nil {
return
}
}
}
func handleSubscribeMessage(data *sessionData, _data any) error {
func (h *hub) handleSubscribeMessage(s *melody.Session, data *sessionData, _data any) error {
msgData := &subscribeMessageData{}
err := mapstructure.Decode(_data, &msgData)
if err != nil {
return err
}
data.onURL = msgData.URL
ctx := goContext.Background() // TODO: use proper context
h.pubsub.Subscribe(ctx, msgData.Repo, func(msg pubsub.Message) {
if data.user != nil {
return
}
// TODO: check permissions
hasAccess, err := access.HasAccessUnit(ctx, data.user, repo, unit.TypeIssues, perm.AccessModeRead)
if err != nil {
log.Error("Failed to check access: %v", err)
return
}
if !hasAccess {
return
}
// TODO: check the actual data received from pubsub and send it correctly to the client
d := fmt.Sprintf(htmxRemoveElement, fmt.Sprintf("#%s", c.HashTag()))
_ = s.Write([]byte(d))
})
return nil
}