Quellcode durchsuchen

refactor: support DefaultContext.DecodeIntoList (#1833)

* refactor

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

---------

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao vor 1 Jahr
Ursprung
Commit
6cf494af4c

+ 6 - 3
internal/converter/json/converter.go

@@ -33,7 +33,10 @@ func (c *Converter) Encode(d interface{}) ([]byte, error) {
 }
 }
 
 
 func (c *Converter) Decode(b []byte) (interface{}, error) {
 func (c *Converter) Decode(b []byte) (interface{}, error) {
-	result := make(map[string]interface{})
-	e := json.Unmarshal(b, &result)
-	return result, e
+	var r0 interface{}
+	err := json.Unmarshal(b, &r0)
+	if err != nil {
+		return nil, err
+	}
+	return r0, nil
 }
 }

+ 21 - 2
internal/converter/json/converter_test.go

@@ -33,6 +33,7 @@ func TestMessageDecode(t *testing.T) {
 		payload []byte
 		payload []byte
 		format  string
 		format  string
 		result  map[string]interface{}
 		result  map[string]interface{}
+		results []interface{}
 	}{
 	}{
 		{
 		{
 			payload: []byte(fmt.Sprintf(`{"format":"jpg","content":"%s"}`, b64img)),
 			payload: []byte(fmt.Sprintf(`{"format":"jpg","content":"%s"}`, b64img)),
@@ -42,6 +43,18 @@ func TestMessageDecode(t *testing.T) {
 				"content": b64img,
 				"content": b64img,
 			},
 			},
 		},
 		},
+		{
+			payload: []byte(`[{"a":1},{"a":2}]`),
+			format:  "json",
+			results: []interface{}{
+				map[string]interface{}{
+					"a": float64(1),
+				},
+				map[string]interface{}{
+					"a": float64(2),
+				},
+			},
+		},
 	}
 	}
 	conv, _ := GetConverter()
 	conv, _ := GetConverter()
 	for i, tt := range tests {
 	for i, tt := range tests {
@@ -49,8 +62,14 @@ func TestMessageDecode(t *testing.T) {
 		if err != nil {
 		if err != nil {
 			t.Errorf("%d decode error: %v", i, err)
 			t.Errorf("%d decode error: %v", i, err)
 		}
 		}
-		if !reflect.DeepEqual(tt.result, result) {
-			t.Errorf("%d result mismatch:\n\nexp=%s\n\ngot=%s\n\n", i, tt.result, result)
+		if len(tt.results) > 0 {
+			if !reflect.DeepEqual(tt.results, result) {
+				t.Errorf("%d result mismatch:\n\nexp=%s\n\ngot=%s\n\n", i, tt.result, result)
+			}
+		} else {
+			if !reflect.DeepEqual(tt.result, result) {
+				t.Errorf("%d result mismatch:\n\nexp=%s\n\ngot=%s\n\n", i, tt.result, result)
+			}
 		}
 		}
 	}
 	}
 }
 }

+ 36 - 22
internal/io/mqtt/mqtt_source.go

@@ -122,39 +122,45 @@ func subscribe(ms *MQTTSource, ctx api.StreamContext, consumer chan<- api.Source
 		return e
 		return e
 	} else {
 	} else {
 		log.Infof("Successfully subscribed to topic %s.", ms.tpc)
 		log.Infof("Successfully subscribed to topic %s.", ms.tpc)
-		var t api.SourceTuple
+		var tuples []api.SourceTuple
 		for {
 		for {
 			select {
 			select {
 			case <-ctx.Done():
 			case <-ctx.Done():
 				log.Infof("Exit subscription to mqtt messagebus topic %s.", ms.tpc)
 				log.Infof("Exit subscription to mqtt messagebus topic %s.", ms.tpc)
 				return nil
 				return nil
 			case e1 := <-err:
 			case e1 := <-err:
-				t = &xsql.ErrorSourceTuple{
-					Error: fmt.Errorf("the subscription to mqtt topic %s have error %s.\n", ms.tpc, e1.Error()),
+				tuples = []api.SourceTuple{
+					&xsql.ErrorSourceTuple{
+						Error: fmt.Errorf("the subscription to mqtt topic %s have error %s.\n", ms.tpc, e1.Error()),
+					},
 				}
 				}
 			case env, ok := <-messages:
 			case env, ok := <-messages:
 				if !ok { // the source is closed
 				if !ok { // the source is closed
 					log.Infof("Exit subscription to mqtt messagebus topic %s.", ms.tpc)
 					log.Infof("Exit subscription to mqtt messagebus topic %s.", ms.tpc)
 					return nil
 					return nil
 				}
 				}
-				t = getTuple(ctx, ms, env)
+				tuples = getTuples(ctx, ms, env)
 			}
 			}
-			select {
-			case consumer <- t:
-				log.Debugf("send data to source node")
-			case <-ctx.Done():
-				return nil
+			for _, t := range tuples {
+				select {
+				case consumer <- t:
+					log.Debugf("send data to source node")
+				case <-ctx.Done():
+					return nil
+				}
 			}
 			}
 		}
 		}
 	}
 	}
 }
 }
 
 
