Skip to content

Commit f32bd13

Browse files
committed
change to generic map interface
1 parent cca5300 commit f32bd13

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

null.go

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,28 +138,56 @@ func (s String) Value() (driver.Value, error) {
138138
return string(s), nil
139139
}
140140

141-
// StringMap is a one level deep dictionary that is represented as JSON text in the database.
141+
// Map is a one level deep dictionary that is represented as JSON text in the database.
142142
// Empty maps will be written as null to the database and to JSON.
143-
type StringMap struct {
144-
m map[string]string
143+
type Map struct {
144+
m map[string]interface{}
145145
}
146146

147-
// NewStringMap creates a new StringMap
148-
func NewStringMap(m map[string]string) StringMap {
149-
return StringMap{m: m}
147+
// NewMap creates a new Map
148+
func NewMap(m map[string]interface{}) Map {
149+
return Map{m: m}
150150
}
151151

152152
// Map returns our underlying map
153-
func (m *StringMap) Map() map[string]string {
153+
func (m *Map) Map() map[string]interface{} {
154154
if m.m == nil {
155-
m.m = make(map[string]string)
155+
m.m = make(map[string]interface{})
156156
}
157157
return m.m
158158
}
159159

160+
// GetString returns the string value with the passed in key, or def if not found or of wrong type
161+
func (m *Map) GetString(key string, def string) string {
162+
if m.m == nil {
163+
return def
164+
}
165+
val := m.m[key]
166+
if val == nil {
167+
return def
168+
}
169+
str, isStr := val.(string)
170+
if !isStr {
171+
return def
172+
}
173+
return str
174+
}
175+
176+
// Get returns the value with the passed in key, or def if not found
177+
func (m *Map) Get(key string, def interface{}) interface{} {
178+
if m.m == nil {
179+
return def
180+
}
181+
val := m.m[key]
182+
if val == nil {
183+
return def
184+
}
185+
return val
186+
}
187+
160188
// Scan implements the Scanner interface for decoding from a database
161-
func (m *StringMap) Scan(src interface{}) error {
162-
m.m = make(map[string]string)
189+
func (m *Map) Scan(src interface{}) error {
190+
m.m = make(map[string]interface{})
163191
if src == nil {
164192
return nil
165193
}
@@ -187,24 +215,24 @@ func (m *StringMap) Scan(src interface{}) error {
187215
}
188216

189217
// Value implements the driver Valuer interface
190-
func (m StringMap) Value() (driver.Value, error) {
218+
func (m Map) Value() (driver.Value, error) {
191219
if m.m == nil || len(m.m) == 0 {
192220
return nil, nil
193221
}
194222
return json.Marshal(m.m)
195223
}
196224

197225
// MarshalJSON encodes our map to JSON
198-
func (m StringMap) MarshalJSON() ([]byte, error) {
226+
func (m Map) MarshalJSON() ([]byte, error) {
199227
if m.m == nil || len(m.m) == 0 {
200228
return json.Marshal(nil)
201229
}
202230
return json.Marshal(m.m)
203231
}
204232

205233
// UnmarshalJSON sets our map from the passed in JSON
206-
func (m *StringMap) UnmarshalJSON(data []byte) error {
207-
m.m = make(map[string]string)
234+
func (m *Map) UnmarshalJSON(data []byte) error {
235+
m.m = make(map[string]interface{})
208236
if len(data) == 0 {
209237
return nil
210238
}

null_test.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,16 @@ func TestMap(t *testing.T) {
310310
}
311311

312312
tcs := []struct {
313-
Value StringMap
314-
JSON string
315-
DB *string
313+
Value Map
314+
JSON string
315+
DB *string
316+
Key string
317+
KeyValue string
316318
}{
317-
{NewStringMap(map[string]string{"foo": "bar"}), `{"foo":"bar"}`, sp(`{"foo": "bar"}`)},
318-
{NewStringMap(map[string]string{}), "null", nil},
319-
{NewStringMap(nil), "null", nil},
320-
{NewStringMap(nil), "null", sp("")},
319+
{NewMap(map[string]interface{}{"foo": "bar"}), `{"foo":"bar"}`, sp(`{"foo": "bar"}`), "foo", "bar"},
320+
{NewMap(map[string]interface{}{}), "null", nil, "foo", ""},
321+
{NewMap(nil), "null", nil, "foo", ""},
322+
{NewMap(nil), "null", sp(""), "foo", ""},
321323
}
322324

323325
for i, tc := range tcs {
@@ -328,23 +330,25 @@ func TestMap(t *testing.T) {
328330
assert.NoError(t, err)
329331
assert.Equal(t, tc.JSON, string(b), "%d: %s not equal to %s", i, tc.JSON, string(b))
330332

331-
m := StringMap{}
333+
m := Map{}
332334
err = json.Unmarshal(b, &m)
333335
assert.NoError(t, err)
334336
assert.Equal(t, tc.Value.Map(), m.Map(), "%d: %s not equal to %s", i, tc.Value, m)
337+
assert.Equal(t, m.GetString(tc.Key, ""), tc.KeyValue)
335338

336339
_, err = db.Exec(`INSERT INTO map(value) VALUES($1)`, tc.Value)
337340
assert.NoError(t, err)
338341

339342
rows, err := db.Query(`SELECT value FROM map;`)
340343
assert.NoError(t, err)
341344

342-
m2 := StringMap{}
345+
m2 := Map{}
343346
assert.True(t, rows.Next())
344347
err = rows.Scan(&m2)
345348
assert.NoError(t, err)
346349

347350
assert.Equal(t, tc.Value.Map(), m2.Map())
351+
assert.Equal(t, m2.GetString(tc.Key, ""), tc.KeyValue)
348352

349353
_, err = db.Exec(`DELETE FROM map;`)
350354
assert.NoError(t, err)
@@ -355,11 +359,12 @@ func TestMap(t *testing.T) {
355359
rows, err = db.Query(`SELECT value FROM map;`)
356360
assert.NoError(t, err)
357361

358-
m2 = StringMap{}
362+
m2 = Map{}
359363
assert.True(t, rows.Next())
360364
err = rows.Scan(&m2)
361365
assert.NoError(t, err)
362366

363367
assert.Equal(t, tc.Value.Map(), m2.Map())
368+
assert.Equal(t, m2.GetString(tc.Key, ""), tc.KeyValue)
364369
}
365370
}

0 commit comments

Comments
 (0)