Browse Source

fix(format): pb import and message array (#2165)

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
ngjaying 1 year ago
parent
commit
04f03771fa

+ 5 - 2
internal/converter/protobuf/converter.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ import (
 	"github.com/jhump/protoreflect/desc"
 	"github.com/jhump/protoreflect/desc/protoparse"
 
+	kconf "github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/converter/static"
 	"github.com/lf-edge/ekuiper/pkg/message"
 )
@@ -32,7 +33,9 @@ type Converter struct {
 var protoParser *protoparse.Parser
 
 func init() {
-	protoParser = &protoparse.Parser{}
+	etcDir, _ := kconf.GetLoc("etc/schemas/protobuf/")
+	dataDir, _ := kconf.GetLoc("data/schemas/protobuf/")
+	protoParser = &protoparse.Parser{ImportPaths: []string{etcDir, dataDir}}
 }
 
 func NewConverter(schemaFile string, soFile string, messageName string) (message.Converter, error) {

+ 31 - 9
internal/converter/protobuf/converter_test.go

@@ -21,6 +21,8 @@ import (
 	"reflect"
 	"testing"
 
+	"github.com/stretchr/testify/assert"
+
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/schema"
 	"github.com/lf-edge/ekuiper/internal/testx"
@@ -42,7 +44,7 @@ func TestEncode(t *testing.T) {
 				"id":   1,
 				"age":  1,
 			},
-			r: []byte{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x1a, 0x00},
+			r: []byte{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01},
 		}, {
 			m: map[string]interface{}{
 				"name":  "test",
@@ -50,6 +52,16 @@ func TestEncode(t *testing.T) {
 				"email": "Dddd",
 			},
 			r: []byte{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x1a, 0x04, 0x44, 0x64, 0x64, 0x64},
+		}, {
+			m: map[string]interface{}{
+				"name": "test",
+				"id":   1,
+				"code": []any{
+					map[string]any{"doubles": []any{1.1, 2.2, 3.3}},
+					map[string]any{"doubles": []any{3.3, 1.1}},
+				},
+			},
+			r: []byte{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x22, 0x1b, 0x09, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xf1, 0x3f, 0x09, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0x01, 0x40, 0x09, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x0a, 0x40, 0x22, 0x12, 0x09, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x0a, 0x40, 0x09, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xf1, 0x3f},
 		},
 	}
 	fmt.Printf("The test bucket size is %d.\n\n", len(tests))
@@ -109,7 +121,6 @@ func TestDecode(t *testing.T) {
 	tests := []struct {
 		m map[string]interface{}
 		r []byte
-		e string
 	}{
 		{
 			m: map[string]interface{}{
@@ -120,15 +131,26 @@ func TestDecode(t *testing.T) {
 			},
 			r: []byte{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x1a, 0x04, 0x44, 0x64, 0x64, 0x64},
 		},
+		{
+			m: map[string]interface{}{
+				"name":  "test",
+				"id":    int64(1),
+				"email": "",
+				"code": []map[string]any{
+					{"doubles": []float64{1.1, 2.2, 3.3}},
+					{"doubles": []float64{3.3, 1.1}},
+				},
+			},
+			r: []byte{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x22, 0x1b, 0x09, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xf1, 0x3f, 0x09, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0x01, 0x40, 0x09, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x0a, 0x40, 0x22, 0x12, 0x09, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x0a, 0x40, 0x09, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xf1, 0x3f},
+		},
 	}
-	fmt.Printf("The test bucket size is %d.\n\n", len(tests))
+
 	for i, tt := range tests {
-		a, err := c.Decode(tt.r)
-		if !reflect.DeepEqual(tt.e, testx.Errstring(err)) {
-			t.Errorf("%d.error mismatch:\n  exp=%s\n  got=%s\n\n", i, tt.e, err)
-		} else if tt.e == "" && !reflect.DeepEqual(tt.m, a) {
-			t.Errorf("%d. \n\nresult mismatch:\n\nexp=%v\n\ngot=%v\n\n", i, tt.m, a)
-		}
+		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
+			a, err := c.Decode(tt.r)
+			assert.NoError(t, err)
+			assert.Equal(t, tt.m, a)
+		})
 	}
 }
 

+ 2 - 2
internal/converter/protobuf/fieldConverterSingleton.go

@@ -72,7 +72,7 @@ func (fc *FieldConverter) encodeMap(im *desc.MessageDescriptor, i interface{}) (
 				if field.IsRequired() {
 					return nil, fmt.Errorf("field %s not found", field.GetName())
 				} else {
-					v = field.GetDefaultValue()
+					continue
 				}
 			}
 			fv, err := fc.EncodeField(field, v)
@@ -135,7 +135,7 @@ func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}
 			result, err = cast.ToBytesSlice(v, cast.STRICT)
 		case dpb.FieldDescriptorProto_TYPE_MESSAGE:
 			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-				r, err := cast.ToStringMap(v)
+				r, err := cast.ToStringMap(input)
 				if err == nil {
 					return fc.encodeMap(field.GetMessageType(), r)
 				} else {