-func getTuple(ctx api.StreamContext, ms *MQTTSource, env interface{}) api.SourceTuple {
+func getTuples(ctx api.StreamContext, ms *MQTTSource, env interface{}) []api.SourceTuple {
 	rcvTime := conf.GetNow()
 	rcvTime := conf.GetNow()
 	msg, ok := env.(pahoMqtt.Message)
 	msg, ok := env.(pahoMqtt.Message)
 	if !ok { // should never happen
 	if !ok { // should never happen
-		return &xsql.ErrorSourceTuple{
-			Error: fmt.Errorf("can not convert interface data to mqtt message %v.", env),
+		return []api.SourceTuple{
+			&xsql.ErrorSourceTuple{
+				Error: fmt.Errorf("can not convert interface data to mqtt message %v.", env),
+			},
 		}
 		}
 	}
 	}
 	payload := msg.Payload()
 	payload := msg.Payload()
@@ -162,29 +168,37 @@ func getTuple(ctx api.StreamContext, ms *MQTTSource, env interface{}) api.Source
 	if ms.decompressor != nil {
 	if ms.decompressor != nil {
 		payload, err = ms.decompressor.Decompress(payload)
 		payload, err = ms.decompressor.Decompress(payload)
 		if err != nil {
 		if err != nil {
-			return &xsql.ErrorSourceTuple{
-				Error: fmt.Errorf("can not decompress mqtt message %v.", err),
+			return []api.SourceTuple{
+				&xsql.ErrorSourceTuple{
+					Error: fmt.Errorf("can not decompress mqtt message %v.", err),
+				},
 			}
 			}
 		}
 		}
 	}
 	}
-	result, e := ctx.Decode(payload)
+	results, e := ctx.DecodeIntoList(payload)
 	//The unmarshal type can only be bool, float64, string, []interface{}, map[string]interface{}, nil
 	//The unmarshal type can only be bool, float64, string, []interface{}, map[string]interface{}, nil
 	if e != nil {
 	if e != nil {
-		return &xsql.ErrorSourceTuple{
-			Error: fmt.Errorf("Invalid data format, cannot decode %s with error %s", string(msg.Payload()), e),
+		return []api.SourceTuple{
+			&xsql.ErrorSourceTuple{
+				Error: fmt.Errorf("Invalid data format, cannot decode %s with error %s", string(msg.Payload()), e),
+			},
 		}
 		}
 	}
 	}
 	meta := make(map[string]interface{})
 	meta := make(map[string]interface{})
 	meta["topic"] = msg.Topic()
 	meta["topic"] = msg.Topic()
 	meta["messageid"] = strconv.Itoa(int(msg.MessageID()))
 	meta["messageid"] = strconv.Itoa(int(msg.MessageID()))
 
 
-	if nil != ms.model {
-		sliErr := ms.model.checkType(result, msg.Topic())
-		for _, v := range sliErr {
-			ctx.GetLogger().Errorf(v)
+	tuples := make([]api.SourceTuple, 0, len(results))
+	for _, result := range results {
+		if nil != ms.model {
+			sliErr := ms.model.checkType(result, msg.Topic())
+			for _, v := range sliErr {
+				ctx.GetLogger().Errorf(v)
+			}
 		}
 		}
+		tuples = append(tuples, api.NewDefaultSourceTupleWithTime(result, meta, rcvTime))
 	}
 	}
-	return api.NewDefaultSourceTupleWithTime(result, meta, rcvTime)
+	return tuples
 }
 }
 
 
 func (ms *MQTTSource) Close(ctx api.StreamContext) error {
 func (ms *MQTTSource) Close(ctx api.StreamContext) error {

+ 12 - 11
internal/io/mqtt/mqtt_source_test.go

@@ -52,18 +52,19 @@ func TestGetTupleWithZlibCompressor(t *testing.T) {
 		topic:   "test/topic",
 		topic:   "test/topic",
 	}
 	}
 	// Call getTuple with the mock MQTT message
 	// Call getTuple with the mock MQTT message
-	result := getTuple(ctx, ms, msg)
-
-	// Check if the result is a valid SourceTuple and has the correct content
-	if st, ok := result.(api.SourceTuple); ok {
-		if !reflect.DeepEqual(st.Message(), map[string]interface{}{"key": "value"}) {
-			t.Errorf("Expected message to be %v, but got %v", map[string]interface{}{"key": "value"}, st.Message())
-		}
-		if !reflect.DeepEqual(st.Meta(), map[string]interface{}{"topic": "test/topic", "messageid": "1"}) {
-			t.Errorf("Expected metadata to be %v, but got %v", map[string]interface{}{"topic": "test/topic", "messageid": "1"}, st.Meta())
+	results := getTuples(ctx, ms, msg)
+	for _, result := range results {
+		// Check if the result is a valid SourceTuple and has the correct content
+		if st, ok := result.(api.SourceTuple); ok {
+			if !reflect.DeepEqual(st.Message(), map[string]interface{}{"key": "value"}) {
+				t.Errorf("Expected message to be %v, but got %v", map[string]interface{}{"key": "value"}, st.Message())
+			}
+			if !reflect.DeepEqual(st.Meta(), map[string]interface{}{"topic": "test/topic", "messageid": "1"}) {
+				t.Errorf("Expected metadata to be %v, but got %v", map[string]interface{}{"topic": "test/topic", "messageid": "1"}, st.Meta())
+			}
+		} else {
+			t.Errorf("Expected result to be a SourceTuple, but got %T", result)
 		}
 		}
-	} else {
-		t.Errorf("Expected result to be a SourceTuple, but got %T", result)
 	}
 	}
 }
 }
 
 

