Skip to content

Commit 56e2adb

Browse files
committed
commands: add generic command processing framework for bots
1 parent 7c1b0c5 commit 56e2adb

File tree

3 files changed

+366
-0
lines changed

3 files changed

+366
-0
lines changed

commands/event.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright (c) 2025 Tulir Asokan
2+
//
3+
// This Source Code Form is subject to the terms of the Mozilla Public
4+
// License, v. 2.0. If a copy of the MPL was not distributed with this
5+
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
7+
package commands
8+
9+
import (
10+
"context"
11+
"fmt"
12+
"strings"
13+
14+
"github.com/rs/zerolog"
15+
16+
"maunium.net/go/mautrix/event"
17+
"maunium.net/go/mautrix/format"
18+
)
19+
20+
// Event contains the data of a single command event.
21+
// It also provides some helper methods for responding to the command.
22+
type Event[MetaType any] struct {
23+
*event.Event
24+
// RawInput is the entire message before splitting into command and arguments.
25+
RawInput string
26+
// Command is the lowercased first word of the message.
27+
Command string
28+
// Args are the rest of the message split by whitespace ([strings.Fields]).
29+
Args []string
30+
// RawArgs is the same as args, but without the splitting by whitespace.
31+
RawArgs string
32+
33+
Ctx context.Context
34+
Proc *Processor[MetaType]
35+
Handler *Handler[MetaType]
36+
Meta MetaType
37+
}
38+
39+
var IDHTMLParser = &format.HTMLParser{
40+
PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
41+
if len(mxid) == 0 {
42+
return displayname
43+
}
44+
if eventID != "" {
45+
return fmt.Sprintf("https://matrix.to/#/%s/%s", mxid, eventID)
46+
}
47+
return mxid
48+
},
49+
ItalicConverter: func(s string, c format.Context) string {
50+
return fmt.Sprintf("*%s*", s)
51+
},
52+
Newline: "\n",
53+
}
54+
55+
// ParseEvent parses a message into a command event struct.
56+
func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] {
57+
content := evt.Content.Parsed.(*event.MessageEventContent)
58+
text := content.Body
59+
if content.Format == event.FormatHTML {
60+
text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx))
61+
}
62+
parts := strings.Fields(text)
63+
return &Event[MetaType]{
64+
Event: evt,
65+
RawInput: text,
66+
Command: strings.ToLower(parts[0]),
67+
Args: parts[1:],
68+
RawArgs: strings.TrimLeft(strings.TrimPrefix(text, parts[0]), " "),
69+
Ctx: ctx,
70+
}
71+
}
72+
73+
type ReplyOpts struct {
74+
AllowHTML bool
75+
AllowMarkdown bool
76+
Reply bool
77+
Thread bool
78+
SendAsText bool
79+
}
80+
81+
func (evt *Event[MetaType]) Reply(msg string, args ...any) {
82+
if len(args) > 0 {
83+
msg = fmt.Sprintf(msg, args...)
84+
}
85+
evt.Respond(msg, ReplyOpts{AllowMarkdown: true, Reply: true})
86+
}
87+
88+
func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) {
89+
content := format.RenderMarkdown(msg, opts.AllowMarkdown, opts.AllowHTML)
90+
if opts.Thread {
91+
content.SetThread(evt.Event)
92+
}
93+
if opts.Reply {
94+
content.SetReply(evt.Event)
95+
}
96+
if !opts.SendAsText {
97+
content.MsgType = event.MsgNotice
98+
}
99+
_, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, content)
100+
if err != nil {
101+
zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply")
102+
}
103+
}
104+
105+
func (evt *Event[MetaType]) React(emoji string) {
106+
_, err := evt.Proc.Client.SendReaction(evt.Ctx, evt.RoomID, evt.ID, emoji)
107+
if err != nil {
108+
zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reaction")
109+
}
110+
}
111+
112+
func (evt *Event[MetaType]) Redact() {
113+
_, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID)
114+
if err != nil {
115+
zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to redact command")
116+
}
117+
}
118+
119+
func (evt *Event[MetaType]) MarkRead() {
120+
err := evt.Proc.Client.MarkRead(evt.Ctx, evt.RoomID, evt.ID)
121+
if err != nil {
122+
zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send read receipt")
123+
}
124+
}

