default_test.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package contexts
  2. import (
  3. "github.com/emqx/kuiper/common"
  4. "github.com/emqx/kuiper/xstream/api"
  5. "github.com/emqx/kuiper/xstream/states"
  6. "log"
  7. "os"
  8. "path"
  9. "reflect"
  10. "testing"
  11. )
  12. func TestState(t *testing.T) {
  13. var (
  14. i = 0
  15. ruleId = "testStateRule"
  16. value1 = 21
  17. value2 = "hello"
  18. value3 = "world"
  19. s = map[string]interface{}{
  20. "key1": 21,
  21. "key3": "world",
  22. }
  23. )
  24. //initialization
  25. store, err := states.CreateStore(ruleId, api.AtLeastOnce)
  26. if err != nil {
  27. t.Errorf("Get store for rule %s error: %s", ruleId, err)
  28. return
  29. }
  30. ctx := Background().WithMeta("testStateRule", "op1", store)
  31. defer cleanStateData()
  32. // Do state function
  33. ctx.IncrCounter("key1", 20)
  34. ctx.IncrCounter("key1", 1)
  35. v, err := ctx.GetCounter("key1")
  36. if err != nil {
  37. t.Errorf("%d.Get counter error: %s", i, err)
  38. return
  39. }
  40. if !reflect.DeepEqual(value1, v) {
  41. t.Errorf("%d.Get counter\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, value1, v)
  42. }
  43. err = ctx.PutState("key2", value2)
  44. if err != nil {
  45. t.Errorf("%d.Put state key2 error: %s", i, err)
  46. return
  47. }
  48. err = ctx.PutState("key3", value3)
  49. if err != nil {
  50. t.Errorf("%d.Put state key3 error: %s", i, err)
  51. return
  52. }
  53. v2, err := ctx.GetState("key2")
  54. if err != nil {
  55. t.Errorf("%d.Get state key2 error: %s", i, err)
  56. return
  57. }
  58. if !reflect.DeepEqual(value2, v2) {
  59. t.Errorf("%d.Get state\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, value2, v2)
  60. }
  61. err = ctx.DeleteState("key2")
  62. if err != nil {
  63. t.Errorf("%d.Delete state key2 error: %s", i, err)
  64. return
  65. }
  66. err = ctx.Snapshot()
  67. if err != nil {
  68. t.Errorf("%d.Snapshot error: %s", i, err)
  69. return
  70. }
  71. rs := ctx.(*DefaultContext).snapshot
  72. if !reflect.DeepEqual(s, rs) {
  73. t.Errorf("%d.Snapshot\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, s, rs)
  74. }
  75. }
  76. func cleanStateData() {
  77. dbDir, err := common.GetDataLoc()
  78. if err != nil {
  79. log.Panic(err)
  80. }
  81. c := path.Join(dbDir, "checkpoints")
  82. err = os.RemoveAll(c)
  83. if err != nil {
  84. common.Log.Error(err)
  85. }
  86. }