Skip to content

Commit

Permalink
fix: command - distinguish between key not found and zero value
Browse files Browse the repository at this point in the history
  • Loading branch information
nalgeon committed May 4, 2024
1 parent 91ba6db commit 4aff4a9
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 59 deletions.
21 changes: 11 additions & 10 deletions internal/command/hash/hdel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hash
import (
"testing"

"github.com/nalgeon/redka/internal/core"
"github.com/nalgeon/redka/internal/redis"
"github.com/nalgeon/redka/internal/testx"
)
Expand Down Expand Up @@ -68,8 +69,8 @@ func TestHDelExec(t *testing.T) {
testx.AssertEqual(t, res, 1)
testx.AssertEqual(t, conn.Out(), "1")

name, _ := db.Hash().Get("person", "name")
testx.AssertEqual(t, name.Exists(), false)
_, err = db.Hash().Get("person", "name")
testx.AssertErr(t, err, core.ErrNotFound)
age, _ := db.Hash().Get("person", "age")
testx.AssertEqual(t, age.String(), "25")
})
Expand All @@ -89,12 +90,12 @@ func TestHDelExec(t *testing.T) {
testx.AssertEqual(t, res, 2)
testx.AssertEqual(t, conn.Out(), "2")

name, _ := db.Hash().Get("person", "name")
testx.AssertEqual(t, name.Exists(), false)
_, err = db.Hash().Get("person", "name")
testx.AssertErr(t, err, core.ErrNotFound)
age, _ := db.Hash().Get("person", "age")
testx.AssertEqual(t, age.String(), "25")
happy, _ := db.Hash().Get("person", "happy")
testx.AssertEqual(t, happy.Exists(), false)
_, err = db.Hash().Get("person", "happy")
testx.AssertErr(t, err, core.ErrNotFound)
})
t.Run("all", func(t *testing.T) {
db, red := getDB(t)
Expand All @@ -111,9 +112,9 @@ func TestHDelExec(t *testing.T) {
testx.AssertEqual(t, res, 2)
testx.AssertEqual(t, conn.Out(), "2")

name, _ := db.Hash().Get("person", "name")
testx.AssertEqual(t, name.Exists(), false)
age, _ := db.Hash().Get("person", "age")
testx.AssertEqual(t, age.Exists(), false)
_, err = db.Hash().Get("person", "name")
testx.AssertErr(t, err, core.ErrNotFound)
_, err = db.Hash().Get("person", "age")
testx.AssertErr(t, err, core.ErrNotFound)
})
}
14 changes: 6 additions & 8 deletions internal/command/hash/hmget.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,20 @@ func (cmd *HMGet) Run(w redis.Writer, red redis.Redka) (any, error) {
return nil, err
}

// Build the result slice.
// Write the result.
// It will contain all values in the order of fields.
// Missing fields will have nil values.
w.WriteArray(len(cmd.fields))
vals := make([]core.Value, len(cmd.fields))
for i, field := range cmd.fields {
vals[i] = items[field]
}

// Write the result.
w.WriteArray(len(vals))
for _, v := range vals {
if v.Exists() {
v, ok := items[field]
vals[i] = v
if ok {
w.WriteBulk(v.Bytes())
} else {
w.WriteNull()
}
}

return vals, nil
}
5 changes: 3 additions & 2 deletions internal/command/key/del_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package key
import (
"testing"

"github.com/nalgeon/redka/internal/core"
"github.com/nalgeon/redka/internal/redis"
"github.com/nalgeon/redka/internal/testx"
)
Expand Down Expand Up @@ -80,8 +81,8 @@ func TestDelExec(t *testing.T) {
testx.AssertEqual(t, res, test.res)
testx.AssertEqual(t, conn.Out(), test.out)

name, _ := db.Str().Get("name")
testx.AssertEqual(t, name.Exists(), false)
_, err = db.Str().Get("name")
testx.AssertErr(t, err, core.ErrNotFound)
city, _ := db.Str().Get("city")
testx.AssertEqual(t, city.String(), "paris")
})
Expand Down
10 changes: 7 additions & 3 deletions internal/command/string/getset.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package string

import "github.com/nalgeon/redka/internal/redis"
import (
"github.com/nalgeon/redka/internal/core"
"github.com/nalgeon/redka/internal/redis"
)

