浏览代码

refactor(source): add error callback for source open

ngjaying 5 年之前
父节点
当前提交
eaaf86a2fc

+ 1 - 2
plugins/sources/random.go

@@ -35,7 +35,7 @@ func (s *randomSource) Configure(topic string, props map[string]interface{}) err
 	return nil
 }
 
-func (s *randomSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err error) {
+func (s *randomSource) Open(ctx api.StreamContext, consume api.ConsumeFunc, onError api.ErrorFunc) {
 	t := time.NewTicker(time.Duration(s.interval) * time.Millisecond)
 	exeCtx, cancel := ctx.WithCancel()
 	s.cancel = cancel
@@ -50,7 +50,6 @@ func (s *randomSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err
 			}
 		}
 	}(exeCtx)
-	return nil
 }
 
 func randomize(p map[string]interface{}, seed int) map[string]interface{} {

+ 5 - 5
plugins/sources/zmq.go

@@ -25,15 +25,16 @@ func (s *zmqSource) Configure(topic string, props map[string]interface{}) error
 	return nil
 }
 
-func (s *zmqSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err error) {
+func (s *zmqSource) Open(ctx api.StreamContext, consume api.ConsumeFunc, onError api.ErrorFunc) {
 	logger := ctx.GetLogger()
+	var err error
 	s.subscriber, err = zmq.NewSocket(zmq.SUB)
 	if err != nil {
-		return fmt.Errorf("zmq source fails to create socket: %v", err)
+		onError(fmt.Errorf("zmq source fails to create socket: %v", err))
 	}
 	err = s.subscriber.Connect(s.srv)
 	if err != nil {
-		return fmt.Errorf("zmq source fails to connect to %s: %v", s.srv, err)
+		onError(fmt.Errorf("zmq source fails to connect to %s: %v", s.srv, err))
 	}
 	s.subscriber.SetSubscribe(s.topic)
 	logger.Debugf("zmq source subscribe to topic %s", s.topic)
@@ -45,7 +46,7 @@ func (s *zmqSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err er
 			msgs, err := s.subscriber.RecvMessage(0)
 			if err != nil {
 				id, err := s.subscriber.GetIdentity()
-				logger.Warnf("zmq source getting message %s error: %v", id, err)
+				onError(fmt.Errorf("zmq source getting message %s error: %v", id, err))
 			} else {
 				logger.Debugf("zmq source receive %v", msgs)
 				var m string
@@ -78,7 +79,6 @@ func (s *zmqSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err er
 			}
 		}
 	}(exeCtx)
-	return nil
 }
 
 func (s *zmqSource) Close(ctx api.StreamContext) error {

+ 2 - 1
xstream/api/stream.go

@@ -6,6 +6,7 @@ import (
 
 //The function to call when data is emitted by the source.
 type ConsumeFunc func(message map[string]interface{}, metadata map[string]interface{})
+type ErrorFunc func(err error)
 type Logger interface {
 	Debug(args ...interface{})
 	Info(args ...interface{})
@@ -27,7 +28,7 @@ type Closable interface {
 
 type Source interface {
 	//Should be sync function for normal case. The container will run it in go func
-	Open(ctx StreamContext, consume ConsumeFunc) error
+	Open(ctx StreamContext, consume ConsumeFunc, onError ErrorFunc)
 	//Called during initialization. Configure the source with the data source(e.g. topic for mqtt) and the properties
 	//read from the yaml
 	Configure(datasource string, props map[string]interface{}) error

+ 6 - 8
xstream/extensions/mqtt_source.go

@@ -70,13 +70,13 @@ func (ms *MQTTSource) Configure(topic string, props map[string]interface{}) erro
 	return nil
 }
 
-func (ms *MQTTSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) error {
+func (ms *MQTTSource) Open(ctx api.StreamContext, consume api.ConsumeFunc, onError api.ErrorFunc) {
 	log := ctx.GetLogger()
 
 	opts := MQTT.NewClientOptions().AddBroker(ms.srv).SetProtocolVersion(ms.pVersion)
 	if ms.clientid == "" {
 		if uuid, err := uuid.NewUUID(); err != nil {
-			return fmt.Errorf("failed to get uuid, the error is %s", err)
+			onError(fmt.Errorf("failed to get uuid, the error is %s", err))
 		} else {
 			ms.clientid = uuid.String()
 			opts.SetClientID(uuid.String())
@@ -92,15 +92,15 @@ func (ms *MQTTSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) error
 			if kp, err1 := common.ProcessPath(ms.pkeyPath); err1 == nil {
 				log.Infof("The private key file is %s.", kp)
 				if cer, err2 := tls.LoadX509KeyPair(cp, kp); err2 != nil {
-					return err2
+					onError(err2)
 				} else {
 					opts.SetTLSConfig(&tls.Config{Certificates: []tls.Certificate{cer}})
 				}
 			} else {
-				return err1
+				onError(err1)
 			}
 		} else {
-			return err
+			onError(err)
 		}
 	} else {
 		log.Infof("Connect MQTT broker with username and password.")
@@ -132,14 +132,12 @@ func (ms *MQTTSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) error
 
 	c := MQTT.NewClient(opts)
 	if token := c.Connect(); token.Wait() && token.Error() != nil {
-		return fmt.Errorf("found error when connecting to %s: %s", ms.srv, token.Error())
+		onError(fmt.Errorf("found error when connecting to %s: %s", ms.srv, token.Error()))
 	}
 	log.Infof("The connection to server %s was established successfully", ms.srv)
 	ms.conn = c
 	subscribe(ms.tpc, c, ctx, consume)
 	log.Infof("Successfully subscribe to topic %s", ms.srv+": "+ms.clientid)
-
-	return nil
 }
 
 func subscribe(topic string, client MQTT.Client, ctx api.StreamContext, consume api.ConsumeFunc) {

+ 3 - 4
xstream/nodes/source_node.go

@@ -109,7 +109,7 @@ func (m *SourceNode) Open(ctx api.StreamContext, errCh chan<- error) {
 				m.statManagers = append(m.statManagers, stats)
 				m.mutex.Unlock()
 
-				if err := source.Open(ctx.WithInstance(instance), func(message map[string]interface{}, meta map[string]interface{}) {
+				source.Open(ctx.WithInstance(instance), func(message map[string]interface{}, meta map[string]interface{}) {
 					stats.IncTotalRecordsIn()
 					stats.ProcessTimeStart()
 					tuple := &xsql.Tuple{Emitter: m.name, Message: message, Timestamp: common.GetNowInMilli(), Metadata: meta}
@@ -118,10 +118,9 @@ func (m *SourceNode) Open(ctx api.StreamContext, errCh chan<- error) {
 					stats.IncTotalRecordsOut()
 					stats.SetBufferLength(int64(m.getBufferLength()))
 					logger.Debugf("%s consume data %v complete", m.name, tuple)
-				}); err != nil {
+				}, func(err error) {
 					m.drainError(errCh, err, ctx, logger)
-					return
-				}
+				})
 				logger.Infof("Start source %s instance %d successfully", m.name, instance)
 			}(i)
 		}

+ 1 - 2
xstream/test/mock_source.go

@@ -23,7 +23,7 @@ func NewMockSource(data []*xsql.Tuple, done <-chan int, isEventTime bool) *MockS
 	return mock
 }
 
-func (m *MockSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err error) {
+func (m *MockSource) Open(ctx api.StreamContext, consume api.ConsumeFunc, onError api.ErrorFunc) {
 	log := ctx.GetLogger()
 	mockClock := GetMockClock()
 	log.Debugln("mock source starts")
@@ -40,7 +40,6 @@ func (m *MockSource) Open(ctx api.StreamContext, consume api.ConsumeFunc) (err e
 			time.Sleep(1)
 		}
 	}()
-	return nil
 }
 
 func (m *MockSource) Close(ctx api.StreamContext) error {