commands/prevalidate.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) 2025 Tulir Asokan
2+
//
3+
// This Source Code Form is subject to the terms of the Mozilla Public
4+
// License, v. 2.0. If a copy of the MPL was not distributed with this
5+
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
7+
package commands
8+
9+
import (
10+
"strings"
11+
)
12+
13+
// A PreValidator contains a function that takes an Event and returns true if the event should be processed further.
14+
//
15+
// The [PreValidator] field in [Processor] is called before the handler of the command is checked.
16+
// It can be used to modify the command or arguments, or to skip the command entirely.
17+
//
18+
// The primary use case is removing a static command prefix, such as requiring all commands start with `!`.
19+
type PreValidator[MetaType any] interface {
20+
Validate(*Event[MetaType]) bool
21+
}
22+
23+
// FuncPreValidator is a simple function that implements the PreValidator interface.
24+
type FuncPreValidator[MetaType any] func(*Event[MetaType]) bool
25+
26+
func (f FuncPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool {
27+
return f(ce)
28+
}
29+
30+
// AllPreValidator can be used to combine multiple PreValidators, such that
31+
// all of them must return true for the command to be processed further.
32+
type AllPreValidator[MetaType any] []PreValidator[MetaType]
33+
34+
func (f AllPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool {
35+
for _, validator := range f {
36+
if !validator.Validate(ce) {
37+
return false
38+
}
39+
}
40+
return true
41+
}
42+
43+
// AnyPreValidator can be used to combine multiple PreValidators, such that
44+
// at least one of them must return true for the command to be processed further.
45+
type AnyPreValidator[MetaType any] []PreValidator[MetaType]
46+
47+
func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool {
48+
for _, validator := range f {
49+
if validator.Validate(ce) {
50+
return true
51+
}
52+
}
53+
return false
54+
}
55+
56+
// ValidatePrefixCommand checks that the first word in the input is exactly the given string,
57+
// and if so, removes it from the command and sets the command to the next word.
58+
//
59+
// For example, `ValidateCommandPrefix("!mybot")` would only allow commands in the form `!mybot foo`,
60+
// where `foo` would be used to look up the command handler.
61+
func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] {
62+
return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool {
63+
if ce.Command == prefix && len(ce.Args) > 0 {
64+
ce.Command = strings.ToLower(ce.Args[0])
65+
ce.RawArgs = strings.TrimLeft(strings.TrimPrefix(ce.RawArgs, ce.Args[0]), " ")
66+
ce.Args = ce.Args[1:]
67+
return true
68+
}
69+
return false
70+
})
71+
}
72+
73+
// ValidatePrefixSubstring checks that the command starts with the given prefix,
74+
// and if so, removes it from the command.
75+
//
76+
// For example, `ValidatePrefixSubstring("!")` would only allow commands in the form `!foo`,
77+
// where `foo` would be used to look up the command handler.
78+
func ValidatePrefixSubstring[MetaType any](prefix string) PreValidator[MetaType] {
79+
return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool {
80+
if strings.HasPrefix(ce.Command, prefix) {
81+
ce.Command = ce.Command[len(prefix):]
82+
return true
83+
}
84+
return false
85+
})
86+
}

