fix(tinykv): remove expired values from UnmarshalJSON

This commit is contained in:
Alexis Couvreur
2022-11-02 15:27:07 +00:00
parent 62e9d33ec8
commit 411c2b2dc6
3 changed files with 66 additions and 317 deletions

View File

@@ -182,7 +182,7 @@ func (s *SessionsManager) RequestReadySession(names []string, duration time.Dura
}
func (s *SessionsManager) ExpiresAfter(instance *instance.State, duration time.Duration) {
s.store.Put(instance.Name, *instance, tinykv.ExpiresAfter(duration))
s.store.Put(instance.Name, *instance, duration)
}
func (s *SessionState) MarshalJSON() ([]byte, error) {

View File

@@ -2,7 +2,6 @@ package tinykv
import (
"encoding/json"
"fmt"
"sync"
"time"
)
@@ -58,13 +57,11 @@ type entry[T any] struct {
// KV is a registry for values (like/is a concurrent map) with timeout and sliding timeout
type KV[T any] interface {
Delete(k string)
Get(k string) (v T, ok bool)
Keys() (keys []string)
Values() (values []T)
Entries() (entries map[string]entry[T])
Put(k string, v T, options ...PutOption) error
Take(k string) (v T, ok bool)
Put(k string, v T, expiresAfter time.Duration) error
Stop()
MarshalJSON() ([]byte, error)
UnmarshalJSON(b []byte) error
@@ -72,30 +69,6 @@ type KV[T any] interface {
//-----------------------------------------------------------------------------
type putOpt struct {
expiresAfter time.Duration
cas func(interface{}, bool) bool
}
// PutOption extra options for put
type PutOption func(*putOpt)
// ExpiresAfter entry will expire after this time
func ExpiresAfter(expiresAfter time.Duration) PutOption {
return func(opt *putOpt) {
opt.expiresAfter = expiresAfter
}
}
// CAS for performing a compare and swap
func CAS(cas func(oldValue interface{}, found bool) bool) PutOption {
return func(opt *putOpt) {
opt.cas = cas
}
}
//-----------------------------------------------------------------------------
// store is a registry for values (like/is a concurrent map) with timeout and sliding timeout
type store[T any] struct {
onExpire func(k string, v T)
@@ -131,15 +104,6 @@ func (kv *store[T]) Stop() {
kv.stopOnce.Do(func() { close(kv.stop) })
}
// Delete deletes an entry
func (kv *store[T]) Delete(k string) {
kv.mx.Lock()
defer kv.mx.Unlock()
delete(kv.kv, k)
}
// Get gets an entry from KV store
// and if a sliding timeout is set, it will be slided
func (kv *store[T]) Get(k string) (T, bool) {
var zero T
kv.mx.Lock()
@@ -202,23 +166,16 @@ func (kv *store[T]) Entries() (entries map[string]entry[T]) {
}
// Put puts an entry inside kv store with provided options
func (kv *store[T]) Put(k string, v T, options ...PutOption) error {
opt := &putOpt{}
for _, v := range options {
v(opt)
}
func (kv *store[T]) Put(k string, v T, expiresAfter time.Duration) error {
e := &entry[T]{
value: v,
}
kv.mx.Lock()
defer kv.mx.Unlock()
if opt.expiresAfter > 0 {
e.timeout = newTimeout(k, opt.expiresAfter)
timeheapPush(&kv.heap, e.timeout)
}
if opt.cas != nil {
return kv.cas(k, e, opt.cas)
}
e.timeout = newTimeout(k, expiresAfter)
timeheapPush(&kv.heap, e.timeout)
kv.kv[k] = e
return nil
}
@@ -238,29 +195,28 @@ func (e *entry[T]) MarshalJSON() ([]byte, error) {
Value: e.value,
ExpiresAt: e.expiresAt,
})
} else {
return json.Marshal(&struct {
Value T `json:"value"`
}{
Value: e.value,
})
}
return nil, nil
}
type minimalEntry[T any] struct {
Value T
ExpiresAfter time.Duration
expired bool
}
func (kv *store[T]) UnmarshalJSON(b []byte) error {
var result map[string]minimalEntry[T]
var entries map[string]minimalEntry[T]
// Unmarshal or Decode the JSON to the interface.
json.Unmarshal([]byte(b), &result)
if err := json.Unmarshal([]byte(b), &entries); err != nil {
return err
}
for k, v := range result {
kv.Put(k, v.Value, ExpiresAfter(v.ExpiresAfter))
for k, v := range entries {
if !v.expired {
kv.Put(k, v.Value, v.ExpiresAfter)
}
}
return nil
@@ -268,54 +224,25 @@ func (kv *store[T]) UnmarshalJSON(b []byte) error {
func (e *minimalEntry[T]) UnmarshalJSON(b []byte) error {
result := &struct {
entry := &struct {
Value T `json:"value"`
ExpiresAt time.Time `json:"expiresAt"`
}{}
// Unmarshal or Decode the JSON to the interface.
json.Unmarshal([]byte(b), &result)
if err := json.Unmarshal([]byte(b), &entry); err != nil {
return err
}
if result.ExpiresAt.After(time.Now()) {
e.Value = result.Value
e.ExpiresAfter = time.Until(result.ExpiresAt)
if entry.ExpiresAt.After(time.Now()) {
e.Value = entry.Value
e.ExpiresAfter = time.Until(entry.ExpiresAt)
e.expired = false
} else {
e.expired = true
}
return nil
}
func (kv *store[T]) cas(k string, e *entry[T], casFunc func(interface{}, bool) bool) error {
old, ok := kv.kv[k]
var oldValue T
if ok && old != nil {
oldValue = old.value
}
if !casFunc(oldValue, ok) {
return ErrCASCond
}
if ok && old != nil {
if e.timeout != nil {
old.timeout = e.timeout
}
old.value = e.value
e = old
}
kv.kv[k] = e
return nil
}
// Take takes an entry out of kv store
func (kv *store[T]) Take(k string) (T, bool) {
var zero T
kv.mx.Lock()
defer kv.mx.Unlock()
e, ok := kv.kv[k]
if ok {
delete(kv.kv, k)
return e.value, ok
}
return zero, ok
}
//-----------------------------------------------------------------------------
func (kv *store[T]) expireLoop() {
@@ -413,21 +340,3 @@ func notifyExpirations[T any](
})
}
}
//-----------------------------------------------------------------------------
// errors
var (
ErrCASCond = errorf("CAS COND FAILED")
)
//-----------------------------------------------------------------------------
type sentinelErr string
func (v sentinelErr) Error() string { return string(v) }
func errorf(format string, a ...interface{}) error {
return sentinelErr(fmt.Sprintf(format, a...))
}
//-----------------------------------------------------------------------------

View File

@@ -43,12 +43,12 @@ func TestGetPut(t *testing.T) {
rg := New[int](0)
defer rg.Stop()
rg.Put("1", 1)
rg.Put("1", 1, time.Minute*50)
v, ok := rg.Get("1")
assert.True(ok)
assert.Equal(1, v)
rg.Put("2", 2, ExpiresAfter(time.Millisecond*50))
rg.Put("2", 2, time.Millisecond*50)
v, ok = rg.Get("2")
assert.True(ok)
assert.Equal(2, v)
@@ -64,8 +64,8 @@ func TestKeys(t *testing.T) {
rg := New[int](0)
defer rg.Stop()
rg.Put("1", 1)
rg.Put("2", 2)
rg.Put("1", 1, time.Minute*50)
rg.Put("2", 2, time.Minute*50)
keys := rg.Keys()
assert.NotEmpty(keys)
@@ -78,8 +78,8 @@ func TestValues(t *testing.T) {
rg := New[int](0)
defer rg.Stop()
rg.Put("1", 1)
rg.Put("2", 2)
rg.Put("1", 1, time.Minute*50)
rg.Put("2", 2, time.Minute*50)
values := rg.Values()
assert.NotEmpty(values)
@@ -92,9 +92,9 @@ func TestEntries(t *testing.T) {
rg := New[int](0)
defer rg.Stop()
rg.Put("1", 1)
rg.Put("2", 2)
rg.Put("3", 3, ExpiresAfter(time.Minute*50))
rg.Put("1", 1, time.Minute*50)
rg.Put("2", 2, time.Minute*50)
rg.Put("3", 3, time.Minute*50)
entries := rg.Entries()
assert.NotEmpty(entries)
@@ -108,14 +108,12 @@ func TestMarshalJSON(t *testing.T) {
rg := New[int](0)
defer rg.Stop()
rg.Put("1", 1)
rg.Put("2", 2)
rg.Put("3", 3, ExpiresAfter(time.Minute*50))
rg.Put("3", 3, time.Minute*50)
jsonb, err := json.Marshal(rg)
assert.Nil(err)
json := string(jsonb)
assert.Regexp(`{"1":{"value":1},"2":{"value":2},"3":{"value":3,"expiresAt":"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d+Z"}}`, json)
assert.Regexp(`{"3":{"value":3,"expiresAt":"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d+Z"}}`, json)
}
func TestUnmarshalJSON(t *testing.T) {
@@ -130,6 +128,24 @@ func TestUnmarshalJSON(t *testing.T) {
err = json.Unmarshal([]byte(jsons), &rg)
assert.Nil(err)
assert.Len(rg.Entries(), 1)
}
func TestUnmarshalJSONExpired(t *testing.T) {
assert := assert.New(t)
since5Minutes := time.Now().Add(-time.Minute * 5)
since5MinutesJson, err := json.Marshal(since5Minutes)
assert.Nil(err)
jsons := `{"1":{"value":1},"2":{"value":2},"3":{"value":3,"expiresAt":` + string(since5MinutesJson) + `}}`
rg := New[int](0)
defer rg.Stop()
err = json.Unmarshal([]byte(jsons), &rg)
assert.Nil(err)
assert.Empty(rg.Entries())
}
func TestTimeout(t *testing.T) {
@@ -141,7 +157,7 @@ func TestTimeout(t *testing.T) {
rg := New(time.Millisecond*10, notify)
n := 1000
for i := n; i < 2*n; i++ {
rg.Put(strconv.Itoa(i), i, ExpiresAfter(time.Millisecond*10))
rg.Put(strconv.Itoa(i), i, time.Millisecond*10)
}
got := make([]string, n)
OUT01:
@@ -179,31 +195,12 @@ func Test03(t *testing.T) {
})
putAt = time.Now()
kv.Put("1", 1, ExpiresAfter(time.Millisecond*10))
kv.Put("1", 1, time.Millisecond*10)
<-time.After(time.Millisecond * 100)
assert.WithinDuration(putAt, putAt.Add(<-elapsed), time.Millisecond*60)
}
func Test04(t *testing.T) {
assert := assert.New(t)
kv := New(
time.Millisecond*10,
func(k string, v interface{}) {
t.Fatal(k, v)
})
err := kv.Put("1", 1, ExpiresAfter(time.Millisecond*10000))
assert.NoError(err)
<-time.After(time.Millisecond * 50)
kv.Delete("1")
kv.Delete("1")
<-time.After(time.Millisecond * 100)
_, ok := kv.Get("1")
assert.False(ok)
}
func Test05(t *testing.T) {
assert := assert.New(t)
N := 10000
@@ -219,8 +216,7 @@ func Test05(t *testing.T) {
for i := 0; i < N; i++ {
k := fmt.Sprintf("%d", i)
kv.Put(k, fmt.Sprintf("VAL::%v", k),
ExpiresAfter(
time.Millisecond*time.Duration(rnd.Intn(10)+1)))
time.Millisecond*time.Duration(rnd.Intn(10)+1))
}
<-time.After(time.Millisecond * 100)
@@ -231,66 +227,6 @@ func Test05(t *testing.T) {
}
}
func Test07(t *testing.T) {
assert := assert.New(t)
kv := New[int](-1)
kv.Put("1", 1)
v, ok := kv.Take("1")
assert.True(ok)
assert.Equal(1, v)
_, ok = kv.Get("1")
assert.False(ok)
}
func Test08(t *testing.T) {
assert := assert.New(t)
kv := New[interface{}](-1)
err := kv.Put(
"QQG", "G",
CAS(func(interface{}, bool) bool { return true }),
ExpiresAfter(time.Millisecond))
assert.NoError(err)
v, ok := kv.Take("QQG")
assert.True(ok)
assert.Equal("G", v)
}
// ignore new timeouts when cas, and just use the old ones from the old value (if exists)
func Test09IgnoreTimeoutParamsOnCAS(t *testing.T) {
assert := assert.New(t)
key := "QQG"
kv := New[interface{}](time.Millisecond)
err := kv.Put(
key, "G",
CAS(func(interface{}, bool) bool { return true }),
ExpiresAfter(time.Millisecond*30))
assert.NoError(err)
v, ok := kv.Get(key)
assert.True(ok)
assert.Equal("G", v)
<-time.After(time.Millisecond * 20)
err = kv.Put(key, "OK",
CAS(func(currentValue interface{}, found bool) bool {
assert.True(found)
assert.Equal("G", currentValue)
return true
}))
assert.NoError(err)
<-time.After(time.Millisecond * 12)
_, ok = kv.Get(key)
assert.False(ok)
}
func Test11(t *testing.T) {
assert := assert.New(t)
@@ -302,7 +238,7 @@ func Test11(t *testing.T) {
kv := New(time.Millisecond*100, onExpired)
err := kv.Put(
key, "G",
ExpiresAfter(time.Millisecond*15))
time.Millisecond*15)
assert.NoError(err)
<-time.After(time.Millisecond * 10)
@@ -334,7 +270,7 @@ func Test12(t *testing.T) {
kv := New(time.Millisecond*100, onExpired)
err := kv.Put(
key, "G",
ExpiresAfter(time.Millisecond))
time.Millisecond)
assert.NoError(err)
<-time.After(time.Millisecond * 10)
@@ -355,7 +291,7 @@ func Test13(t *testing.T) {
kv := New(time.Millisecond*10, onExpired)
err := kv.Put(
"1", 123,
ExpiresAfter(time.Millisecond))
time.Millisecond)
assert.NoError(err)
<-time.After(time.Millisecond * 50)
@@ -385,7 +321,7 @@ func TestOrdering(t *testing.T) {
for i := 1; i <= 10; i++ {
k := strconv.Itoa(i)
v := i
kv.Put(k, v, ExpiresAfter(time.Millisecond*time.Duration(i)*50))
kv.Put(k, v, time.Millisecond*time.Duration(i)*50)
}
var order = make([]int, 10)
@@ -413,71 +349,6 @@ func TestOrdering(t *testing.T) {
assert.Equal(1, 1)
}
func TestCASOldFound(t *testing.T) {
assert := assert.New(t)
kv := New[interface{}](time.Millisecond * 10)
key := "KEY01"
value := "VALUE01"
err := kv.Put(
key, value,
CAS(func(old interface{}, found bool) bool {
assert.Nil(old)
assert.False(found)
return true
}))
assert.NoError(err)
err = kv.Put(
key, value,
CAS(func(old interface{}, found bool) bool {
assert.Equal(value, old)
assert.True(found)
return true
}))
assert.NoError(err)
kv.Delete(key)
err = kv.Put(
key, value,
CAS(func(old interface{}, found bool) bool {
assert.Nil(old)
assert.False(found)
return true
}))
assert.NoError(err)
v, ok := kv.Take(key)
assert.True(ok)
assert.Equal(value, v)
err = kv.Put(
key, value,
CAS(func(old interface{}, found bool) bool {
assert.Nil(old)
assert.False(found)
return true
}))
assert.NoError(err)
}
func ExampleNew() {
key := "KEY"
value := "VALUE"
kv := New[interface{}](time.Millisecond * 10)
defer kv.Stop()
kv.Put(key, value)
v, ok := kv.Get(key)
if !ok {
// ...
}
fmt.Println(key, v)
kv.Delete(key)
_, ok = kv.Get(key)
fmt.Println(ok)
// Output:
// KEY VALUE
// false
}
func BenchmarkGetNoValue(b *testing.B) {
rg := New[interface{}](-1)
for n := 0; n < b.N; n++ {
@@ -487,7 +358,7 @@ func BenchmarkGetNoValue(b *testing.B) {
func BenchmarkGetValue(b *testing.B) {
rg := New[interface{}](-1)
rg.Put("1", 1)
rg.Put("1", 1, time.Minute*50)
for n := 0; n < b.N; n++ {
rg.Get("1")
}
@@ -495,46 +366,15 @@ func BenchmarkGetValue(b *testing.B) {
func BenchmarkGetSlidingTimeout(b *testing.B) {
rg := New[interface{}](-1)
rg.Put("1", 1, ExpiresAfter(time.Second*10))
rg.Put("1", 1, time.Second*10)
for n := 0; n < b.N; n++ {
rg.Get("1")
}
}
func BenchmarkPutOne(b *testing.B) {
rg := New[interface{}](-1)
for n := 0; n < b.N; n++ {
rg.Put("1", 1)
}
}
func BenchmarkPutN(b *testing.B) {
rg := New[interface{}](-1)
for n := 0; n < b.N; n++ {
k := strconv.Itoa(n)
rg.Put(k, n)
}
}
func BenchmarkPutExpire(b *testing.B) {
rg := New[interface{}](-1)
for n := 0; n < b.N; n++ {
rg.Put("1", 1, ExpiresAfter(time.Second*10))
}
}
func BenchmarkCASTrue(b *testing.B) {
rg := New[interface{}](-1)
rg.Put("1", 1)
for n := 0; n < b.N; n++ {
rg.Put("1", 2, CAS(func(interface{}, bool) bool { return true }))
}
}
func BenchmarkCASFalse(b *testing.B) {
rg := New[interface{}](-1)
rg.Put("1", 1)
for n := 0; n < b.N; n++ {
rg.Put("1", 2, CAS(func(interface{}, bool) bool { return false }))
rg.Put("1", 1, time.Second*10)
}
}