Quellcode durchsuchen

test(store): kvstore test and bugfix

ngjaying vor 4 Jahren
Ursprung
Commit
9aa9389b86

+ 1 - 1
common/util.go

@@ -343,7 +343,7 @@ func ConvertArray(s []interface{}) []interface{} {
 func SyncMapToMap(sm *sync.Map) map[string]interface{} {
 func SyncMapToMap(sm *sync.Map) map[string]interface{} {
 	m := make(map[string]interface{})
 	m := make(map[string]interface{})
 	sm.Range(func(k interface{}, v interface{}) bool {
 	sm.Range(func(k interface{}, v interface{}) bool {
-		m[k.(string)] = v
+		m[fmt.Sprintf("%v", k)] = v
 		return true
 		return true
 	})
 	})
 	return m
 	return m

+ 6 - 2
xsql/processors/xsql_processor_test.go

@@ -36,10 +36,14 @@ func cleanStateData() {
 	}
 	}
 	c := path.Join(dbDir, "checkpoints")
 	c := path.Join(dbDir, "checkpoints")
 	err = os.RemoveAll(c)
 	err = os.RemoveAll(c)
-	log.Errorf("%s", err)
+	if err != nil {
+		log.Errorf("%s", err)
+	}
 	s := path.Join(dbDir, "sink")
 	s := path.Join(dbDir, "sink")
 	err = os.RemoveAll(s)
 	err = os.RemoveAll(s)
-	log.Errorf("%s", err)
+	if err != nil {
+		log.Errorf("%s", err)
+	}
 }
 }
 
 
 func TestStreamCreateProcessor(t *testing.T) {
 func TestStreamCreateProcessor(t *testing.T) {

+ 89 - 0
xstream/contexts/default_test.go

@@ -0,0 +1,89 @@
+package contexts
+
+import (
+	"github.com/emqx/kuiper/common"
+	"github.com/emqx/kuiper/xstream/api"
+	"github.com/emqx/kuiper/xstream/states"
+	"log"
+	"os"
+	"path"
+	"reflect"
+	"testing"
+)
+
+func TestState(t *testing.T) {
+	var (
+		i      = 0
+		ruleId = "testStateRule"
+		value1 = 21
+		value2 = "hello"
+		value3 = "world"
+		s      = map[string]interface{}{
+			"key1": 21,
+			"key3": "world",
+		}
+	)
+	//initialization
+	store, err := states.CreateStore(ruleId, api.AtLeastOnce)
+	if err != nil {
+		t.Errorf("Get store for rule %s error: %s", ruleId, err)
+		return
+	}
+	ctx := Background().WithMeta("testStateRule", "op1", store)
+	defer cleanStateData()
+	// Do state function
+	ctx.IncrCounter("key1", 20)
+	ctx.IncrCounter("key1", 1)
+	v, err := ctx.GetCounter("key1")
+	if err != nil {
+		t.Errorf("%d.Get counter error: %s", i, err)
+		return
+	}
+	if !reflect.DeepEqual(value1, v) {
+		t.Errorf("%d.Get counter\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, value1, v)
+	}
+	err = ctx.PutState("key2", value2)
+	if err != nil {
+		t.Errorf("%d.Put state key2 error: %s", i, err)
+		return
+	}
+	err = ctx.PutState("key3", value3)
+	if err != nil {
+		t.Errorf("%d.Put state key3 error: %s", i, err)
+		return
+	}
+	v2, err := ctx.GetState("key2")
+	if err != nil {
+		t.Errorf("%d.Get state key2 error: %s", i, err)
+		return
+	}
+	if !reflect.DeepEqual(value2, v2) {
+		t.Errorf("%d.Get state\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, value2, v2)
+	}
+	err = ctx.DeleteState("key2")
+	if err != nil {
+		t.Errorf("%d.Delete state key2 error: %s", i, err)
+		return
+	}
+	err = ctx.Snapshot()
+	if err != nil {
+		t.Errorf("%d.Snapshot error: %s", i, err)
+		return
+	}
+	rs := ctx.(*DefaultContext).snapshot
+	if !reflect.DeepEqual(s, rs) {
+		t.Errorf("%d.Snapshot\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, s, rs)
+	}
+}
+
+func cleanStateData() {
+	dbDir, err := common.GetDataLoc()
+	if err != nil {
+		log.Panic(err)
+	}
+	c := path.Join(dbDir, "checkpoints")
+	err = os.RemoveAll(c)
+	if err != nil {
+		common.Log.Error(err)
+	}
+}

+ 31 - 23
xstream/states/kv_store.go

@@ -37,7 +37,7 @@ type KVStore struct {
 func getKVStore(ruleId string) (*KVStore, error) {
 func getKVStore(ruleId string) (*KVStore, error) {
 	dr, _ := common.GetDataLoc()
 	dr, _ := common.GetDataLoc()
 	db := common.GetSimpleKVStore(path.Join(dr, "checkpoints", ruleId))
 	db := common.GetSimpleKVStore(path.Join(dr, "checkpoints", ruleId))
-	s := &KVStore{db: db, max: 3}
+	s := &KVStore{db: db, max: 3, mapStore: &sync.Map{}}
 	//read data from badger db
 	//read data from badger db
 	if err := s.restore(); err != nil {
 	if err := s.restore(); err != nil {
 		return nil, err
 		return nil, err
@@ -56,18 +56,17 @@ func (s *KVStore) restore() error {
 			return fmt.Errorf("invalid checkpoint data: %s", err)
 			return fmt.Errorf("invalid checkpoint data: %s", err)
 		} else {
 		} else {
 			s.checkpoints = cs
 			s.checkpoints = cs
-			if bytes, ok := s.db.Get(string(cs[len(cs)-1])); ok {
-				if m, err := bytesToMap(bytes.([]byte)); err != nil {
-					return fmt.Errorf("invalid last checkpoint data: %s", err)
-				} else {
-					s.mapStore = m
-					return nil
+			for _, c := range cs {
+				if bytes, ok := s.db.Get(string(c)); ok {
+					if m, err := bytesToMap(bytes.([]byte)); err != nil {
+						return fmt.Errorf("invalid checkpoint data: %s", err)
+					} else {
+						s.mapStore.Store(c, common.MapToSyncMap(m))
+					}
 				}
 				}
 			}
 			}
 		}
 		}
-
 	}
 	}
-	s.mapStore = &sync.Map{}
 	return nil
 	return nil
 }
 }
 
 
@@ -103,7 +102,7 @@ func (s *KVStore) SaveCheckpoint(checkpointId int64) error {
 			if err != nil {
 			if err != nil {
 				return fmt.Errorf("save checkpoint err, fail to encode states: %s", err)
 				return fmt.Errorf("save checkpoint err, fail to encode states: %s", err)
 			}
 			}
-			err = s.db.Set(string(checkpointId), b)
+			err = s.db.Replace(string(checkpointId), b)
 			if err != nil {
 			if err != nil {
 				return fmt.Errorf("save checkpoint err: %v", err)
 				return fmt.Errorf("save checkpoint err: %v", err)
 			}
 			}
@@ -121,7 +120,7 @@ func (s *KVStore) SaveCheckpoint(checkpointId int64) error {
 			if !ok {
 			if !ok {
 				return fmt.Errorf("save checkpoint err: fail to encode checkpoint counts")
 				return fmt.Errorf("save checkpoint err: fail to encode checkpoint counts")
 			}
 			}
-			err = s.db.Set(CheckpointListKey, cs)
+			err = s.db.Replace(CheckpointListKey, cs)
 			if err != nil {
 			if err != nil {
 				return fmt.Errorf("save checkpoint err: %v", err)
 				return fmt.Errorf("save checkpoint err: %v", err)
 			}
 			}
@@ -132,18 +131,27 @@ func (s *KVStore) SaveCheckpoint(checkpointId int64) error {
 
 
 //Only run in the initialization
 //Only run in the initialization
 func (s *KVStore) GetOpState(opId string) (*sync.Map, error) {
 func (s *KVStore) GetOpState(opId string) (*sync.Map, error) {
-	if sm, ok := s.mapStore.Load(opId); ok {
-		switch m := sm.(type) {
-		case *sync.Map:
-			return m, nil
-		case map[string]interface{}:
-			return common.MapToSyncMap(m), nil
-		default:
-			return nil, fmt.Errorf("invalid state %v stored for op %s: data type is not *sync.Map", sm, opId)
+	if len(s.checkpoints) > 0 {
+		if v, ok := s.mapStore.Load(s.checkpoints[len(s.checkpoints)-1]); ok {
+			if cstore, ok := v.(*sync.Map); !ok {
+				return nil, fmt.Errorf("invalid state %v stored for op %s: data type is not *sync.Map", v, opId)
+			} else {
+				if sm, ok := cstore.Load(opId); ok {
+					switch m := sm.(type) {
+					case *sync.Map:
+						return m, nil
+					case map[string]interface{}:
+						return common.MapToSyncMap(m), nil
+					default:
+						return nil, fmt.Errorf("invalid state %v stored for op %s: data type is not *sync.Map", sm, opId)
+					}
+				}
+			}
+		} else {
+			return nil, fmt.Errorf("store for checkpoint %d not found", s.checkpoints[len(s.checkpoints)-1])
 		}
 		}
-	} else {
-		return &sync.Map{}, nil
 	}
 	}
+	return &sync.Map{}, nil
 }
 }
 
 
 func mapToBytes(sm *sync.Map) ([]byte, error) {
 func mapToBytes(sm *sync.Map) ([]byte, error) {
@@ -156,14 +164,14 @@ func mapToBytes(sm *sync.Map) ([]byte, error) {
 	return buf.Bytes(), nil
 	return buf.Bytes(), nil
 }
 }
 
 
-func bytesToMap(input []byte) (*sync.Map, error) {
+func bytesToMap(input []byte) (map[string]interface{}, error) {
 	var result map[string]interface{}
 	var result map[string]interface{}
 	buf := bytes.NewBuffer(input)
 	buf := bytes.NewBuffer(input)
 	dec := gob.NewDecoder(buf)
 	dec := gob.NewDecoder(buf)
 	if err := dec.Decode(&result); err != nil {
 	if err := dec.Decode(&result); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return common.MapToSyncMap(result), nil
+	return result, nil
 }
 }
 
 
 func sliceToBytes(s []int64) ([]byte, bool) {
 func sliceToBytes(s []int64) ([]byte, bool) {

+ 258 - 0
xstream/states/kv_store_test.go

@@ -0,0 +1,258 @@
+package states
+
+import (
+	"fmt"
+	"github.com/emqx/kuiper/common"
+	"log"
+	"os"
+	"path"
+	"reflect"
+	"sync"
+	"testing"
+)
+
+func TestLifecycle(t *testing.T) {
+	var (
+		i             = 0
+		ruleId        = "test1"
+		checkpointIds = []int64{1, 2, 3}
+		opIds         = []string{"op1", "op2", "op3"}
+		r             = map[string]interface{}{
+			"1": map[string]interface{}{
+				"op1": map[string]interface{}{
+					"op": "op1",
+					"oi": 0,
+					"ci": 0,
+				},
+				"op2": map[string]interface{}{
+					"op": "op2",
+					"oi": 1,
+					"ci": 0,
+				},
+				"op3": map[string]interface{}{
+					"op": "op3",
+					"oi": 2,
+					"ci": 0,
+				},
+			},
+			"2": map[string]interface{}{
+				"op1": map[string]interface{}{
+					"op": "op1",
+					"oi": 0,
+					"ci": 1,
+				},
+				"op2": map[string]interface{}{
+					"op": "op2",
+					"oi": 1,
+					"ci": 1,
+				},
+				"op3": map[string]interface{}{
+					"op": "op3",
+					"oi": 2,
+					"ci": 1,
+				},
+			},
+			"3": map[string]interface{}{
+				"op1": map[string]interface{}{
+					"op": "op1",
+					"oi": 0,
+					"ci": 2,
+				},
+				"op2": map[string]interface{}{
+					"op": "op2",
+					"oi": 1,
+					"ci": 2,
+				},
+				"op3": map[string]interface{}{
+					"op": "op3",
+					"oi": 2,
+					"ci": 2,
+				},
+			},
+		}
+		rm = map[string]interface{}{
+			"1": map[string]interface{}{
+				"op1": map[string]interface{}{
+					"op": "op1",
+					"oi": 0,
+					"ci": 0,
+				},
+				"op2": map[string]interface{}{
+					"op": "op2",
+					"oi": 1,
+					"ci": 0,
+				},
+				"op3": map[string]interface{}{
+					"op": "op3",
+					"oi": 2,
+					"ci": 0,
+				},
+			},
+			"2": map[string]interface{}{
+				"op1": map[string]interface{}{
+					"op": "op1",
+					"oi": 0,
+					"ci": 1,
+				},
+				"op2": map[string]interface{}{
+					"op": "op2",
+					"oi": 1,
+					"ci": 1,
+				},
+				"op3": map[string]interface{}{
+					"op": "op3",
+					"oi": 2,
+					"ci": 1,
+				},
+			},
+			"3": map[string]interface{}{
+				"op1": map[string]interface{}{
+					"op": "op1",
+					"oi": 0,
+					"ci": 2,
+				},
+				"op2": map[string]interface{}{
+					"op": "op2",
+					"oi": 1,
+					"ci": 2,
+				},
+				"op3": map[string]interface{}{
+					"op": "op3",
+					"oi": 2,
+					"ci": 2,
+				},
+			},
+			"10000": map[string]interface{}{
+				"op2": map[string]interface{}{
+					"op": "op2",
+					"oi": 1,
+					"ci": 10000,
+				},
+				"op3": map[string]interface{}{
+					"op": "op3",
+					"oi": 2,
+					"ci": 10000,
+				},
+			},
+		}
+	)
+	func() {
+		defer cleanStateData()
+		store, err := getKVStore(ruleId)
+		if err != nil {
+			t.Errorf("Get store for rule %s error: %s", ruleId, err)
+			return
+		}
+		//Save for all checkpoints
+		for i, cid := range checkpointIds {
+			for j, opId := range opIds {
+				err := store.SaveState(cid, opId, map[string]interface{}{
+					"op": opId,
+					"oi": j,
+					"ci": i,
+				})
+				if err != nil {
+					t.Errorf("Save state for rule %s op %s error: %s", ruleId, opId, err)
+					return
+				}
+			}
+			err := store.SaveCheckpoint(cid)
+			if err != nil {
+				t.Errorf("Save checkpoint %d for rule %s error: %s", cid, ruleId, err)
+				return
+			}
+		}
+		// compare checkpoints
+		if !reflect.DeepEqual(checkpointIds, store.checkpoints) {
+			t.Errorf("%d.Save checkpoint\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, checkpointIds, store.checkpoints)
+		}
+		// compare contents
+		result := mapStoreToMap(store.mapStore)
+		if !reflect.DeepEqual(r, result) {
+			t.Errorf("%d.Save checkpoint\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, r, result)
+		}
+		//Save additional state but not serialized in checkpoint
+		err = store.SaveState(10000, opIds[1], map[string]interface{}{
+			"op": opIds[1],
+			"oi": 1,
+			"ci": 10000,
+		})
+		if err != nil {
+			t.Errorf("Save state for rule %s op %s error: %s", ruleId, opIds[1], err)
+			return
+		}
+		err = store.SaveState(10000, opIds[2], map[string]interface{}{
+			"op": opIds[2],
+			"oi": 2,
+			"ci": 10000,
+		})
+		if err != nil {
+			t.Errorf("Save state for rule %s op %s error: %s", ruleId, opIds[2], err)
+			return
+		}
+		// compare checkpoints
+		if !reflect.DeepEqual(checkpointIds, store.checkpoints) {
+			t.Errorf("%d.Save state\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, checkpointIds, store.checkpoints)
+		}
+		// compare contents
+		result = mapStoreToMap(store.mapStore)
+		if !reflect.DeepEqual(rm, result) {
+			t.Errorf("%d.Save state\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, r, result)
+		}
+		//simulate restore
+		store = nil
+		store, err = getKVStore(ruleId)
+		if err != nil {
+			t.Errorf("Restore store for rule %s error: %s", ruleId, err)
+			return
+		}
+		// compare checkpoints
+		if !reflect.DeepEqual(checkpointIds, store.checkpoints) {
+			t.Errorf("%d.Restore checkpoint\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, checkpointIds, store.checkpoints)
+			return
+		}
+		// compare contents
+		result = mapStoreToMap(store.mapStore)
+		if !reflect.DeepEqual(r, result) {
+			t.Errorf("%d.Restore checkpoint\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, r, result)
+			return
+		}
+		ns, err := store.GetOpState(opIds[1])
+		if err != nil {
+			t.Errorf("Get op %s state for rule %s error: %s", opIds[1], ruleId, err)
+			return
+		}
+		sm := r[fmt.Sprintf("%v", checkpointIds[len(checkpointIds)-1])].(map[string]interface{})[opIds[1]]
+		nsm := common.SyncMapToMap(ns)
+		if !reflect.DeepEqual(sm, nsm) {
+			t.Errorf("%d.Restore op state\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, sm, nsm)
+			return
+		}
+	}()
+}
+
+func mapStoreToMap(sm *sync.Map) map[string]interface{} {
+	m := make(map[string]interface{})
+	sm.Range(func(k interface{}, v interface{}) bool {
+		switch t := v.(type) {
+		case *sync.Map:
+			m[fmt.Sprintf("%v", k)] = mapStoreToMap(t)
+		default:
+			m[fmt.Sprintf("%v", k)] = t
+		}
+		return true
+	})
+	return m
+}
+
+func cleanStateData() {
+	dbDir, err := common.GetDataLoc()
+	if err != nil {
+		log.Panic(err)
+	}
+	c := path.Join(dbDir, "checkpoints")
+	err = os.RemoveAll(c)
+	if err != nil {
+		common.Log.Error(err)
+	}
+}