barrier_handler.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package checkpoint
  2. import (
  3. "github.com/lf-edge/ekuiper/pkg/api"
  4. )
  5. type BarrierHandler interface {
  6. Process(data *BufferOrEvent, ctx api.StreamContext) bool //If data is barrier return true, else return false
  7. SetOutput(chan<- *BufferOrEvent) //It is using for block a channel
  8. }
  9. //For qos 1, simple track barriers
  10. type BarrierTracker struct {
  11. responder Responder
  12. inputCount int
  13. pendingCheckpoints map[int64]int
  14. }
  15. func NewBarrierTracker(responder Responder, inputCount int) *BarrierTracker {
  16. return &BarrierTracker{
  17. responder: responder,
  18. inputCount: inputCount,
  19. pendingCheckpoints: make(map[int64]int),
  20. }
  21. }
  22. func (h *BarrierTracker) Process(data *BufferOrEvent, ctx api.StreamContext) bool {
  23. d := data.Data
  24. if b, ok := d.(*Barrier); ok {
  25. h.processBarrier(b, ctx)
  26. return true
  27. }
  28. return false
  29. }
  30. func (h *BarrierTracker) SetOutput(_ chan<- *BufferOrEvent) {
  31. //do nothing, does not need it
  32. }
  33. func (h *BarrierTracker) processBarrier(b *Barrier, ctx api.StreamContext) {
  34. logger := ctx.GetLogger()
  35. if h.inputCount == 1 {
  36. err := h.responder.TriggerCheckpoint(b.CheckpointId)
  37. if err != nil {
  38. logger.Errorf("trigger checkpoint for %s err: %s", h.responder.GetName(), err)
  39. }
  40. return
  41. }
  42. if c, ok := h.pendingCheckpoints[b.CheckpointId]; ok {
  43. c += 1
  44. if c == h.inputCount {
  45. err := h.responder.TriggerCheckpoint(b.CheckpointId)
  46. if err != nil {
  47. logger.Errorf("trigger checkpoint for %s err: %s", h.responder.GetName(), err)
  48. return
  49. }
  50. delete(h.pendingCheckpoints, b.CheckpointId)
  51. for cid := range h.pendingCheckpoints {
  52. if cid < b.CheckpointId {
  53. delete(h.pendingCheckpoints, cid)
  54. }
  55. }
  56. } else {
  57. h.pendingCheckpoints[b.CheckpointId] = c
  58. }
  59. } else {
  60. h.pendingCheckpoints[b.CheckpointId] = 1
  61. }
  62. }
  63. //For qos 2, block an input until all barriers are received
  64. type BarrierAligner struct {
  65. responder Responder
  66. inputCount int
  67. currentCheckpointId int64
  68. output chan<- *BufferOrEvent
  69. blockedChannels map[string]bool
  70. buffer []*BufferOrEvent
  71. }
  72. func NewBarrierAligner(responder Responder, inputCount int) *BarrierAligner {
  73. ba := &BarrierAligner{
  74. responder: responder,
  75. inputCount: inputCount,
  76. blockedChannels: make(map[string]bool),
  77. }
  78. return ba
  79. }
  80. func (h *BarrierAligner) Process(data *BufferOrEvent, ctx api.StreamContext) bool {
  81. switch d := data.Data.(type) {
  82. case *Barrier:
  83. h.processBarrier(d, ctx)
  84. return true
  85. default:
  86. //If blocking, save to buffer
  87. if h.inputCount > 1 && len(h.blockedChannels) > 0 {
  88. if _, ok := h.blockedChannels[data.Channel]; ok {
  89. h.buffer = append(h.buffer, data)
  90. return true
  91. }
  92. }
  93. }
  94. return false
  95. }
  96. func (h *BarrierAligner) processBarrier(b *Barrier, ctx api.StreamContext) {
  97. logger := ctx.GetLogger()
  98. logger.Debugf("Aligner process barrier %+v", b)
  99. if h.inputCount == 1 {
  100. if b.CheckpointId > h.currentCheckpointId {
  101. h.currentCheckpointId = b.CheckpointId
  102. err := h.responder.TriggerCheckpoint(b.CheckpointId)
  103. if err != nil {
  104. logger.Errorf("trigger checkpoint for %s err: %s", h.responder.GetName(), err)
  105. }
  106. }
  107. return
  108. }
  109. if len(h.blockedChannels) > 0 {
  110. if b.CheckpointId == h.currentCheckpointId {
  111. h.onBarrier(b.OpId, ctx)
  112. } else if b.CheckpointId > h.currentCheckpointId {
  113. logger.Infof("Received checkpoint barrier for checkpoint %d before complete current checkpoint %d. Skipping current checkpoint.", b.CheckpointId, h.currentCheckpointId)
  114. //TODO Abort checkpoint
  115. h.releaseBlocksAndResetBarriers()
  116. h.beginNewAlignment(b, ctx)
  117. } else {
  118. return
  119. }
  120. } else if b.CheckpointId > h.currentCheckpointId {
  121. logger.Debugf("Aligner process new alignment", b)
  122. h.beginNewAlignment(b, ctx)
  123. } else {
  124. return
  125. }
  126. if len(h.blockedChannels) == h.inputCount {
  127. logger.Debugf("Received all barriers, triggering checkpoint %d", b.CheckpointId)
  128. err := h.responder.TriggerCheckpoint(b.CheckpointId)
  129. if err != nil {
  130. logger.Errorf("trigger checkpoint for %s err: %s", h.responder.GetName(), err)
  131. return
  132. }
  133. h.releaseBlocksAndResetBarriers()
  134. // clean up all the buffer
  135. var temp []*BufferOrEvent
  136. for _, d := range h.buffer {
  137. temp = append(temp, d)
  138. }
  139. go func() {
  140. for _, d := range temp {
  141. h.output <- d
  142. }
  143. }()
  144. h.buffer = make([]*BufferOrEvent, 0)
  145. }
  146. }
  147. func (h *BarrierAligner) onBarrier(name string, ctx api.StreamContext) {
  148. logger := ctx.GetLogger()
  149. if _, ok := h.blockedChannels[name]; !ok {
  150. h.blockedChannels[name] = true
  151. logger.Debugf("Received barrier from channel %s", name)
  152. }
  153. }
  154. func (h *BarrierAligner) SetOutput(output chan<- *BufferOrEvent) {
  155. h.output = output
  156. }
  157. func (h *BarrierAligner) releaseBlocksAndResetBarriers() {
  158. h.blockedChannels = make(map[string]bool)
  159. }
  160. func (h *BarrierAligner) beginNewAlignment(barrier *Barrier, ctx api.StreamContext) {
  161. logger := ctx.GetLogger()
  162. h.currentCheckpointId = barrier.CheckpointId
  163. h.onBarrier(barrier.OpId, ctx)
  164. logger.Debugf("Starting stream alignment for checkpoint %d", barrier.CheckpointId)
  165. }