commands/processor.go

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright (c) 2025 Tulir Asokan
2+
//
3+
// This Source Code Form is subject to the terms of the Mozilla Public
4+
// License, v. 2.0. If a copy of the MPL was not distributed with this
5+
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
7+
package commands
8+
9+
import (
10+
"context"
11+
"fmt"
12+
"runtime/debug"
13+
"strings"
14+
"sync"
15+
16+
"github.com/rs/zerolog"
17+
18+
"maunium.net/go/mautrix"
19+
"maunium.net/go/mautrix/event"
20+
)
21+
22+
// Processor implements boilerplate code for splitting messages into a command and arguments,
23+
// and finding the appropriate handler for the command.
24+
type Processor[MetaType any] struct {
25+
Client *mautrix.Client
26+
LogArgs bool
27+
PreValidator PreValidator[MetaType]
28+
Meta MetaType
29+
commands map[string]*Handler[MetaType]
30+
aliases map[string]string
31+
lock sync.RWMutex
32+
}
33+
34+
type Handler[MetaType any] struct {
35+
Func func(ce *Event[MetaType])
36+
37+
// Name is the primary name of the command. It must be lowercase.
38+
Name string
39+
// Aliases are alternative names for the command. They must be lowercase.
40+
Aliases []string
41+
}
42+
43+
// UnknownCommandName is the name of the fallback handler which is used if no other handler is found.
44+
// If even the unknown command handler is not found, the command is ignored.
45+
const UnknownCommandName = "unknown-command"
46+
47+
func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] {
48+
proc := &Processor[MetaType]{
49+
Client: cli,
50+
PreValidator: ValidatePrefixSubstring[MetaType]("!"),
51+
commands: make(map[string]*Handler[MetaType]),
52+
aliases: make(map[string]string),
53+
}
54+
proc.Register(&Handler[MetaType]{
55+
Name: UnknownCommandName,
56+
Func: func(ce *Event[MetaType]) {
57+
ce.Reply("Unknown command")
58+
},
59+
})
60+
return proc
61+
}
62+
63+
// Register registers the given command handlers.
64+
func (proc *Processor[MetaType]) Register(handlers ...*Handler[MetaType]) {
65+
proc.lock.Lock()
66+
defer proc.lock.Unlock()
67+
for _, handler := range handlers {
68+
proc.registerOne(handler)
69+
}
70+
}
71+
72+
func (proc *Processor[MetaType]) registerOne(handler *Handler[MetaType]) {
73+
if strings.ToLower(handler.Name) != handler.Name {
74+
panic(fmt.Errorf("command %q is not lowercase", handler.Name))
75+
}
76+
proc.commands[handler.Name] = handler
77+
for _, alias := range handler.Aliases {
78+
if strings.ToLower(alias) != alias {
79+
panic(fmt.Errorf("alias %q is not lowercase", alias))
80+
}
81+
proc.aliases[alias] = handler.Name
82+
}
83+
}
84+
85+
func (proc *Processor[MetaType]) Unregister(handlers ...*Handler[MetaType]) {
86+
proc.lock.Lock()
87+
defer proc.lock.Unlock()
88+
for _, handler := range handlers {
89+
proc.unregisterOne(handler)
90+
}
91+
}
92+
93+
func (proc *Processor[MetaType]) unregisterOne(handler *Handler[MetaType]) {
94+
delete(proc.commands, handler.Name)
95+
for _, alias := range handler.Aliases {
96+
if proc.aliases[alias] == handler.Name {
97+
delete(proc.aliases, alias)
98+
}
99+
}
100+
}
101+
102+
func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) {
103+
log := *zerolog.Ctx(ctx)
104+
defer func() {
105+
panicErr := recover()
106+
if panicErr != nil {
107+
logEvt := log.Error().
108+
Bytes(zerolog.ErrorStackFieldName, debug.Stack())
109+
if realErr, ok := panicErr.(error); ok {
110+
logEvt = logEvt.Err(realErr)
111+
} else {
112+
logEvt = logEvt.Any(zerolog.ErrorFieldName, panicErr)
113+
}
114+
logEvt.Msg("Panic in command handler")
115+
_, err := proc.Client.SendReaction(ctx, evt.RoomID, evt.ID, "💥")
116+
if err != nil {
117+
log.Err(err).Msg("Failed to send reaction after panic")
118+
}
119+
}
120+
}()
121+
parsed := ParseEvent[MetaType](ctx, evt)
122+
if !proc.PreValidator.Validate(parsed) {
123+
return
124+
}
125+
126+
realCommand := parsed.Command
127+
proc.lock.RLock()
128+
alias, ok := proc.aliases[realCommand]
129+
if ok {
130+
realCommand = alias
131+
}
132+
handler, ok := proc.commands[realCommand]
133+
if !ok {
134+
handler, ok = proc.commands[UnknownCommandName]
135+
}
136+
proc.lock.RUnlock()
137+
if !ok {
138+
return
139+
}
140+
141+
logWith := log.With().
142+
Str("command", realCommand).
143+
Stringer("sender", evt.Sender).
144+
Stringer("room_id", evt.RoomID)
145+
if proc.LogArgs {
146+
logWith = logWith.Strs("args", parsed.Args)
147+
}
148+
log = logWith.Logger()
149+
parsed.Ctx = log.WithContext(ctx)
150+
parsed.Handler = handler
151+
parsed.Proc = proc
152+
parsed.Meta = proc.Meta
153+
154+
log.Debug().Msg("Processing command")
155+
handler.Func(parsed)
156+
}

0 commit comments

Comments
 (0)