Przeglądaj źródła

feat(extension): context and sink refactor. stream compose refactor

ngjaying 5 lat temu
rodzic
commit
1aaf875785

+ 0 - 12
common/util.go

@@ -2,7 +2,6 @@ package common
 
 import (
 	"bytes"
-	"context"
 	"fmt"
 	"github.com/dgraph-io/badger"
 	"github.com/go-yaml/yaml"
@@ -15,7 +14,6 @@ import (
 
 const (
 	logFileName = "stream.log"
-	LoggerKey = "logger"
 	etc_dir = "/etc/"
 	data_dir = "/data/"
 	log_dir = "/log/"
@@ -51,16 +49,6 @@ func (l *logRedirect) Debugf(f string, v ...interface{}) {
 	Log.Debug(fmt.Sprintf(f, v...))
 }
 
-func GetLogger(ctx context.Context) *logrus.Entry {
-	if ctx != nil{
-		l, ok := ctx.Value(LoggerKey).(*logrus.Entry)
-		if l != nil && ok {
-			return l
-		}
-	}
-	return Log.WithField("caller", "default")
-}
-
 func LoadConf(confName string) []byte {
 	confDir, err := GetConfLoc()
 	if err != nil {

+ 3 - 4
xsql/plans/aggregate_operator.go

@@ -1,9 +1,8 @@
 package plans
 
 import (
-	"context"
-	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 	"fmt"
 )
 
@@ -16,8 +15,8 @@ type AggregatePlan struct {
  *  input: *xsql.Tuple from preprocessor | xsql.WindowTuplesSet from windowOp | xsql.JoinTupleSets from joinOp
  *  output: xsql.GroupedTuplesSet
  */
-func (p *AggregatePlan) Apply(ctx context.Context, data interface{}) interface{} {
-	log := common.GetLogger(ctx)
+func (p *AggregatePlan) Apply(ctx api.StreamContext, data interface{}) interface{} {
+	log := ctx.GetLogger()
 	log.Debugf("aggregate plan receive %s", data)
 	var ms []xsql.DataValuer
 	switch input := data.(type) {

+ 3 - 4
xsql/plans/filter_operator.go

@@ -1,9 +1,8 @@
 package plans
 
 import (
-	"context"
-	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 )
 
 type FilterPlan struct {
@@ -14,8 +13,8 @@ type FilterPlan struct {
   *  input: *xsql.Tuple from preprocessor | xsql.WindowTuplesSet from windowOp | xsql.JoinTupleSets from joinOp
   *  output: *xsql.Tuple | xsql.WindowTuplesSet | xsql.JoinTupleSets
  */
-func (p *FilterPlan) Apply(ctx context.Context, data interface{}) interface{} {
-	log := common.GetLogger(ctx)
+func (p *FilterPlan) Apply(ctx api.StreamContext, data interface{}) interface{} {
+	log := ctx.GetLogger()
 	log.Debugf("filter plan receive %s", data)
 	switch input := data.(type) {
 	case xsql.Valuer:

+ 3 - 3
xsql/plans/join_operator.go

@@ -1,9 +1,9 @@
 package plans
 
 import (
-	"context"
 	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 	"fmt"
 )
 
@@ -15,8 +15,8 @@ type JoinPlan struct {
 
 // input:  xsql.WindowTuplesSet from windowOp, window is required for join
 // output: xsql.JoinTupleSets
-func (jp *JoinPlan) Apply(ctx context.Context, data interface{}) interface{} {
-	log := common.GetLogger(ctx)
+func (jp *JoinPlan) Apply(ctx api.StreamContext, data interface{}) interface{} {
+	log := ctx.GetLogger()
 	var input xsql.WindowTuplesSet
 	if d, ok := data.(xsql.WindowTuplesSet); !ok {
 		log.Errorf("Expect WindowTuplesSet type.\n")

+ 3 - 4
xsql/plans/order_operator.go

@@ -1,9 +1,8 @@
 package plans
 
 import (
-	"context"
-	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 )
 
 type OrderPlan struct {
@@ -14,8 +13,8 @@ type OrderPlan struct {
   *  input: *xsql.Tuple from preprocessor | xsql.WindowTuplesSet from windowOp | xsql.JoinTupleSets from joinOp
   *  output: *xsql.Tuple | xsql.WindowTuplesSet | xsql.JoinTupleSets
  */
-func (p *OrderPlan) Apply(ctx context.Context, data interface{}) interface{} {
-	log := common.GetLogger(ctx)
+func (p *OrderPlan) Apply(ctx api.StreamContext, data interface{}) interface{} {
+	log := ctx.GetLogger()
 	log.Debugf("order plan receive %s", data)
 	sorter := xsql.OrderedBy(p.SortFields)
 	switch input := data.(type) {

+ 3 - 3
xsql/plans/preprocessor.go

@@ -1,9 +1,9 @@
 package plans
 
 import (
-	"context"
 	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 	"fmt"
 	"reflect"
 	"strings"
@@ -38,8 +38,8 @@ func NewPreprocessor(s *xsql.StreamStmt, fs xsql.Fields, iet bool) (*Preprocesso
  *	input: *xsql.Tuple
  *	output: *xsql.Tuple
  */
-func (p *Preprocessor) Apply(ctx context.Context, data interface{}) interface{} {
-	log := common.GetLogger(ctx)
+func (p *Preprocessor) Apply(ctx api.StreamContext, data interface{}) interface{} {
+	log := ctx.GetLogger()
 	tuple, ok := data.(*xsql.Tuple)
 	if !ok {
 		log.Errorf("Expect tuple data type")

+ 3 - 4
xsql/plans/project_operator.go

@@ -1,10 +1,9 @@
 package plans
 
 import (
-	"context"
 	"encoding/json"
-	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 	"fmt"
 	"strconv"
 	"strings"
@@ -19,8 +18,8 @@ type ProjectPlan struct {
  *  input: *xsql.Tuple from preprocessor or filterOp | xsql.WindowTuplesSet from windowOp or filterOp | xsql.JoinTupleSets from joinOp or filterOp
  *  output: []map[string]interface{}
  */
-func (pp *ProjectPlan) Apply(ctx context.Context, data interface{}) interface{} {
-	log := common.GetLogger(ctx)
+func (pp *ProjectPlan) Apply(ctx api.StreamContext, data interface{}) interface{} {
+	log := ctx.GetLogger()
 	log.Debugf("project plan receive %s", data)
 	var results []map[string]interface{}
 	switch input := data.(type) {

+ 4 - 4
xsql/processors/xsql_processor.go

@@ -219,13 +219,13 @@ func (p *RuleProcessor) ExecInitRule(rule *api.Rule) (*xstream.TopologyNew, erro
 				switch name {
 				case "log":
 					log.Printf("Create log sink with %s", action)
-					tp.AddSink(inputs, sinks.NewLogSink("sink_log", rule.Id))
+					tp.AddSink(inputs, nodes.NewSinkNode("sink_log", sinks.NewLogSink()))
 				case "mqtt":
 					log.Printf("Create mqtt sink with %s", action)
-					if ms, err := sinks.NewMqttSink("mqtt_log", rule.Id, action); err != nil{
+					if ms, err := sinks.NewMqttSink(action); err != nil{
 						return nil, err
 					}else{
-						tp.AddSink(inputs, ms)
+						tp.AddSink(inputs, nodes.NewSinkNode("sink_mqtt", ms))
 					}
 				default:
 					return nil, fmt.Errorf("unsupported action: %s", name)
@@ -240,7 +240,7 @@ func (p *RuleProcessor) ExecQuery(ruleid, sql string) (*xstream.TopologyNew, err
 	if tp, inputs, err := p.createTopo(&api.Rule{Id: ruleid, Sql: sql}); err != nil {
 		return nil, err
 	} else {
-		tp.AddSink(inputs, sinks.NewLogSinkToMemory("sink_log", ruleid))
+		tp.AddSink(inputs, nodes.NewSinkNode("sink_memory_log", sinks.NewLogSinkToMemory()))
 		go func() {
 			select {
 			case err := <-tp.Open():

+ 9 - 6
xsql/processors/xsql_processor_test.go

@@ -427,7 +427,8 @@ func TestSingleSQL(t *testing.T) {
 		if err != nil{
 			t.Error(err)
 		}
-		sink := test.NewMockSink("mockSink", tt.name)
+		mockSink := test.NewMockSink()
+		sink := nodes.NewSinkNode("MockSink", mockSink)
 		tp.AddSink(inputs, sink)
 		count := len(sources)
 		errCh := tp.Open()
@@ -451,7 +452,7 @@ func TestSingleSQL(t *testing.T) {
 				}
 			}
 		}()
-		results := sink.GetResults()
+		results := mockSink.GetResults()
 		var maps [][]map[string]interface{}
 		for _, v := range results{
 			var mapRes []map[string]interface{}
@@ -691,7 +692,8 @@ func TestWindow(t *testing.T) {
 		if err != nil{
 			t.Error(err)
 		}
-		sink := test.NewMockSink("mockSink", tt.name)
+		mockSink := test.NewMockSink()
+		sink := nodes.NewSinkNode("mockSink", mockSink)
 		tp.AddSink(inputs, sink)
 		count := len(sources)
 		errCh := tp.Open()
@@ -715,7 +717,7 @@ func TestWindow(t *testing.T) {
 				}
 			}
 		}()
-		results := sink.GetResults()
+		results := mockSink.GetResults()
 		var maps [][]map[string]interface{}
 		for _, v := range results{
 			var mapRes []map[string]interface{}
@@ -1240,7 +1242,8 @@ func TestEventWindow(t *testing.T) {
 		if err != nil{
 			t.Error(err)
 		}
-		sink := test.NewMockSink("mockSink", tt.name)
+		mockSink := test.NewMockSink()
+		sink := nodes.NewSinkNode("MockSink", mockSink)
 		tp.AddSink(inputs, sink)
 		count := len(sources)
 		errCh := tp.Open()
@@ -1264,7 +1267,7 @@ func TestEventWindow(t *testing.T) {
 				}
 			}
 		}()
-		results := sink.GetResults()
+		results := mockSink.GetResults()
 		var maps [][]map[string]interface{}
 		for _, v := range results{
 			var mapRes []map[string]interface{}

+ 19 - 12
xstream/api/stream.go

@@ -7,8 +7,21 @@ import (
 
 type ConsumeFunc func(data interface{})
 
+type Closable interface {
+	Close(StreamContext) error
+}
+
 type Source interface {
-	Open(context StreamContext, consume ConsumeFunc) error
+	//Should be sync function for normal case. The container will run it in go func
+	Open(StreamContext, ConsumeFunc) error
+	Closable
+}
+
+type Sink interface {
+	//Should be sync function for normal case. The container will run it in go func
+	Open(StreamContext) error
+	Collect(StreamContext, interface{}) error
+	Closable
 }
 
 type Emitter interface {
@@ -31,24 +44,18 @@ type Rule struct {
 }
 
 type StreamContext interface {
-	GetContext() context.Context
+	context.Context
 	GetLogger()  *logrus.Entry
 	GetRuleId() string
 	GetOpId() string
-}
-
-type SinkConnector interface {
-	Open(context.Context, chan<- error)
-}
-
-type Sink interface {
-	Collector
-	SinkConnector
+	WithMeta(ruleId string, opId string) StreamContext
+	WithCancel() (StreamContext, context.CancelFunc)
 }
 
 type Operator interface {
 	Emitter
 	Collector
-	Exec(context context.Context) error
+	Exec(StreamContext, chan<- error)
+	GetName() string
 }
 

+ 15 - 37
xstream/collectors/func.go

@@ -1,67 +1,45 @@
 package collectors
 
 import (
-	"context"
-	"engine/common"
+	"engine/xstream/api"
 	"errors"
 )
 
 // CollectorFunc is a function used to colllect
 // incoming stream data. It can be used as a
 // stream sink.
-type CollectorFunc func(context.Context, interface{}) error
+type CollectorFunc func(api.StreamContext, interface{}) error
 
 // FuncCollector is a colletor that uses a function
 // to collect data.  The specified function must be
 // of type:
 //   CollectorFunc
 type FuncCollector struct {
-	input chan interface{}
-	//logf  api.LogFunc
-	//errf  api.ErrorFunc
 	f     CollectorFunc
-	name  string
 }
 
 // Func creates a new value *FuncCollector that
 // will use the specified function parameter to
 // collect streaming data.
-func Func(name string, f CollectorFunc) *FuncCollector {
-	return &FuncCollector{f: f, name:name, input: make(chan interface{}, 1024)}
-}
-
-func (c *FuncCollector) GetName() string  {
-	return c.name
-}
-
-func (c *FuncCollector) GetInput() (chan<- interface{}, string)  {
-	return c.input, c.name
+func Func(f CollectorFunc) *FuncCollector {
+	return &FuncCollector{f: f}
 }
 
 // Open is the starting point that starts the collector
-func (c *FuncCollector) Open(ctx context.Context, result chan<- error) {
-	//c.logf = autoctx.GetLogFunc(ctx)
-	//c.errf = autoctx.GetErrFunc(ctx)
-	log := common.GetLogger(ctx)
+func (c *FuncCollector) Open(ctx api.StreamContext) error {
+	log := ctx.GetLogger()
 	log.Println("Opening func collector")
 
 	if c.f == nil {
-		err := errors.New("Func collector missing function")
-		log.Println(err)
-		go func() { result <- err }()
+		return errors.New("func collector missing function")
 	}
+	return nil
+}
 
-	go func() {
-		for {
-			select {
-			case item := <-c.input:
-				if err := c.f(ctx, item); err != nil {
-					log.Println(err)
-				}
-			case <-ctx.Done():
-				log.Infof("Func collector %s done", c.name)
-				return
-			}
-		}
-	}()
+func (c *FuncCollector) Collect(ctx api.StreamContext, item interface{}) error {
+	return c.f(ctx, item)
 }
+
+func (c *FuncCollector) Close(api.StreamContext) error {
+	return nil
+}

+ 53 - 9
xstream/contexts/default.go

@@ -3,32 +3,59 @@ package contexts
 import (
 	"context"
 	"engine/common"
+	"engine/xstream/api"
 	"github.com/sirupsen/logrus"
+	"time"
 )
 
+const LoggerKey = "$$logger"
+
 type DefaultContext struct {
 	ruleId string
 	opId   string
-	ctx    context.Context
-	logger *logrus.Entry
+	ctx context.Context
 }
 
-func NewDefaultContext(ruleId string, opId string, ctx context.Context) *DefaultContext{
+func Background() *DefaultContext {
 	c := &DefaultContext{
-		ruleId: ruleId,
-		opId:	opId,
-		ctx:    ctx,
-		logger: common.GetLogger(ctx),
+		ctx:context.Background(),
 	}
 	return c
 }
 
-func (c *DefaultContext) GetContext() context.Context {
+func WithValue(parent *DefaultContext, key, val interface{}) *DefaultContext {
+	parent.ctx = context.WithValue(parent.ctx, key, val)
+	return parent
+}
+
+//Implement context interface
+func (c *DefaultContext) Deadline() (deadline time.Time, ok bool){
+	return c.ctx.Deadline()
+}
+
+func (c *DefaultContext) Done() <-chan struct{}{
+	return c.ctx.Done()
+}
+
+func (c *DefaultContext) Err() error{
+	return c.ctx.Err()
+}
+
+func (c *DefaultContext) Value(key interface{}) interface{}{
+	return c.ctx.Value(key)
+}
+
+// Stream metas
+func (c *DefaultContext) GetContext() context.Context{
 	return c.ctx
 }
 
 func (c *DefaultContext) GetLogger() *logrus.Entry {
-	return c.logger
+	l, ok := c.ctx.Value(LoggerKey).(*logrus.Entry)
+	if l != nil && ok {
+		return l
+	}
+	return common.Log.WithField("caller", "default")
 }
 
 func (c *DefaultContext) GetRuleId() string {
@@ -37,4 +64,21 @@ func (c *DefaultContext) GetRuleId() string {
 
 func (c *DefaultContext) GetOpId() string {
 	return c.opId
+}
+
+func (c *DefaultContext) WithMeta(ruleId string, opId string) api.StreamContext{
+	return &DefaultContext{
+		ruleId: ruleId,
+		opId: opId,
+		ctx:c.ctx,
+	}
+}
+
+func (c *DefaultContext) WithCancel() (api.StreamContext, context.CancelFunc) {
+	ctx, cancel := context.WithCancel(c.ctx)
+	return &DefaultContext{
+		ruleId: c.ruleId,
+		opId: c.opId,
+		ctx: ctx,
+	}, cancel
 }

+ 48 - 60
xstream/extensions/mqtt_source.go

@@ -1,7 +1,6 @@
 package extensions
 
 import (
-	"context"
 	"encoding/json"
 	"engine/common"
 	"engine/xsql"
@@ -92,73 +91,62 @@ func (ms *MQTTSource) WithSchema(schema string) *MQTTSource {
 
 func (ms *MQTTSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) error {
 	log := ctx.GetLogger()
-	go func() {
-		exeCtx, cancel := context.WithCancel(ctx.GetContext())
-		opts := MQTT.NewClientOptions().AddBroker(ms.srv).SetProtocolVersion(ms.pVersion)
-
-		if ms.clientid == "" {
-			if uuid, err := uuid.NewUUID(); err != nil {
-				log.Printf("Failed to get uuid, the error is %s", err)
-				cancel()
-				return
-			} else {
-				opts.SetClientID(uuid.String())
-			}
-		} else {
-			opts.SetClientID(ms.clientid)
-		}
-
-		if ms.uName != "" {
-			opts.SetUsername(ms.uName)
-		}
 
-		if ms.password != "" {
-			opts.SetPassword(ms.password)
+	opts := MQTT.NewClientOptions().AddBroker(ms.srv).SetProtocolVersion(ms.pVersion)
+	if ms.clientid == "" {
+		if uuid, err := uuid.NewUUID(); err != nil {
+			return fmt.Errorf("failed to get uuid, the error is %s", err)
+		} else {
+			opts.SetClientID(uuid.String())
 		}
+	} else {
+		opts.SetClientID(ms.clientid)
+	}
+	if ms.uName != "" {
+		opts.SetUsername(ms.uName)
+	}
 
+	if ms.password != "" {
+		opts.SetPassword(ms.password)
+	}
 
-		h := func(client MQTT.Client, msg MQTT.Message) {
-			log.Infof("received %s", msg.Payload())
+	h := func(client MQTT.Client, msg MQTT.Message) {
+		log.Infof("received %s", msg.Payload())
 
-			result := make(map[string]interface{})
-			//The unmarshal type can only be bool, float64, string, []interface{}, map[string]interface{}, nil
-			if e := json.Unmarshal(msg.Payload(), &result); e != nil {
-				log.Errorf("Invalid data format, cannot convert %s into JSON with error %s", string(msg.Payload()), e)
-				return
-			}
-			//Convert the keys to lowercase
-			result = xsql.LowercaseKeyMap(result)
+		result := make(map[string]interface{})
+		//The unmarshal type can only be bool, float64, string, []interface{}, map[string]interface{}, nil
+		if e := json.Unmarshal(msg.Payload(), &result); e != nil {
+			log.Errorf("Invalid data format, cannot convert %s into JSON with error %s", string(msg.Payload()), e)
+			return
+		}
+		//Convert the keys to lowercase
+		result = xsql.LowercaseKeyMap(result)
 
-			meta := make(map[string]interface{})
-			meta[xsql.INTERNAL_MQTT_TOPIC_KEY] = msg.Topic()
-			meta[xsql.INTERNAL_MQTT_MSG_ID_KEY] = strconv.Itoa(int(msg.MessageID()))
+		meta := make(map[string]interface{})
+		meta[xsql.INTERNAL_MQTT_TOPIC_KEY] = msg.Topic()
+		meta[xsql.INTERNAL_MQTT_MSG_ID_KEY] = strconv.Itoa(int(msg.MessageID()))
 
-			tuple := &xsql.Tuple{Emitter: ms.tpc, Message:result, Timestamp: common.TimeToUnixMilli(time.Now()), Metadata:meta}
-			consume(tuple)
-		}
+		tuple := &xsql.Tuple{Emitter: ms.tpc, Message:result, Timestamp: common.TimeToUnixMilli(time.Now()), Metadata:meta}
+		consume(tuple)
+	}
+	//TODO error listener?
+	opts.SetDefaultPublishHandler(h)
+	c := MQTT.NewClient(opts)
+	if token := c.Connect(); token.Wait() && token.Error() != nil {
+		return fmt.Errorf("found error when connecting to %s: %s", ms.srv, token.Error())
+	}
+	log.Printf("The connection to server %s was established successfully", ms.srv)
+	ms.conn = c
+	if token := c.Subscribe(ms.tpc, 0, nil); token.Wait() && token.Error() != nil {
+		return fmt.Errorf("Found error: %s", token.Error())
+	}
+	log.Printf("Successfully subscribe to topic %s", ms.tpc)
 
-		opts.SetDefaultPublishHandler(h)
-		c := MQTT.NewClient(opts)
-		if token := c.Connect(); token.Wait() && token.Error() != nil {
-			log.Printf("Found error when connecting to %s: %s", ms.srv, token.Error())
-			cancel()
-			return
-		}
-		log.Printf("The connection to server %s was established successfully", ms.srv)
-		ms.conn = c
-		if token := c.Subscribe(ms.tpc, 0, nil); token.Wait() && token.Error() != nil {
-			log.Printf("Found error: %s", token.Error())
-			cancel()
-			return
-		}
-		log.Printf("Successfully subscribe to topic %s", ms.tpc)
-		select {
-		case <-exeCtx.Done():
-			log.Println("Mqtt Source Done")
-			ms.conn.Disconnect(5000)
-			cancel()
-		}
-	}()
+	return nil
+}
 
+func (ms *MQTTSource) Close(ctx api.StreamContext) error{
+	ctx.GetLogger().Println("Mqtt Source Done")
+	ms.conn.Disconnect(5000)
 	return nil
 }

+ 4 - 3
xstream/funcs.go

@@ -2,6 +2,7 @@ package xstream
 
 import (
 	"context"
+	"engine/xstream/api"
 	"engine/xstream/operators"
 	"fmt"
 	"reflect"
@@ -34,7 +35,7 @@ func ProcessFunc(f interface{}) (operators.UnFunc, error) {
 
 	fnval := reflect.ValueOf(f)
 
-	return operators.UnFunc(func(ctx context.Context, data interface{}) interface{} {
+	return operators.UnFunc(func(ctx api.StreamContext, data interface{}) interface{} {
 		result := callOpFunc(fnval, ctx, data, funcForm)
 		return result.Interface()
 	}), nil
@@ -63,7 +64,7 @@ func FilterFunc(f interface{}) (operators.UnFunc, error) {
 	}
 
 	fnval := reflect.ValueOf(f)
-	return operators.UnFunc(func(ctx context.Context, data interface{}) interface{} {
+	return operators.UnFunc(func(ctx api.StreamContext, data interface{}) interface{} {
 		result := callOpFunc(fnval, ctx, data, funcForm)
 		predicate := result.Bool()
 		if !predicate {
@@ -103,7 +104,7 @@ func FlatMapFunc(f interface{}) (operators.UnFunc, error) {
 	}
 
 	fnval := reflect.ValueOf(f)
-	return operators.UnFunc(func(ctx context.Context, data interface{}) interface{} {
+	return operators.UnFunc(func(ctx api.StreamContext, data interface{}) interface{} {
 		result := callOpFunc(fnval, ctx, data, funcForm)
 		return result.Interface()
 	}), nil

+ 9 - 9
xstream/nodes/common_func.go

@@ -1,21 +1,21 @@
 package nodes
 
-import "fmt"
+import (
+	"engine/xstream/api"
+)
 
-func Broadcast(outputs map[string]chan<- interface{}, val interface{}) (err error) {
+func Broadcast(outputs map[string]chan<- interface{}, val interface{}, ctx api.StreamContext) int {
+	count := 0
+	logger := ctx.GetLogger()
 	for n, out := range outputs {
 		select {
 		case out <- val:
-			//All ok
+			count++
 		default: //TODO channel full strategy?
-			if err != nil {
-				err = fmt.Errorf("%v;channel full for %s", err, n)
-			} else {
-				err = fmt.Errorf("channel full for %s", n)
-			}
+			logger.Errorf("send output from %s to %s fail: channel full", ctx.GetOpId(), n)
 		}
 	}
-	return err
+	return count
 }
 
 

+ 56 - 0
xstream/nodes/sink_node.go

@@ -0,0 +1,56 @@
+package nodes
+
+import (
+	"engine/xstream/api"
+)
+
+type SinkNode struct {
+	sink   api.Sink
+	input  chan interface{}
+	name   string
+	ctx    api.StreamContext
+}
+
+func NewSinkNode(name string, sink api.Sink) *SinkNode{
+	return &SinkNode{
+		sink: sink,
+		input: make(chan interface{}, 1024),
+		name: name,
+		ctx: nil,
+	}
+}
+
+func (m *SinkNode) Open(ctx api.StreamContext, result chan<- error) {
+	m.ctx = ctx
+	logger := ctx.GetLogger()
+	logger.Debugf("open sink node %s", m.name)
+	go func() {
+		if err := m.sink.Open(ctx); err != nil{
+			go func() { result <- err }()
+			return
+		}
+		for {
+			select {
+			case item := <-m.input:
+				if err := m.sink.Collect(ctx, item); err != nil{
+					//TODO deal with publish error
+					logger.Errorf("sink node %s publish %v error: %v", ctx.GetOpId(), item, err)
+				}
+			case <-ctx.Done():
+				logger.Infof("sink node %s done", m.name)
+				if err := m.sink.Close(ctx); err != nil{
+					go func() { result <- err }()
+				}
+				return
+			}
+		}
+	}()
+}
+
+func (m *SinkNode) GetName() string{
+	return m.name
+}
+
+func (m *SinkNode) GetInput() (chan<- interface{}, string)  {
+	return m.input, m.name
+}

+ 27 - 7
xstream/nodes/source_node.go

@@ -21,18 +21,38 @@ func NewSourceNode(name string, source api.Source) *SourceNode{
 	}
 }
 
-func (m *SourceNode) Open(ctx api.StreamContext) error {
+func (m *SourceNode) Open(ctx api.StreamContext, errCh chan<- error) {
 	m.ctx = ctx
 	logger := ctx.GetLogger()
 	logger.Debugf("open source node %s", m.name)
-	return m.source.Open(ctx, func(data interface{}){
-		m.Broadcast(data)
-		logger.Debugf("%s consume data %v complete", m.name, data)
-	})
+	go func(){
+		if err := m.source.Open(ctx, func(data interface{}){
+			m.Broadcast(data)
+			logger.Debugf("%s consume data %v complete", m.name, data)
+		}); err != nil{
+			select {
+			case errCh <- err:
+			case <-ctx.Done():
+				if err := m.source.Close(ctx); err != nil{
+					go func() { errCh <- err }()
+				}
+			}
+		}
+		for {
+			select {
+			case <-ctx.Done():
+				logger.Infof("source %s done", m.name)
+				if err := m.source.Close(ctx); err != nil{
+					go func() { errCh <- err }()
+				}
+				return
+			}
+		}
+	}()
 }
 
-func (m *SourceNode) Broadcast(data interface{}) (err error){
-	return Broadcast(m.outs, data)
+func (m *SourceNode) Broadcast(data interface{}) int{
+	return Broadcast(m.outs, data, m.ctx)
 }
 
 func (m *SourceNode) GetName() string{

+ 16 - 20
xstream/operators/operations.go

@@ -1,8 +1,7 @@
 package operators
 
 import (
-	"context"
-	"engine/common"
+	"engine/xstream/api"
 	"engine/xstream/nodes"
 	"fmt"
 	"sync"
@@ -10,14 +9,14 @@ import (
 
 // UnOperation interface represents unary operations (i.e. Map, Filter, etc)
 type UnOperation interface {
-	Apply(ctx context.Context, data interface{}) interface{}
+	Apply(ctx api.StreamContext, data interface{}) interface{}
 }
 
 // UnFunc implements UnOperation as type func (context.Context, interface{})
-type UnFunc func(context.Context, interface{}) interface{}
+type UnFunc func(api.StreamContext, interface{}) interface{}
 
 // Apply implements UnOperation.Apply method
-func (f UnFunc) Apply(ctx context.Context, data interface{}) interface{} {
+func (f UnFunc) Apply(ctx api.StreamContext, data interface{}) interface{} {
 	return f(ctx, data)
 }
 
@@ -62,7 +61,7 @@ func (o *UnaryOperator) SetConcurrency(concurr int) {
 	}
 }
 
-func (o *UnaryOperator) AddOutput(output chan<- interface{}, name string) (err error){
+func (o *UnaryOperator) AddOutput(output chan<- interface{}, name string) error{
 	if _, ok := o.outputs[name]; !ok{
 		o.outputs[name] = output
 	}else{
@@ -76,12 +75,12 @@ func (o *UnaryOperator) GetInput() (chan<- interface{}, string) {
 }
 
 // Exec is the entry point for the executor
-func (o *UnaryOperator) Exec(ctx context.Context) (err error) {
-	log := common.GetLogger(ctx)
-	log.Printf("Unary operator %s is started", o.name)
+func (o *UnaryOperator) Exec(ctx api.StreamContext, errCh chan<- error ) {
+	log := ctx.GetLogger()
+	log.Tracef("Unary operator %s is started", o.name)
 
 	if len(o.outputs) <= 0 {
-		err = fmt.Errorf("no output channel found")
+		go func(){errCh <- fmt.Errorf("no output channel found")}()
 		return
 	}
 
@@ -98,7 +97,7 @@ func (o *UnaryOperator) Exec(ctx context.Context) (err error) {
 		for i := 0; i < o.concurrency; i++ { // workers
 			go func(wg *sync.WaitGroup) {
 				defer wg.Done()
-				o.doOp(ctx)
+				o.doOp(ctx, errCh)
 			}(&barrier)
 		}
 
@@ -119,17 +118,15 @@ func (o *UnaryOperator) Exec(ctx context.Context) (err error) {
 			return
 		}
 	}()
-
-	return nil
 }
 
-func (o *UnaryOperator) doOp(ctx context.Context) {
-	log := common.GetLogger(ctx)
+func (o *UnaryOperator) doOp(ctx api.StreamContext, errCh chan<- error) {
+	log := ctx.GetLogger()
 	if o.op == nil {
 		log.Println("Unary operator missing operation")
 		return
 	}
-	exeCtx, cancel := context.WithCancel(ctx)
+	exeCtx, cancel := ctx.WithCancel()
 
 	defer func() {
 		log.Infof("unary operator %s done, cancelling future items", o.name)
@@ -145,17 +142,16 @@ func (o *UnaryOperator) doOp(ctx context.Context) {
 			switch val := result.(type) {
 			case nil:
 				continue
-			case error:
+			case error: //TODO error handling
 				log.Println(val)
 				log.Println(val.Error())
 				continue
-
 			default:
-				nodes.Broadcast(o.outputs, val)
+				nodes.Broadcast(o.outputs, val, ctx)
 			}
 
 		// is cancelling
-		case <-exeCtx.Done():
+		case <-ctx.Done():
 			log.Printf("unary operator %s cancelling....", o.name)
 			o.mutex.Lock()
 			cancel()

+ 13 - 14
xstream/operators/watermark.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 	"fmt"
 	"math"
 	"sort"
@@ -62,8 +63,8 @@ func NewWatermarkGenerator(window *WindowConfig, l int64, s []string, stream cha
 	return w, nil
 }
 
-func (w *WatermarkGenerator) track(s string, ts int64, ctx context.Context) bool {
-	log := common.GetLogger(ctx)
+func (w *WatermarkGenerator) track(s string, ts int64, ctx api.StreamContext) bool {
+	log := ctx.GetLogger()
 	log.Infof("watermark generator track event from topic %s at %d", s, ts)
 	currentVal, ok := w.topicToTs[s]
 	if !ok || ts > currentVal {
@@ -79,9 +80,8 @@ func (w *WatermarkGenerator) track(s string, ts int64, ctx context.Context) bool
 	return r
 }
 
-func (w *WatermarkGenerator) start(ctx context.Context) {
-	exeCtx, cancel := context.WithCancel(ctx)
-	log := common.GetLogger(ctx)
+func (w *WatermarkGenerator) start(ctx api.StreamContext) {
+	log := ctx.GetLogger()
 	var c <-chan time.Time
 
 	if w.ticker != nil {
@@ -91,19 +91,18 @@ func (w *WatermarkGenerator) start(ctx context.Context) {
 		select {
 		case <-c:
 			w.trigger(ctx)
-		case <-exeCtx.Done():
+		case <-ctx.Done():
 			log.Println("Cancelling watermark generator....")
 			if w.ticker != nil{
 				w.ticker.Stop()
 			}
-			cancel()
 			return
 		}
 	}
 }
 
-func (w *WatermarkGenerator) trigger(ctx context.Context) {
-	log := common.GetLogger(ctx)
+func (w *WatermarkGenerator) trigger(ctx api.StreamContext) {
+	log := ctx.GetLogger()
 	watermark := w.computeWatermarkTs(ctx)
 	log.Infof("compute watermark event at %d with last %d", watermark, w.lastWatermarkTs)
 	if watermark > w.lastWatermarkTs {
@@ -184,10 +183,10 @@ func (w *WatermarkGenerator) getNextWindow(inputs []*xsql.Tuple,current int64, w
 	}
 }
 
-func (o *WindowOperator) execEventWindow(ctx context.Context) {
-	exeCtx, cancel := context.WithCancel(ctx)
-	log := common.GetLogger(ctx)
-	go o.watermarkGenerator.start(ctx)
+func (o *WindowOperator) execEventWindow(ctx api.StreamContext, errCh chan<- error) {
+	exeCtx, cancel := ctx.WithCancel()
+	log := ctx.GetLogger()
+	go o.watermarkGenerator.start(exeCtx)
 	var (
 		inputs []*xsql.Tuple
 		triggered bool
@@ -236,7 +235,7 @@ func (o *WindowOperator) execEventWindow(ctx context.Context) {
 
 			}
 		// is cancelling
-		case <-exeCtx.Done():
+		case <-ctx.Done():
 			log.Println("Cancelling window....")
 			if o.ticker != nil{
 				o.ticker.Stop()

+ 14 - 17
xstream/operators/window_op.go

@@ -1,9 +1,9 @@
 package operators
 
 import (
-	"context"
 	"engine/common"
 	"engine/xsql"
+	"engine/xstream/api"
 	"engine/xstream/nodes"
 	"fmt"
 	"github.com/sirupsen/logrus"
@@ -96,26 +96,23 @@ func (o *WindowOperator) GetInput() (chan<- interface{}, string) {
 // Exec is the entry point for the executor
 // input: *xsql.Tuple from preprocessor
 // output: xsql.WindowTuplesSet
-func (o *WindowOperator) Exec(ctx context.Context) (err error) {
-	log := common.GetLogger(ctx)
+func (o *WindowOperator) Exec(ctx api.StreamContext, errCh chan<- error ){
+	log := ctx.GetLogger()
 	log.Printf("Window operator %s is started", o.name)
 
 	if len(o.outputs) <= 0 {
-		err = fmt.Errorf("no output channel found")
+		go func(){errCh <- fmt.Errorf("no output channel found")}()
 		return
 	}
 	if o.isEventTime{
-		go o.execEventWindow(ctx)
+		go o.execEventWindow(ctx, errCh)
 	}else{
-		go o.execProcessingWindow(ctx)
+		go o.execProcessingWindow(ctx, errCh)
 	}
-
-	return nil
 }
 
-func (o *WindowOperator) execProcessingWindow(ctx context.Context) {
-	exeCtx, cancel := context.WithCancel(ctx)
-	log := common.GetLogger(ctx)
+func (o *WindowOperator) execProcessingWindow(ctx api.StreamContext, errCh chan<- error) {
+	log := ctx.GetLogger()
 	var (
 		inputs []*xsql.Tuple
 		c <-chan time.Time
@@ -179,19 +176,18 @@ func (o *WindowOperator) execProcessingWindow(ctx context.Context) {
 				inputs = make([]*xsql.Tuple, 0)
 			}
 		// is cancelling
-		case <-exeCtx.Done():
+		case <-ctx.Done():
 			log.Println("Cancelling window....")
 			if o.ticker != nil{
 				o.ticker.Stop()
 			}
-			cancel()
 			return
 		}
 	}
 }
 
-func (o *WindowOperator) scan(inputs []*xsql.Tuple, triggerTime int64, ctx context.Context) ([]*xsql.Tuple, bool){
-	log := common.GetLogger(ctx)
+func (o *WindowOperator) scan(inputs []*xsql.Tuple, triggerTime int64, ctx api.StreamContext) ([]*xsql.Tuple, bool){
+	log := ctx.GetLogger()
 	log.Printf("window %s triggered at %s", o.name, time.Unix(triggerTime/1000, triggerTime%1000))
 	var delta int64
 	if o.window.Type == xsql.HOPPING_WINDOW || o.window.Type == xsql.SLIDING_WINDOW {
@@ -227,8 +223,9 @@ func (o *WindowOperator) scan(inputs []*xsql.Tuple, triggerTime int64, ctx conte
 		if o.isEventTime{
 			results.Sort()
 		}
-		err := nodes.Broadcast(o.outputs, results)
-		if err != nil{
+		count := nodes.Broadcast(o.outputs, results, ctx)
+		//TODO deal with partial fail
+		if count > 0{
 			triggered = true
 		}
 	}

+ 7 - 8
xstream/sinks/log_sink.go

@@ -1,8 +1,7 @@
 package sinks
 
 import (
-	"context"
-	"engine/common"
+	"engine/xstream/api"
 	"engine/xstream/collectors"
 	"fmt"
 	"sync"
@@ -11,10 +10,10 @@ import (
 
 // log action, no properties now
 // example: {"log":{}}
-func NewLogSink(name string, ruleId string) *collectors.FuncCollector {
-	return collectors.Func(name, func(ctx context.Context, data interface{}) error {
-		log := common.GetLogger(ctx)
-		log.Printf("sink result for rule %s: %s", ruleId, data)
+func NewLogSink() *collectors.FuncCollector {
+	return collectors.Func(func(ctx api.StreamContext, data interface{}) error {
+		log := ctx.GetLogger()
+		log.Printf("sink result for rule %s: %s", ctx.GetRuleId(), data)
 		return nil
 	})
 }
@@ -27,9 +26,9 @@ type QueryResult struct {
 
 var QR = &QueryResult{LastFetch:time.Now()}
 
-func NewLogSinkToMemory(name string, ruleId string) *collectors.FuncCollector {
+func NewLogSinkToMemory() *collectors.FuncCollector {
 	QR.Results = make([]string, 10)
-	return collectors.Func(name, func(ctx context.Context, data interface{}) error {
+	return collectors.Func(func(ctx api.StreamContext, data interface{}) error {
 		QR.Mux.Lock()
 		QR.Results = append(QR.Results, fmt.Sprintf("%s", data))
 		QR.Mux.Unlock()

+ 35 - 53
xstream/sinks/mqtt_sink.go

@@ -1,8 +1,7 @@
 package sinks
 
 import (
-	"context"
-	"engine/common"
+	"engine/xstream/api"
 	"fmt"
 	MQTT "github.com/eclipse/paho.mqtt.golang"
 	"github.com/google/uuid"
@@ -16,15 +15,10 @@ type MQTTSink struct {
 	pVersion uint
 	uName 	string
 	password string
-
-	input chan interface{}
 	conn MQTT.Client
-	ruleId   string
-	name 	 string
-	//ctx context.Context
 }
 
-func NewMqttSink(name string, ruleId string, properties interface{}) (*MQTTSink, error) {
+func NewMqttSink(properties interface{}) (*MQTTSink, error) {
 	ps, ok := properties.(map[string]interface{})
 	if !ok {
 		return nil, fmt.Errorf("expect map[string]interface{} type for the mqtt sink properties")
@@ -76,58 +70,46 @@ func NewMqttSink(name string, ruleId string, properties interface{}) (*MQTTSink,
 		}
 	}
 
-	ms := &MQTTSink{name:name, ruleId: ruleId, input: make(chan interface{}), srv: srv.(string), tpc: tpc.(string), clientid: clientid.(string), pVersion:pVersion, uName:uName, password:password}
+	ms := &MQTTSink{srv: srv.(string), tpc: tpc.(string), clientid: clientid.(string), pVersion:pVersion, uName:uName, password:password}
 	return ms, nil
 }
 
-func (ms *MQTTSink) GetName() string {
-	return ms.name
-}
-
-func (ms *MQTTSink) GetInput() (chan<- interface{}, string)  {
-	return ms.input, ms.name
-}
-
-func (ms *MQTTSink) Open(ctx context.Context, result chan<- error) {
-	log := common.GetLogger(ctx)
-	log.Printf("Opening mqtt sink for rule %s", ms.ruleId)
-
-	go func() {
-		exeCtx, cancel := context.WithCancel(ctx)
-		opts := MQTT.NewClientOptions().AddBroker(ms.srv).SetClientID(ms.clientid).SetProtocolVersion(ms.pVersion)
-		if ms.uName != "" {
-			opts = opts.SetUsername(ms.uName)
-		}
-
-		if ms.password != "" {
-			opts = opts.SetPassword(ms.password)
-		}
+func (ms *MQTTSink) Open(ctx api.StreamContext) error {
+	log := ctx.GetLogger()
+	log.Printf("Opening mqtt sink for rule %s", ctx.GetRuleId())
+	opts := MQTT.NewClientOptions().AddBroker(ms.srv).SetClientID(ms.clientid)
+	if ms.uName != "" {
+		opts = opts.SetUsername(ms.uName)
+	}
 
-		c := MQTT.NewClient(opts)
-		if token := c.Connect(); token.Wait() && token.Error() != nil {
-			result <- fmt.Errorf("Found error: %s", token.Error())
-			cancel()
-		}
-		log.Printf("The connection to server %s was established successfully", ms.srv)
-		ms.conn = c
+	if ms.password != "" {
+		opts = opts.SetPassword(ms.password)
+	}
 
-		for {
-			select {
-			case item := <-ms.input:
-				log.Infof("publish %s", item)
-				if token := c.Publish(ms.tpc, 0, false, item); token.Wait() && token.Error() != nil {
-					result <- fmt.Errorf("Publish error: %s", token.Error())
-				}
+	c := MQTT.NewClient(opts)
+	if token := c.Connect(); token.Wait() && token.Error() != nil {
+		return fmt.Errorf("Found error: %s", token.Error())
+	}
+	log.Printf("The connection to server %s was established successfully", ms.srv)
+	ms.conn = c
+	return nil
+}
 
-			case <-exeCtx.Done():
-				c.Disconnect(5000)
-				log.Infof("Closing mqtt sink")
-				cancel()
-				return
-			}
-		}
+func (ms *MQTTSink) Collect(ctx api.StreamContext, item interface{}) error {
+	logger := ctx.GetLogger()
+	c := ms.conn
+	logger.Infof("publish %s", item)
+	if token := c.Publish(ms.tpc, 0, false, item); token.Wait() && token.Error() != nil {
+		return fmt.Errorf("publish error: %s", token.Error())
+	}
+	return nil
+}
 
-	}()
+func (ms *MQTTSink) Close(ctx api.StreamContext) error {
+	logger := ctx.GetLogger()
+	logger.Infof("Closing mqtt sink")
+	ms.conn.Disconnect(5000)
+	return nil
 }
 
 

+ 15 - 22
xstream/streams.go

@@ -11,8 +11,8 @@ import (
 
 type TopologyNew struct {
 	sources []*nodes.SourceNode
-	sinks []api.Sink
-	ctx context.Context
+	sinks []*nodes.SinkNode
+	ctx api.StreamContext
 	cancel context.CancelFunc
 	drain chan error
 	ops []api.Operator
@@ -37,7 +37,7 @@ func (s *TopologyNew) AddSrc(src *nodes.SourceNode) *TopologyNew {
 	return s
 }
 
-func (s *TopologyNew) AddSink(inputs []api.Emitter, snk api.Sink) *TopologyNew {
+func (s *TopologyNew) AddSink(inputs []api.Emitter, snk *nodes.SinkNode) *TopologyNew {
 	for _, input := range inputs{
 		input.AddOutput(snk.GetInput())
 	}
@@ -60,7 +60,7 @@ func Transform(op operators.UnOperation, name string) *operators.UnaryOperator {
 }
 
 func (s *TopologyNew) Map(f interface{}) *TopologyNew {
-	log := common.GetLogger(s.ctx)
+	log := s.ctx.GetLogger()
 	op, err := MapFunc(f)
 	if err != nil {
 		log.Println(err)
@@ -94,9 +94,9 @@ func (s *TopologyNew) Transform(op operators.UnOperation) *TopologyNew {
 // stream starts execution.
 func (s *TopologyNew) prepareContext() {
 	if s.ctx == nil || s.ctx.Err() != nil {
-		s.ctx, s.cancel = context.WithCancel(context.Background())
 		contextLogger := common.Log.WithField("rule", s.name)
-		s.ctx = context.WithValue(s.ctx, common.LoggerKey, contextLogger)
+		ctx := contexts.WithValue(contexts.Background(), contexts.LoggerKey, contextLogger)
+		s.ctx, s.cancel = ctx.WithCancel()
 	}
 }
 
@@ -106,41 +106,34 @@ func (s *TopologyNew) drainErr(err error) {
 
 func (s *TopologyNew) Open() <-chan error {
 	s.prepareContext() // ensure context is set
-	log := common.GetLogger(s.ctx)
+	log := s.ctx.GetLogger()
 	log.Println("Opening stream")
 
 	// open stream
 	go func() {
-		sinkErr := make(chan error)
+		streamErr := make(chan error)
 		defer func() {
-			log.Println("Closing sinkErr channel")
-			close(sinkErr)
+			log.Println("Closing streamErr channel")
+			close(streamErr)
 		}()
 		// open stream sink, after log sink is ready.
 		for _, snk := range s.sinks{
-			snk.Open(s.ctx, sinkErr)
+			snk.Open(s.ctx.WithMeta(s.name, snk.GetName()), streamErr)
 		}
 
 		//apply operators, if err bail
 		for _, op := range s.ops {
-			if err := op.Exec(s.ctx); err != nil {
-				s.drainErr(err)
-				log.Println("Closing stream")
-				return
-			}
+			op.Exec(s.ctx.WithMeta(s.name, op.GetName()), streamErr)
 		}
 
 		// open source, if err bail
 		for _, node := range s.sources{
-			if err := node.Open(contexts.NewDefaultContext(s.name, node.GetName(), s.ctx)); err != nil {
-				s.drainErr(err)
-				log.Println("Closing stream")
-				return
-			}
+			node.Open(s.ctx.WithMeta(s.name, node.GetName()), streamErr)
 		}
 
 		select {
-		case err := <- sinkErr:
+		case err := <-streamErr:
+			//TODO error handling
 			log.Println("Closing stream")
 			s.drain <- err
 		}

+ 20 - 32
xstream/test/mock_sink.go

@@ -1,53 +1,41 @@
 package test
 
 import (
-	"context"
-	"engine/common"
+	"engine/xstream/api"
 )
 
 type MockSink struct {
-	ruleId   string
-	name 	 string
 	results  [][]byte
-	input chan interface{}
 }
 
-func NewMockSink(name, ruleId string) *MockSink{
-	m := &MockSink{
-		ruleId:  ruleId,
-		name:    name,
-		input: make(chan interface{}),
-	}
+func NewMockSink() *MockSink{
+	m := &MockSink{}
 	return m
 }
 
-func (m *MockSink) Open(ctx context.Context, result chan<- error) {
-	log := common.GetLogger(ctx)
+func (m *MockSink) Open(ctx api.StreamContext) error {
+	log := ctx.GetLogger()
 	log.Trace("Opening mock sink")
 	m.results = make([][]byte, 0)
-	go func() {
-		for {
-			select {
-			case item := <-m.input:
-				if v, ok := item.([]byte); ok {
-					log.Infof("mock sink receive %s", item)
-					m.results = append(m.results, v)
-				}else{
-					log.Info("mock sink receive non byte data")
-				}
+	return nil
+}
 
-			case <-ctx.Done():
-				log.Infof("mock sink %s done", m.name)
-				return
-			}
-		}
-	}()
+func (m *MockSink) Collect(ctx api.StreamContext, item interface{}) error {
+	logger := ctx.GetLogger()
+	if v, ok := item.([]byte); ok {
+		logger.Infof("mock sink receive %s", item)
+		m.results = append(m.results, v)
+	}else{
+		logger.Info("mock sink receive non byte data")
+	}
+	return nil
 }
 
-func (m *MockSink) GetInput() (chan<- interface{}, string)  {
-	return m.input, m.name
+func (m *MockSink) Close(ctx api.StreamContext) error {
+	//do nothing
+	return nil
 }
 
 func (m *MockSink) GetResults() [][]byte {
 	return m.results
-}
+}

+ 5 - 0
xstream/test/mock_source.go

@@ -25,6 +25,7 @@ func NewMockSource(data []*xsql.Tuple, done chan<- struct{}, isEventTime bool) *
 
 func (m *MockSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err error) {
 	log := ctx.GetLogger()
+
 	log.Trace("mock source starts")
 	go func(){
 		for _, d := range m.data{
@@ -55,4 +56,8 @@ func (m *MockSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err e
 		m.done <- struct{}{}
 	}()
 	return nil
+}
+
+func (m *MockSource) Close(ctx api.StreamContext) error{
+	return nil
 }