Browse Source

fix(proto): support enum type

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
Jiyong Huang 1 year ago
parent
commit
1d5f2b2584

+ 39 - 1
internal/converter/protobuf/converter_test.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -63,6 +63,44 @@ func TestEncode(t *testing.T) {
 	}
 }
 
+func TestEmbedType(t *testing.T) {
+	c, err := NewConverter("../../schema/test/test3.proto", "", "DrivingData")
+	if err != nil {
+		t.Fatal(err)
+	}
+	tests := []struct {
+		m map[string]interface{}
+		r []byte
+		e string
+	}{
+		{
+			m: map[string]interface{}{
+				"drvg_mod": int64(1),
+				"brk_pedal_sts": map[string]interface{}{
+					"valid": int64(0),
+				},
+				"average_speed": 90.56,
+			},
+			r: []byte{0x08, 0x01, 0x11, 0xa4, 0x70, 0x3d, 0x0a, 0xd7, 0xa3, 0x56, 0x40, 0x1a, 0x02, 0x08, 0x00},
+		},
+	}
+	fmt.Printf("The test bucket size is %d.\n\n", len(tests))
+	for i, tt := range tests {
+		a, err := c.Encode(tt.m)
+		if !reflect.DeepEqual(tt.e, testx.Errstring(err)) {
+			t.Errorf("%d.error mismatch:\n  exp=%s\n  got=%s\n\n", i, tt.e, err)
+		} else if tt.e == "" && !reflect.DeepEqual(tt.r, a) {
+			t.Errorf("%d. \n\nresult mismatch:\n\nexp=%x\n\ngot=%x\n\n", i, tt.r, a)
+		}
+		m, err := c.Decode(a)
+		if !reflect.DeepEqual(tt.e, testx.Errstring(err)) {
+			t.Errorf("%d.error mismatch:\n  exp=%s\n  got=%s\n\n", i, tt.e, err)
+		} else if tt.e == "" && !reflect.DeepEqual(tt.m, m) {
+			t.Errorf("%d. \n\nresult mismatch:\n\nexp=%v\n\ngot=%v\n\n", i, tt.m, m)
+		}
+	}
+}
+
 func TestDecode(t *testing.T) {
 	c, err := NewConverter("../../schema/test/test1.proto", "", "Person")
 	if err != nil {

+ 7 - 7
internal/converter/protobuf/fieldConverterSingleton.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -98,11 +98,11 @@ func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}
 			result, err = cast.ToFloat64Slice(v, cast.STRICT)
 		case dpb.FieldDescriptorProto_TYPE_FLOAT:
 			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-				r, err := cast.ToFloat64(input, sn)
+				r, err := cast.ToFloat32(input, sn)
 				if err != nil {
 					return 0, nil
 				} else {
-					return float32(r), nil
+					return r, nil
 				}
 			}, "float", cast.STRICT)
 		case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32:
@@ -165,13 +165,13 @@ func (fc *FieldConverter) encodeSingleField(field *desc.FieldDescriptor, v inter
 			return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
 		}
 	case dpb.FieldDescriptorProto_TYPE_FLOAT:
-		r, err := cast.ToFloat64(v, cast.STRICT)
+		r, err := cast.ToFloat32(v, cast.STRICT)
 		if err == nil {
-			return float32(r), nil
+			return r, nil
 		} else {
 			return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
 		}
-	case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32:
+	case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32, dpb.FieldDescriptorProto_TYPE_ENUM:
 		r, err := cast.ToInt(v, cast.STRICT)
 		if err == nil {
 			return int32(r), nil
@@ -245,7 +245,7 @@ func (fc *FieldConverter) DecodeField(src interface{}, field *desc.FieldDescript
 		} else {
 			r, e = cast.ToFloat64(src, sn)
 		}
-	case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32, dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64, dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32, dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
+	case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32, dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64, dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32, dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64, dpb.FieldDescriptorProto_TYPE_ENUM:
 		if field.IsRepeated() {
 			r, e = cast.ToInt64Slice(src, sn)
 		} else {

+ 2 - 2
internal/schema/registry_test.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -78,7 +78,7 @@ func TestProtoRegistry(t *testing.T) {
 	expectedSchema := &Info{
 		Type:     "protobuf",
 		Name:     "test1",
-		Content:  "syntax = \"proto2\";message Person {required string name = 1;optional int32 id = 2;optional string email = 3;repeated ListOfDoubles code = 4;}message ListOfDoubles {repeated double doubles=1;}",
+		Content:  "syntax = \"proto2\";message Person {required string name = 1;optional int32 id = 2;optional string email = 3;repeated ListOfDoubles code = 4;}message ListOfDoubles {repeated double doubles = 1;}",
 		FilePath: filepath.Join(etcDir, "test1.proto"),
 	}
 	gottenSchema, err := GetSchema("protobuf", "test1")

+ 1 - 1
internal/schema/test/test1.proto

@@ -1 +1 @@
-syntax = "proto2";message Person {required string name = 1;optional int32 id = 2;optional string email = 3;repeated ListOfDoubles code = 4;}message ListOfDoubles {repeated double doubles=1;}
+syntax = "proto2";message Person {required string name = 1;optional int32 id = 2;optional string email = 3;repeated ListOfDoubles code = 4;}message ListOfDoubles {repeated double doubles = 1;}

+ 20 - 0
internal/schema/test/test3.proto

@@ -0,0 +1,20 @@
+message DrivingData {
+  optional DrvgMod drvg_mod = 1;
+  optional double average_speed = 2;
+  optional BrkPedalStatus brk_pedal_sts = 3;
+
+  message BrkPedalStatus {
+    optional BrkPedalValid valid = 1;
+  }
+
+  enum BrkPedalValid {
+    BRK_PED_VALID = 0;
+    BRK_PED_INVALID = 1;
+  }
+
+  enum DrvgMod {
+    AUTO_MODE = 0;
+    ECO_MODE = 1;
+    SPORT_MODE = 2;
+  }
+}