Skip to content
This repository was archived by the owner on Jan 14, 2022. It is now read-only.

Commit 3cfb502

Browse files
Random tweaks, codegen requests and responses
1 parent 2e65b6a commit 3cfb502

File tree

10 files changed

+249
-170
lines changed

10 files changed

+249
-170
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ import (
2121
obs "github.com/christopher-dG/go-obs-websocket"
2222
)
2323

24-
client := obs.NewClient("localhost", 4444, "")
24+
client := obs.Client{Host: "localhost", Port: 4444}
2525
if err := client.Connect(); err != nil {
2626
log.Fatal(err)
2727
}
28-
defer client.Close()
28+
defer client.Disconnect()
2929

3030
// TODO
3131
```

client.go

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,96 @@
11
package obsws
22

3-
import "github.com/gorilla/websocket"
3+
import (
4+
"crypto/sha256"
5+
"encoding/base64"
6+
"fmt"
7+
"strconv"
48

5-
// client is the interface to obs-websocket.
6-
type client struct {
9+
"github.com/gorilla/websocket"
10+
"github.com/pkg/errors"
11+
)
12+
13+
// Client is the interface to obs-websocket.
14+
// Client{Host: "localhost", Port: 4444} will probably work if you haven't configured OBS.
15+
type Client struct {
716
Host string // Host (probably "localhost").
817
Port int // Port (OBS default is 4444).
918
Password string // Password (OBS default is "").
1019
conn *websocket.Conn
1120
id int
1221
}
1322

14-
// NewClient creates a new client. If you haven't configured obs-websocket at
15-
// all, then host should be "localhost", port should be 4444, and password
16-
// should be "".
17-
func NewClient(host string, port int, password string) *client {
18-
return &client{Host: host, Port: port, Password: password}
23+
// Connect opens a WebSocket connection and authenticates if necessary.
24+
func (c *Client) Connect() error {
25+
conn, err := connectWS(c.Host, c.Port)
26+
if err != nil {
27+
return err
28+
}
29+
c.conn = conn
30+
31+
reqGAR := GetAuthRequiredRequest{
32+
MessageID: c.getMessageID(),
33+
RequestType: "GetAuthRequired",
34+
}
35+
36+
if err = c.conn.WriteJSON(reqGAR); err != nil {
37+
return errors.Wrap(err, "write Authenticate")
38+
}
39+
40+
respGAR := &GetAuthRequiredResponse{}
41+
if err = c.conn.ReadJSON(respGAR); err != nil {
42+
return errors.Wrap(err, "read GetAuthRequired")
43+
}
44+
45+
if !respGAR.AuthRequired {
46+
logger.Info("no authentication required")
47+
return nil
48+
}
49+
50+
auth := getAuth(c.Password, respGAR.Salt, respGAR.Challenge)
51+
logger.Debugf("auth: %s", auth)
52+
53+
reqA := AuthenticateRequest{
54+
Auth: auth,
55+
_request: _request{
56+
MessageID: c.getMessageID(),
57+
RequestType: "Authenticate",
58+
},
59+
}
60+
if err = c.conn.WriteJSON(reqA); err != nil {
61+
return errors.Wrap(err, "write Authenticate")
62+
}
63+
64+
logger.Info("logged in")
65+
return nil
66+
}
67+
68+
// Disconnect closes the WebSocket connection.
69+
func (c *Client) Disconnect() error {
70+
return c.conn.Close()
71+
}
72+
73+
func connectWS(host string, port int) (*websocket.Conn, error) {
74+
url := fmt.Sprintf("ws://%s:%d", host, port)
75+
logger.Infof("connecting to %s", url)
76+
conn, _, err := websocket.DefaultDialer.Dial(url, nil)
77+
if err != nil {
78+
return nil, err
79+
}
80+
return conn, nil
81+
}
82+
83+
func getAuth(password, salt, challenge string) string {
84+
sha := sha256.Sum256([]byte(password + salt))
85+
b64 := base64.StdEncoding.EncodeToString([]byte(sha[:]))
86+
87+
sha = sha256.Sum256([]byte(b64 + challenge))
88+
b64 = base64.StdEncoding.EncodeToString([]byte(sha[:]))
89+
90+
return b64
91+
}
92+
93+
func (c *Client) getMessageID() string {
94+
c.id++
95+
return strconv.Itoa(c.id)
1996
}

client_test.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@ package obsws
22

33
import "testing"
44

5-
func TestNewClient(t *testing.T) {
6-
c := NewClient("localhost", 4444, "")
7-
if c.Host != "localhost" {
8-
t.Errorf("expected c.Host == 'localhost', got '%s'", c.Host)
9-
}
10-
if c.Port != 4444 {
11-
t.Errorf("expected c.Port == 4444, got '%d'", c.Port)
12-
}
13-
if c.Password != "" {
14-
t.Errorf("expected c.Password == '', got '%s'", c.Password)
5+
func TestGetAuth(t *testing.T) {
6+
expected := "zTM5ki6L2vVvBQiTG9ckH1Lh64AbnCf6XZ226UmnkIA="
7+
observed := getAuth("password", "salt", "challenge")
8+
if observed != expected {
9+
t.Errorf("expected auth == '%s', got '%s'", expected, observed)
1510
}
1611
}

codegen/protocol.py

Lines changed: 131 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
package = "obsws"
99

1010
type_map = {
11+
"bool": "bool",
1112
"boolean": "bool",
1213
"int": "int",
14+
"float": "float64",
1315
"double": "float64",
1416
"string": "string",
1517
"array": "[]string",
@@ -36,36 +38,37 @@ def optional_type(s: str) -> Tuple[str, bool]:
3638

3739

3840
def process_json(d: Dict):
39-
process_events(d["events"])
40-
process_requests(d["requests"])
41+
gen_events(d["events"])
42+
gen_requests(d["requests"])
4143

4244

43-
def process_events(events: Dict):
45+
def gen_events(events: Dict):
46+
"""Generate all events."""
4447
for category, data in events.items():
45-
process_events_category(category, data)
48+
gen_events_category(category, data)
4649

4750

48-
def process_events_category(category: str, data: Dict):
49-
events = "\n\n".join(generate_event(event) for event in data)
51+
def gen_events_category(category: str, data: Dict):
52+
"""Generate all events in one category."""
53+
events = "\n\n".join(gen_event(event) for event in data)
5054
with open(go_filename("events", category), "w") as f:
5155
f.write(f"""\
5256
package {package}
5357
54-
// This code is automatically generated.
55-
// See: https://github.com/christopher-dG/go-obs-websocket/blob/master/codegen/protocol.py
56-
57-
// https://github.com/Palakis/obs-websocket/blob/master/docs/generated/protocol.md#{category}
58+
// This file is automatically generated.
59+
// https://github.com/christopher-dG/go-obs-websocket/blob/master/codegen/protocol.py
5860
5961
{events}
6062
""")
6163

6264

63-
def generate_event(data: Dict) -> str:
64-
"""Generate Go code with type definitions and interface functions."""
65-
if "returns" in data:
65+
def gen_event(data: Dict) -> str:
66+
"""Write Go code with a type definition and interface functions."""
67+
reserved = ["Type", "StreamTC", "RecTC"]
68+
if data.get("returns"):
6669
struct = f"""\
6770
type {data['name']}Event struct {{
68-
{go_variables(data['returns'])}
71+
{go_variables(data['returns'], reserved)}
6972
_event
7073
}}\
7174
"""
@@ -74,8 +77,10 @@ def generate_event(data: Dict) -> str:
7477

7578
description = data["description"].replace("\n", " ")
7679
description = f"{data['name']}Event : {description}"
77-
if "since" in data:
78-
description += f" Since: {data['since'].capitalize()}"
80+
if description and not description.endswith("."):
81+
description += "."
82+
if data.get("since"):
83+
description += f" Since: {data['since'].capitalize()}."
7984

8085
return f"""\
8186
// {description}
@@ -93,44 +98,135 @@ def generate_event(data: Dict) -> str:
9398
"""
9499

95100

96-
def process_requests(requests: Dict):
97-
pass
101+
def gen_requests(requests: Dict):
102+
"""Generate all requests and responses."""
103+
for category, data in requests.items():
104+
gen_requests_category(category, data)
105+
106+
107+
def gen_requests_category(category: str, data: Dict):
108+
requests = "\n\n".join(gen_request(request) for request in data)
109+
with open(go_filename("requests", category), "w") as f:
110+
f.write(f"""\
111+
package {package}
98112
113+
// This file is automatically generated.
114+
// https://github.com/christopher-dG/go-obs-websocket/blob/master/codegen/protocol.py
99115
100-
def go_variables(vars: List) -> str:
116+
{requests}
117+
""")
118+
119+
120+
def gen_request(data: Dict) -> str:
121+
"""Write Go code with type definitions and interface functions."""
122+
reserved = ["ID", "Type"]
123+
if data.get("params"):
124+
struct = f"""\
125+
type {data['name']}Request struct {{
126+
{go_variables(data['params'], reserved)}
127+
_request
128+
}}
129+
"""
130+
else:
131+
struct = f"type {data['name']}Request _request"
132+
133+
description = data["description"].replace("\n", " ")
134+
description = f"{data['name']}Request : {description}"
135+
if description and not description.endswith("."):
136+
description += "."
137+
if data.get("since"):
138+
description += f" Since: {data['since'].capitalize()}."
139+
140+
request = f"""\
141+
// {description}
142+
// https://github.com/Palakis/obs-websocket/blob/master/docs/generated/protocol.md#{data['heading']['text'].lower()}
143+
{struct}
144+
145+
// ID returns the request's message ID.
146+
func (r {data['name']}Request) ID() string {{ return r.MessageID }}
147+
148+
// Type returns the request's message type.
149+
func (r {data['name']}Request) Type() string {{ return r.RequestType }}
101150
"""
102-
Convert a list of variable definition into Go code to be put
151+
152+
if data.get("returns"):
153+
reserved = ["ID", "Stat", "Err"]
154+
struct = f"""\
155+
type {data['name']}Response struct {{
156+
{go_variables(data['returns'], reserved)}
157+
_response
158+
}}
159+
"""
160+
else:
161+
struct = f"type {data['name']}Response _response"
162+
163+
description = f"{data['name']}Response : Response for {data['name']}Request."
164+
if data.get("since"):
165+
description += f" Since: {data['since'].capitalize()}."
166+
167+
response = f"""\
168+
// {description}
169+
// https://github.com/Palakis/obs-websocket/blob/master/docs/generated/protocol.md#{data['heading']['text'].lower()}
170+
{struct}
171+
172+
// ID returns the response's message ID.
173+
func (r {data['name']}Response) ID() string {{ return r.MessageID }}
174+
175+
// Stat returns the response's status.
176+
func (r {data['name']}Response) Stat() string {{ return r.Status }}
177+
178+
// Err returns the response's error.
179+
func (r {data['name']}Response) Err() string {{ return r.Error }}
180+
"""
181+
182+
return f"{request}\n\n{response}"
183+
184+
185+
def go_variables(names: List, reserved: List) -> str:
186+
"""
187+
Convert a list of variable names into Go code to be put
103188
inside a struct definition.
104189
"""
105-
lines = []
106-
for v in vars:
107-
line = go_name(v["name"])
190+
lines, varnames = [], []
191+
for v in names:
108192
typename, optional = optional_type(v["type"])
109-
line += f" {type_map[typename.lower()]} // {v['description']}"
193+
varname = go_var(v["name"])
194+
description = v["description"].replace("\n", " ")
195+
if description and not description.endswith("."):
196+
description += "."
197+
tag = '`json:"%s"`' % v['name']
198+
line = f"{go_var(v['name'])} {type_map[typename.lower()]} {tag} // {description}"
110199
if optional:
111-
line += " Optional."
200+
line += " Optional." if description else "Optional."
201+
if varname in reserved:
202+
line += " TODO: Reserved name."
203+
if varname in varnames:
204+
line += " TODO: Duplicate name."
205+
else:
206+
varnames.append(varname)
112207
if typename.lower() in unknown_types:
113-
line += f" TODO: Unknown type ({typename})."
208+
line += f" TODO: Unknown type ({v['type']})."
114209
lines.append(line)
115210
return "\n".join(lines)
116211

117212

118-
def go_name(s: str) -> str:
119-
"""
120-
Convert a variable name in the input file to a Go variable name.
121-
Note: This makes lots of assumptions about the input,
122-
i.e. nothing ends with a separator.
123-
"""
124-
s = s.capitalize()
213+
def go_var(s: str) -> str:
214+
"""Convert a variable name in the input file to a Go variable name."""
215+
s = f"{s[0].upper()}{s[1:]}"
125216
for sep in ["-", "_", ".*.", "[].", "."]:
126217
while sep in s:
127-
i = s.find(sep)
128218
_len = len(sep)
219+
if s.endswith(sep):
220+
s = s[:-_len]
221+
continue
222+
i = s.find(sep)
129223
s = f"{s[:i]}{s[i+_len].upper()}{s[i+_len+1:]}"
130-
return s
224+
225+
return s.replace("Id", "ID") # Yuck.
131226

132227

133228
def go_filename(category, section):
229+
"""Generate a Go filename from a category and section."""
134230
return f"{category}_{section.replace(' ', '_')}.go"
135231

136232

0 commit comments

Comments
 (0)