1
0
mirror of https://github.com/joshuarubin/go-sway synced 2024-11-26 08:43:49 +01:00
go-sway/client.go
2019-04-15 10:11:13 -06:00

318 lines
6.7 KiB
Go

package sway
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"strings"
"github.com/joshuarubin/lifecycle"
"go.uber.org/multierr"
)
type client struct {
conn net.Conn
path string
}
// A Client provides simple communication with the sway IPC
type Client interface {
// Runs the payload as sway commands
RunCommand(context.Context, string) ([]RunCommandReply, error)
// Get the list of current workspaces
GetWorkspaces(context.Context) ([]Workspace, error)
// Get the list of current outputs
GetOutputs(context.Context) ([]Output, error)
// Get the node layout tree
GetTree(context.Context) (*Node, error)
// Get the names of all the marks currently set
GetMarks(context.Context) ([]string, error)
// Get the list of configured bar IDs
GetBarIDs(context.Context) ([]string, error)
// Get the specified bar config
GetBarConfig(context.Context, string) (*BarConfig, error)
// Get the version of sway that owns the IPC socket
GetVersion(context.Context) (*Version, error)
// Get the list of binding mode names
GetBindingModes(context.Context) ([]string, error)
// Returns the config that was last loaded
GetConfig(context.Context) (*Config, error)
// Sends a tick event with the specified payload
SendTick(context.Context, string) (*TickReply, error)
// Get the list of input devices
GetInputs(context.Context) ([]Input, error)
// Get the list of seats
GetSeats(context.Context) ([]Seat, error)
}
// Option can be passed to New to specify runtime configuration settings
type Option func(*client)
// WithSocketPath explicitly sets the sway socket path so it isn't read from
// $SWAYSOCK
func WithSocketPath(socketPath string) Option {
return func(c *client) {
c.path = socketPath
}
}
// New returns a Client configured to connect to $SWAYSOCK
func New(ctx context.Context, opts ...Option) (_ Client, err error) {
c := &client{}
for _, opt := range opts {
opt(c)
}
if c.path == "" {
c.path = strings.TrimSpace(os.Getenv("SWAYSOCK"))
}
if c.path == "" {
return nil, fmt.Errorf("$SWAYSOCK is empty")
}
if c.conn, err = (&net.Dialer{}).DialContext(ctx, "unix", c.path); err != nil {
return nil, err
}
if lifecycle.Exists(ctx) {
lifecycle.DeferErr(ctx, c.conn.Close)
} else {
go func() {
<-ctx.Done()
_ = c.conn.Close()
}()
}
return c, nil
}
type payloadReader struct {
io.Reader
}
func (r payloadReader) Close() error {
_, err := ioutil.ReadAll(r)
return err
}
func (c *client) recvMsg(ctx context.Context) (*message, error) {
var h header
err := do(ctx, func() error {
return binary.Read(c.conn, binary.LittleEndian, &h)
})
if err != nil {
return nil, err
}
return &message{
Type: h.Type,
Payload: payloadReader{io.LimitReader(c.conn, int64(h.Length))},
}, nil
}
func (c *client) roundTrip(ctx context.Context, t messageType, payload []byte) (*message, error) {
if c == nil {
return nil, fmt.Errorf("not connected")
}
err := do(ctx, func() error {
err := binary.Write(c.conn, binary.LittleEndian, &header{magic, uint32(len(payload)), t})
if err != nil {
return nil
}
_, err = c.conn.Write(payload)
return err
})
if err != nil {
return nil, err
}
return c.recvMsg(ctx)
}
func (c *client) GetTree(ctx context.Context) (*Node, error) {
b, err := c.roundTrip(ctx, messageTypeGetTree, nil)
if err != nil {
return nil, err
}
var n Node
return &n, b.Decode(&n)
}
func (c *client) subscribe(ctx context.Context, events ...EventType) error {
payload, err := json.Marshal(events)
if err != nil {
return err
}
msg, err := c.roundTrip(ctx, messageTypeSubscribe, payload)
if err != nil {
return err
}
var reply struct {
Success bool `json:"success"`
}
if err = msg.Decode(&reply); err != nil {
return err
}
if !reply.Success {
return fmt.Errorf("subscribe unsuccessful")
}
return nil
}
func (c *client) RunCommand(ctx context.Context, command string) ([]RunCommandReply, error) {
msg, err := c.roundTrip(ctx, messageTypeRunCommand, []byte(command))
if err != nil {
return nil, err
}
var replies []RunCommandReply
if err = msg.Decode(&replies); err != nil {
return nil, err
}
for _, reply := range replies {
if !reply.Success {
err = multierr.Append(err, fmt.Errorf("command %q unsuccessful: %v", command, reply.Error))
}
}
return replies, err
}
func (c *client) GetWorkspaces(ctx context.Context) ([]Workspace, error) {
msg, err := c.roundTrip(ctx, messageTypeGetWorkspaces, nil)
if err != nil {
return nil, err
}
var ret []Workspace
return ret, msg.Decode(&ret)
}
func (c *client) GetOutputs(ctx context.Context) ([]Output, error) {
msg, err := c.roundTrip(ctx, messageTypeGetOutputs, nil)
if err != nil {
return nil, err
}
var ret []Output
return ret, msg.Decode(&ret)
}
func (c *client) GetMarks(ctx context.Context) ([]string, error) {
msg, err := c.roundTrip(ctx, messageTypeGetMarks, nil)
if err != nil {
return nil, err
}
var ret []string
return ret, msg.Decode(&ret)
}
func (c *client) GetBarIDs(ctx context.Context) ([]string, error) {
msg, err := c.roundTrip(ctx, messageTypeGetBarConfig, nil)
if err != nil {
return nil, err
}
var ret []string
return ret, msg.Decode(&ret)
}
func (c *client) GetBarConfig(ctx context.Context, id string) (*BarConfig, error) {
msg, err := c.roundTrip(ctx, messageTypeGetBarConfig, []byte(id))
if err != nil {
return nil, err
}
var ret BarConfig
return &ret, msg.Decode(&ret)
}
func (c *client) GetVersion(ctx context.Context) (*Version, error) {
msg, err := c.roundTrip(ctx, messageTypeGetVersion, nil)
if err != nil {
return nil, err
}
var ret Version
return &ret, msg.Decode(&ret)
}
func (c *client) GetBindingModes(ctx context.Context) ([]string, error) {
msg, err := c.roundTrip(ctx, messageTypeGetBindingModes, nil)
if err != nil {
return nil, err
}
var ret []string
return ret, msg.Decode(&ret)
}
func (c *client) GetConfig(ctx context.Context) (*Config, error) {
msg, err := c.roundTrip(ctx, messageTypeGetConfig, nil)
if err != nil {
return nil, err
}
var ret Config
return &ret, msg.Decode(&ret)
}
func (c *client) SendTick(ctx context.Context, payload string) (*TickReply, error) {
msg, err := c.roundTrip(ctx, messageTypeSendTick, []byte(payload))
if err != nil {
return nil, err
}
var ret TickReply
return &ret, msg.Decode(&ret)
}
func (c *client) GetInputs(ctx context.Context) ([]Input, error) {
msg, err := c.roundTrip(ctx, messageTypeGetInputs, nil)
if err != nil {
return nil, err
}
var ret []Input
return ret, msg.Decode(&ret)
}
func (c *client) GetSeats(ctx context.Context) ([]Seat, error) {
msg, err := c.roundTrip(ctx, messageTypeGetSeats, nil)
if err != nil {
return nil, err
}
var ret []Seat
return ret, msg.Decode(&ret)
}