// Returns the previous string value of a key after setting it to a new value.
// GETSET key value
Expand All @@ -27,9 +30,10 @@ func (cmd *GetSet) Run(w redis.Writer, red redis.Redka) (any, error) {
w.WriteError(cmd.Error(err))
return nil, err
}
if !out.Prev.Exists() {
if out.Created {
// no previous value
w.WriteNull()
return out.Prev, nil
return core.Value(nil), nil
}
w.WriteBulk(out.Prev)
return out.Prev, nil
Expand Down
14 changes: 6 additions & 8 deletions internal/command/string/mget.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,20 @@ func (cmd *MGet) Run(w redis.Writer, red redis.Redka) (any, error) {
return nil, err
}

// Build the result slice.
// Write the result.
// It will contain all values in the order of keys.
// Missing keys will have nil values.
w.WriteArray(len(cmd.keys))
vals := make([]core.Value, len(cmd.keys))
for i, key := range cmd.keys {
vals[i] = items[key]
}

// Write the result.
w.WriteArray(len(vals))
for _, v := range vals {
if v.Exists() {
v, ok := items[key]
vals[i] = v
if ok {
w.WriteBulk(v.Bytes())
} else {
w.WriteNull()
}
}

return vals, nil
}
4 changes: 2 additions & 2 deletions internal/command/string/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ func (cmd *Set) Run(w redis.Writer, red redis.Redka) (any, error) {
}

if cmd.get {
// GET given: The key didn't exist before the SET.
if !out.Prev.Exists() {
if out.Created {
// no previous value
w.WriteNull()
return core.Value(nil), nil
}
Expand Down
12 changes: 7 additions & 5 deletions internal/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (v Value) Bytes() []byte {

// Bool returns the value as a boolean.
func (v Value) Bool() (bool, error) {
if !v.Exists() {
if v.IsZero() {
return false, nil
}
return strconv.ParseBool(string(v))
Expand All @@ -99,7 +99,7 @@ func (v Value) MustBool() bool {

// Int returns the value as an integer.
func (v Value) Int() (int, error) {
if !v.Exists() {
if v.IsZero() {
return 0, nil
}
return strconv.Atoi(string(v))
Expand All @@ -118,7 +118,7 @@ func (v Value) MustInt() int {

// Float returns the value as a float64.
func (v Value) Float() (float64, error) {
if !v.Exists() {
if v.IsZero() {
return 0, nil
}
return strconv.ParseFloat(string(v), 64)
Expand All @@ -134,8 +134,10 @@ func (v Value) MustFloat() float64 {
}
return f
}
func (v Value) Exists() bool {
return len(v) != 0

// IsZero reports whether the value is empty.
func (v Value) IsZero() bool {
return len(v) == 0
}

// IsValueType reports if the value has a valid type to be persisted
Expand Down
121 changes: 121 additions & 0 deletions internal/core/core_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package core

import (
"testing"

"github.com/nalgeon/redka/internal/testx"
)

func TestValue(t *testing.T) {
t.Run("bytes", func(t *testing.T) {
v := Value([]byte("hello"))
testx.AssertEqual(t, v.IsZero(), false)
testx.AssertEqual(t, v.Bytes(), []byte("hello"))
testx.AssertEqual(t, v.String(), "hello")
_, err := v.Bool()
testx.AssertEqual(t, err.Error(), `strconv.ParseBool: parsing "hello": invalid syntax`)
_, err = v.Int()
testx.AssertEqual(t, err.Error(), `strconv.Atoi: parsing "hello": invalid syntax`)
_, err = v.Float()
testx.AssertEqual(t, err.Error(), `strconv.ParseFloat: parsing "hello": invalid syntax`)
})
t.Run("string", func(t *testing.T) {
v := Value("hello")
testx.AssertEqual(t, v.IsZero(), false)
testx.AssertEqual(t, v.Bytes(), []byte("hello"))
testx.AssertEqual(t, v.String(), "hello")
_, err := v.Bool()
testx.AssertEqual(t, err.Error(), `strconv.ParseBool: parsing "hello": invalid syntax`)
_, err = v.Int()
testx.AssertEqual(t, err.Error(), `strconv.Atoi: parsing "hello": invalid syntax`)
_, err = v.Float()
testx.AssertEqual(t, err.Error(), `strconv.ParseFloat: parsing "hello": invalid syntax`)
})
t.Run("bool true", func(t *testing.T) {
v := Value("1")
testx.AssertEqual(t, v.IsZero(), false)
testx.AssertEqual(t, v.Bytes(), []byte("1"))
testx.AssertEqual(t, v.String(), "1")
vbool, err := v.Bool()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vbool, true)
vint, err := v.Int()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vint, 1)
vfloat, err := v.Float()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vfloat, 1.0)
})
t.Run("bool false", func(t *testing.T) {
v := Value("0")
testx.AssertEqual(t, v.IsZero(), false)
testx.AssertEqual(t, v.Bytes(), []byte("0"))
testx.AssertEqual(t, v.String(), "0")
vbool, err := v.Bool()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vbool, false)
vint, err := v.Int()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vint, 0)
vfloat, err := v.Float()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vfloat, 0.0)
})
t.Run("int", func(t *testing.T) {
v := Value("42")
testx.AssertEqual(t, v.IsZero(), false)
testx.AssertEqual(t, v.Bytes(), []byte("42"))
testx.AssertEqual(t, v.String(), "42")
_, err := v.Bool()
testx.AssertEqual(t, err.Error(), `strconv.ParseBool: parsing "42": invalid syntax`)
vint, err := v.Int()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vint, 42)
vfloat, err := v.Float()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vfloat, 42.0)
})
t.Run("float", func(t *testing.T) {
v := Value("42.5")
testx.AssertEqual(t, v.IsZero(), false)
testx.AssertEqual(t, v.Bytes(), []byte("42.5"))
testx.AssertEqual(t, v.String(), "42.5")
_, err := v.Bool()
testx.AssertEqual(t, err.Error(), `strconv.ParseBool: parsing "42.5": invalid syntax`)
_, err = v.Int()
testx.AssertEqual(t, err.Error(), `strconv.Atoi: parsing "42.5": invalid syntax`)
vfloat, err := v.Float()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vfloat, 42.5)
})
t.Run("empty string", func(t *testing.T) {
v := Value("")
testx.AssertEqual(t, v.IsZero(), true)
testx.AssertEqual(t, v.Bytes(), []byte{})
testx.AssertEqual(t, v.String(), "")
vbool, err := v.Bool()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vbool, false)
vint, err := v.Int()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vint, 0)
vfloat, err := v.Float()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vfloat, 0.0)
})
t.Run("nil", func(t *testing.T) {
v := Value(nil)
testx.AssertEqual(t, v.IsZero(), true)
testx.AssertEqual(t, v.Bytes(), []byte(nil))
testx.AssertEqual(t, v.String(), "")
vbool, err := v.Bool()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vbool, false)
vint, err := v.Int()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vint, 0)
vfloat, err := v.Float()
testx.AssertNoErr(t, err)
testx.AssertEqual(t, vfloat, 0.0)
})
}

0 comments on commit 4aff4a9

Please sign in to comment.