+ 30 - 0
internal/topo/context/decoder.go

@@ -37,3 +37,33 @@ func (c *DefaultContext) Decode(data []byte) (map[string]interface{}, error) {
 	}
 	}
 	return nil, fmt.Errorf("no decoder configured")
 	return nil, fmt.Errorf("no decoder configured")
 }
 }
+
+func (c *DefaultContext) DecodeIntoList(data []byte) ([]map[string]interface{}, error) {
+	v := c.Value(DecodeKey)
+	f, ok := v.(message.Converter)
+	if ok {
+		t, err := f.Decode(data)
+		if err != nil {
+			return nil, fmt.Errorf("decode failed: %v", err)
+		}
+		typeErr := fmt.Errorf("only map[string]interface{} and []map[string]interface{} is supported but got: %v", t)
+		switch r := t.(type) {
+		case map[string]interface{}:
+			return []map[string]interface{}{r}, nil
+		case []map[string]interface{}:
+			return r, nil
+		case []interface{}:
+			rs := make([]map[string]interface{}, len(r))
+			for i, v := range r {
+				if vc, ok := v.(map[string]interface{}); ok {
+					rs[i] = vc
+				} else {
+					return nil, typeErr
+				}
+			}
+			return rs, nil
+		}
+		return nil, typeErr
+	}
+	return nil, fmt.Errorf("no decoder configured")
+}

+ 2 - 0
pkg/api/stream.go

@@ -207,6 +207,8 @@ type StreamContext interface {
 	// Decode is set in the source according to the format.
 	// Decode is set in the source according to the format.
 	// It decodes byte array into map or map slice.
 	// It decodes byte array into map or map slice.
 	Decode(data []byte) (map[string]interface{}, error)
 	Decode(data []byte) (map[string]interface{}, error)
+
+	DecodeIntoList(data []byte) ([]map[string]interface{}, error)
 }
 }
 
 
 type Operator interface {
 type Operator interface {