Skip to content

Commit cca5300

Browse files
authored
Merge pull request #1 from nyaruka/string-map
add StringMap type
2 parents c108471 + 094e0a4 commit cca5300

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

null.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"database/sql/driver"
66
"encoding/json"
7+
"fmt"
78
)
89

910
// Int is an int that will write as null when it is zero both to databases and json
@@ -136,3 +137,76 @@ func (s String) Value() (driver.Value, error) {
136137
}
137138
return string(s), nil
138139
}
140+
141+
// StringMap is a one level deep dictionary that is represented as JSON text in the database.
142+
// Empty maps will be written as null to the database and to JSON.
143+
type StringMap struct {
144+
m map[string]string
145+
}
146+
147+
// NewStringMap creates a new StringMap
148+
func NewStringMap(m map[string]string) StringMap {
149+
return StringMap{m: m}
150+
}
151+
152+
// Map returns our underlying map
153+
func (m *StringMap) Map() map[string]string {
154+
if m.m == nil {
155+
m.m = make(map[string]string)
156+
}
157+
return m.m
158+
}
159+
160+
// Scan implements the Scanner interface for decoding from a database
161+
func (m *StringMap) Scan(src interface{}) error {
162+
m.m = make(map[string]string)
163+
if src == nil {
164+
return nil
165+
}
166+
167+
var source []byte
168+
switch src.(type) {
169+
case string:
170+
source = []byte(src.(string))
171+
case []byte:
172+
source = src.([]byte)
173+
default:
174+
return fmt.Errorf("incompatible type for map")
175+
}
176+
177+
// 0 length string is same as nil
178+
if len(source) == 0 {
179+
return nil
180+
}
181+
182+
err := json.Unmarshal(source, &m.m)
183+
if err != nil {
184+
return err
185+
}
186+
return nil
187+
}
188+
189+
// Value implements the driver Valuer interface
190+
func (m StringMap) Value() (driver.Value, error) {
191+
if m.m == nil || len(m.m) == 0 {
192+
return nil, nil
193+
}
194+
return json.Marshal(m.m)
195+
}
196+
197+
// MarshalJSON encodes our map to JSON
198+
func (m StringMap) MarshalJSON() ([]byte, error) {
199+
if m.m == nil || len(m.m) == 0 {
200+
return json.Marshal(nil)
201+
}
202+
return json.Marshal(m.m)
203+
}
204+
205+
// UnmarshalJSON sets our map from the passed in JSON
206+
func (m *StringMap) UnmarshalJSON(data []byte) error {
207+
m.m = make(map[string]string)
208+
if len(data) == 0 {
209+
return nil
210+
}
211+
return json.Unmarshal(data, &m.m)
212+
}

null_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,69 @@ func TestString(t *testing.T) {
297297
assert.True(t, tc.Value == str)
298298
}
299299
}
300+
301+
func TestMap(t *testing.T) {
302+
db, err := sql.Open("postgres", "postgres://localhost/null_test?sslmode=disable")
303+
assert.NoError(t, err)
304+
305+
_, err = db.Exec(`DROP TABLE IF EXISTS map; CREATE TABLE map(value varchar(255) null);`)
306+
assert.NoError(t, err)
307+
308+
sp := func(s string) *string {
309+
return &s
310+
}
311+
312+
tcs := []struct {
313+
Value StringMap
314+
JSON string
315+
DB *string
316+
}{
317+
{NewStringMap(map[string]string{"foo": "bar"}), `{"foo":"bar"}`, sp(`{"foo": "bar"}`)},
318+
{NewStringMap(map[string]string{}), "null", nil},
319+
{NewStringMap(nil), "null", nil},
320+
{NewStringMap(nil), "null", sp("")},
321+
}
322+
323+
for i, tc := range tcs {
324+
_, err = db.Exec(`DELETE FROM map;`)
325+
assert.NoError(t, err)
326+
327+
b, err := json.Marshal(tc.Value)
328+
assert.NoError(t, err)
329+
assert.Equal(t, tc.JSON, string(b), "%d: %s not equal to %s", i, tc.JSON, string(b))
330+
331+
m := StringMap{}
332+
err = json.Unmarshal(b, &m)
333+
assert.NoError(t, err)
334+
assert.Equal(t, tc.Value.Map(), m.Map(), "%d: %s not equal to %s", i, tc.Value, m)
335+
336+
_, err = db.Exec(`INSERT INTO map(value) VALUES($1)`, tc.Value)
337+
assert.NoError(t, err)
338+
339+
rows, err := db.Query(`SELECT value FROM map;`)
340+
assert.NoError(t, err)
341+
342+
m2 := StringMap{}
343+
assert.True(t, rows.Next())
344+
err = rows.Scan(&m2)
345+
assert.NoError(t, err)
346+
347+
assert.Equal(t, tc.Value.Map(), m2.Map())
348+
349+
_, err = db.Exec(`DELETE FROM map;`)
350+
assert.NoError(t, err)
351+
352+
_, err = db.Exec(`INSERT INTO map(value) VALUES($1)`, tc.DB)
353+
assert.NoError(t, err)
354+
355+
rows, err = db.Query(`SELECT value FROM map;`)
356+
assert.NoError(t, err)
357+
358+
m2 = StringMap{}
359+
assert.True(t, rows.Next())
360+
err = rows.Scan(&m2)
361+
assert.NoError(t, err)
362+
363+
assert.Equal(t, tc.Value.Map(), m2.Map())
364+
}
365+
}

0 commit comments

Comments
 (0)