Browse Source

fix(schema): pb support enum type when refer schema

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
Jiyong Huang 1 year ago
parent
commit
5041dbaa4b

+ 3 - 2
internal/schema/ext_inferer_protobuf.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -108,7 +108,8 @@ func convertFieldType(tt dpb.FieldDescriptorProto_Type, f *desc.FieldDescriptor)
 	case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32,
 		dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64,
 		dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32,
-		dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
+		dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64,
+		dpb.FieldDescriptorProto_TYPE_ENUM:
 		ft = &ast.BasicType{Type: ast.BIGINT}
 	case dpb.FieldDescriptorProto_TYPE_BOOL:
 		ft = &ast.BasicType{Type: ast.BOOLEAN}

+ 53 - 1
internal/schema/ext_inferer_protobuf_test.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -22,6 +22,8 @@ import (
 	"reflect"
 	"testing"
 
+	"github.com/stretchr/testify/assert"
+
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/testx"
 	"github.com/lf-edge/ekuiper/pkg/ast"
@@ -80,3 +82,53 @@ func TestInferProtobuf(t *testing.T) {
 		t.Errorf("InferProtobuf result is not expected, got %v, expected %v", result, expected)
 	}
 }
+
+func TestInferProtobufWithEmbedType(t *testing.T) {
+	testx.InitEnv()
+	// Move test schema file to etc dir
+	etcDir, err := conf.GetDataLoc()
+	if err != nil {
+		t.Fatal(err)
+	}
+	etcDir = filepath.Join(etcDir, "schemas", "protobuf")
+	err = os.MkdirAll(etcDir, os.ModePerm)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Copy init.proto
+	bytesRead, err := os.ReadFile("test/test3.proto")
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = os.WriteFile(filepath.Join(etcDir, "test3.proto"), bytesRead, 0o755)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		err = os.RemoveAll(etcDir)
+		if err != nil {
+			t.Fatal(err)
+		}
+	}()
+	err = InitRegistry()
+	if err != nil {
+		t.Errorf("InitRegistry error: %v", err)
+		return
+	}
+	// Test infer
+	result, err := InferProtobuf("test3", "DrivingData")
+	if err != nil {
+		t.Errorf("InferProtobuf error: %v", err)
+		return
+	}
+	expected := ast.StreamFields{
+		{Name: "drvg_mod", FieldType: &ast.BasicType{Type: ast.BIGINT}},
+		{Name: "average_speed", FieldType: &ast.BasicType{Type: ast.FLOAT}},
+		{Name: "brk_pedal_sts", FieldType: &ast.RecType{StreamFields: []ast.StreamField{
+			{Name: "valid", FieldType: &ast.BasicType{Type: ast.BIGINT}},
+		}}},
+	}
+	if !assert.Equal(t, expected, result) {
+		t.Errorf("InferProtobuf result is not expected, got %v, expected %v", result, expected)
+	}
+}