Skip to content

Commit

Permalink
feat: Add options for geo commands and optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
diiyw committed May 24, 2024
1 parent 0d1e7ca commit 6aed84e
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 17 deletions.
81 changes: 64 additions & 17 deletions geo.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var (
dr = math.Pi / 180.0
)

func (n *Nodis) GeoAdd(key string, members ...*geo.Member) (int64, error) {
func (n *Nodis) GeoAdd(key string, members ...*geo.Member) int64 {
var v int64
_ = n.exec(func(tx *Tx) error {
meta := tx.writeKey(key, n.newZSet)
Expand All @@ -22,11 +22,11 @@ func (n *Nodis) GeoAdd(key string, members ...*geo.Member) (int64, error) {
}
return nil
})
return v, nil
return v
}

// GeoAddXX adds the specified members to the key only if the member already exists.
func (n *Nodis) GeoAddXX(key string, members ...*geo.Member) (int64, error) {
func (n *Nodis) GeoAddXX(key string, members ...*geo.Member) int64 {
var v int64
_ = n.exec(func(tx *Tx) error {
meta := tx.writeKey(key, nil)
Expand All @@ -38,11 +38,11 @@ func (n *Nodis) GeoAddXX(key string, members ...*geo.Member) (int64, error) {
}
return nil
})
return v, nil
return v
}

// GeoAddNX adds the specified members to the key only if the member does not already exist.
func (n *Nodis) GeoAddNX(key string, members ...*geo.Member) (int64, error) {
func (n *Nodis) GeoAddNX(key string, members ...*geo.Member) int64 {
var v int64
_ = n.exec(func(tx *Tx) error {
meta := tx.writeKey(key, n.newZSet)
Expand All @@ -54,7 +54,7 @@ func (n *Nodis) GeoAddNX(key string, members ...*geo.Member) (int64, error) {
}
return nil
})
return v, nil
return v
}

func (n *Nodis) GeoDist(key string, member1, member2 string) (float64, error) {
Expand Down Expand Up @@ -90,9 +90,9 @@ func distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
}

func (n *Nodis) GeoHash(key string, members ...string) ([]string, error) {
func (n *Nodis) GeoHash(key string, members ...string) []string {
var v []string
err := n.exec(func(tx *Tx) error {
_ = n.exec(func(tx *Tx) error {
meta := tx.readKey(key)
if !meta.isOk() {
return nil
Expand All @@ -107,37 +107,84 @@ func (n *Nodis) GeoHash(key string, members ...string) ([]string, error) {
}
return nil
})
return v, err
return v
}

func (n *Nodis) GeoPos(key string, members ...string) ([]*geo.Member, error) {
var v []*geo.Member
err := n.exec(func(tx *Tx) error {
func (n *Nodis) GeoPos(key string, members ...string) []*geo.Member {
var v = make([]*geo.Member, len(members))
_ = n.exec(func(tx *Tx) error {
meta := tx.readKey(key)
if meta == nil {
return nil
}
for _, member := range members {
for i, member := range members {
score, err := meta.value.(*zset.SortedSet).ZScore(member)
if err != nil {
return err
continue
}
lat, lng := geohash.DecodeInt(uint64(score))
v = append(v, &geo.Member{Name: member, Latitude: lat, Longitude: lng})
v[i] = &geo.Member{Name: member, Latitude: lat, Longitude: lng}
}
return nil
})
return v, err
return v
}

func (n *Nodis) GeoRadius(key string, longitude, latitude, radius float64, count int64, desc bool) ([]*geo.Member, error) {
var v []*geo.Member
err := n.exec(func(tx *Tx) error {
meta := tx.readKey(key)
if meta == nil {
if !meta.isOk() {
return nil
}
bits := estimatePrecisionByRadius(radius, latitude)
hash := geohash.EncodeIntWithPrecision(latitude, longitude, bits)
neighbors := geohash.NeighborsIntWithPrecision(hash, bits)
for _, lower := range neighbors {
var items []*zset.Item
r := uint64(1 << (64 - bits))
upper := lower + r
if desc {
items = meta.value.(*zset.SortedSet).ZRangeByScore(float64(lower), float64(upper), 0, count, 0)
} else {
items = meta.value.(*zset.SortedSet).ZRevRangeByScore(float64(lower), float64(upper), 0, count, 0)
}
for _, item := range items {
lat, lng := geohash.DecodeInt(uint64(item.Score))
v = append(v, &geo.Member{Name: item.Member, Latitude: lat, Longitude: lng})
}
}
return nil
})
return v, err
}

const (
MERCATOR_MAX = 20037726.37
)

func estimatePrecisionByRadius(radiusMeters float64, lat float64) uint {
if radiusMeters == 0 {
return 63
}
var precision uint = 1
for radiusMeters < MERCATOR_MAX {
radiusMeters *= 2
precision++
}
/* Make sure range is included in most of the base cases. */
precision -= 2
if lat > 66 || lat < -66 {
precision--
if lat > 80 || lat < -80 {
precision--
}
}
if precision < 1 {
precision = 1
}
if precision > 32 {
precision = 32
}
return precision*2 - 1
}
219 changes: 219 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/diiyw/nodis/ds"
"github.com/diiyw/nodis/ds/zset"
"github.com/diiyw/nodis/geo"
"github.com/diiyw/nodis/internal/strings"
"github.com/diiyw/nodis/redis"
)
Expand Down Expand Up @@ -261,6 +262,16 @@ func getCommand(name string) func(n *Nodis, conn *redis.Conn, cmd redis.Command)
return zInterStore
case "ZEXISTS":
return zExists
case "GEOADD":
return geoAdd
case "GEODIST":
return geoDist
case "GEOHASH":
return geoHash
case "GEOPOS":
return geoPos
case "GEORADIUS":
return geoRadius
}
return cmdNotFound
}
Expand Down Expand Up @@ -2592,3 +2603,211 @@ func save(n *Nodis, conn *redis.Conn, cmd redis.Command) {
conn.WriteString("OK")
})
}

func geoAdd(n *Nodis, conn *redis.Conn, cmd redis.Command) {
if len(cmd.Args) < 4 {
conn.WriteError("GEOADD requires at least four arguments")
return
}
key := cmd.Args[0]
var items = make([]*geo.Member, 0)
var args []string
if cmd.Options.NX == 2 || cmd.Options.XX == 2 {
args = cmd.Args[2:]
} else if cmd.Options.NX == 1 || cmd.Options.XX == 1 {
if cmd.Options.CH == 2 {
args = cmd.Args[2:]
} else {
args = cmd.Args[1:]
}
} else {
args = cmd.Args[1:]
}
if len(args) < 3 {
conn.WriteError("GEOADD requires at least four arguments")
return
}
if len(args)%3 != 0 {
conn.WriteError("syntax error")
return
}
for i := 0; i < len(args); i += 3 {
longitude, err := redis.FormatFloat64(args[i])
if err != nil {
conn.WriteError("ERR longitude value is not a valid float")
return
}
latitude, err := redis.FormatFloat64(args[i+1])
if err != nil {
conn.WriteError("ERR latitude value is not a valid float")
return
}
items = append(items, &geo.Member{Name: args[i+2], Longitude: longitude, Latitude: latitude})
}
execCommand(conn, func() {
if cmd.Options.NX == 1 {
conn.WriteInteger(n.GeoAddNX(key, items...))
return
}
if cmd.Options.XX == 1 {
conn.WriteInteger(n.GeoAddXX(key, items...))
return
}
v := n.GeoAdd(key, items...)
conn.WriteInteger(v)
})
}

func geoDist(n *Nodis, conn *redis.Conn, cmd redis.Command) {
if len(cmd.Args) < 3 {
conn.WriteError("GEODIST requires at least three arguments")
return
}
key := cmd.Args[0]
member1 := cmd.Args[1]
member2 := cmd.Args[2]
execCommand(conn, func() {
v, err := n.GeoDist(key, member1, member2)
if err != nil {
conn.WriteBulkNull()
return
}
switch true {
case cmd.Options.KM == 3:
conn.WriteBulk(strconv.FormatFloat(v/1000, 'f', -1, 64))
case cmd.Options.MI == 3:
conn.WriteBulk(strconv.FormatFloat(v/1609.34, 'f', -1, 64))
case cmd.Options.FT == 3:
conn.WriteBulk(strconv.FormatFloat(v/0.3048, 'f', -1, 64))
default:
conn.WriteBulk(strconv.FormatFloat(v, 'f', -1, 64))
}
})
}

func geoHash(n *Nodis, conn *redis.Conn, cmd redis.Command) {
if len(cmd.Args) < 2 {
conn.WriteError("GEOHASH requires at least two arguments")
return
}
key := cmd.Args[0]
execCommand(conn, func() {
results := n.GeoHash(key, cmd.Args[1:]...)
conn.WriteArray(len(results))
for _, v := range results {
conn.WriteBulk(v)
}
})
}

func geoPos(n *Nodis, conn *redis.Conn, cmd redis.Command) {
if len(cmd.Args) < 2 {
conn.WriteError("GEOPOS requires at least two arguments")
return
}
key := cmd.Args[0]
execCommand(conn, func() {
results := n.GeoPos(key, cmd.Args[1:]...)
conn.WriteArray(len(results))
for _, v := range results {
if v == nil {
conn.WriteBulkNull()
continue
}
conn.WriteArray(2)
conn.WriteBulk(strconv.FormatFloat(v.Longitude, 'f', -1, 64))
conn.WriteBulk(strconv.FormatFloat(v.Latitude, 'f', -1, 64))
}
})
}

func geoRadius(n *Nodis, conn *redis.Conn, cmd redis.Command) {
if len(cmd.Args) < 4 {
conn.WriteError("GEORADIUS requires at least four arguments")
return
}
key := cmd.Args[0]
longitude, err := redis.FormatFloat64(cmd.Args[1])
if err != nil {
conn.WriteError("ERR longitude value is not a valid float")
return
}
latitude, err := redis.FormatFloat64(cmd.Args[2])
if err != nil {
conn.WriteError("ERR latitude value is not a valid float")
return
}
radius, err := redis.FormatFloat64(cmd.Args[3])
if err != nil {
conn.WriteError("ERR radius value is not a valid float")
return
}
if cmd.Options.KM > 3 {
radius *= 1000
}
if cmd.Options.MI > 3 {
radius *= 1609.34
}
if cmd.Options.FT > 3 {
radius *= 0.3048
}
var count int64 = -1
if cmd.Options.COUNT > 3 && cmd.Options.ANY == 0 {
count, err = strconv.ParseInt(cmd.Args[cmd.Options.COUNT], 10, 64)
if err != nil {
conn.WriteError("ERR count value is not a valid integer")
return
}
}
execCommand(conn, func() {
var results []*geo.Member
var err error
results, err = n.GeoRadius(key, longitude, latitude, radius, count, cmd.Options.DESC > 3)
if err != nil {
conn.WriteArrayNull()
return
}
if cmd.Options.WITHCOORD == 0 && cmd.Options.WITHDIST == 0 && cmd.Options.WITHHASH == 0 {
conn.WriteArray(len(results))
for _, v := range results {
conn.WriteBulk(v.Name)
}
return
}
conn.WriteArray(len(results))
l := 1
if cmd.Options.WITHCOORD > 3 {
l++
}
if cmd.Options.WITHDIST > 3 {
l++
}
if cmd.Options.WITHHASH > 3 {
l++
}
for _, v := range results {
conn.WriteArray(l)
conn.WriteBulk(v.Name)
if cmd.Options.WITHDIST > 3 {
dist := distance(latitude, longitude, v.Latitude, v.Longitude)
if cmd.Options.KM > 3 {
conn.WriteBulk(strconv.FormatFloat(dist/1000, 'f', -1, 64))
} else if cmd.Options.MI > 3 {
conn.WriteBulk(strconv.FormatFloat(dist/1609.34, 'f', -1, 64))
} else if cmd.Options.FT > 3 {
conn.WriteBulk(strconv.FormatFloat(dist/0.3048, 'f', -1, 64))
} else {
conn.WriteBulk(strconv.FormatFloat(dist, 'f', -1, 64))
}
}
if cmd.Options.WITHHASH > 3 {
conn.WriteInteger(int64(v.Hash()))
}
if cmd.Options.WITHCOORD > 3 {
conn.WriteArray(2)
conn.WriteBulk(strconv.FormatFloat(v.Longitude, 'f', -1, 64))
conn.WriteBulk(strconv.FormatFloat(v.Latitude, 'f', -1, 64))
}
}
})
}
Loading

0 comments on commit 6aed84e

Please sign in to comment.