浏览代码

feat(sink): support general property 'format' (#1274)

* feat(sink): support general property 'format'

1. Support protobuf 'format' in sink, must use with 'schemaId' together
2. Refactor protobuf schema encoder to be shared between external service and sink format

Signed-off-by: Jiyong Huang <huangjy@emqx.io>

* refactor: move protobuf schema file location

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
ngjaying 2 年之前
父节点
当前提交
518ef606fe

+ 74 - 0
internal/schema/protobuf/converter.go

@@ -0,0 +1,74 @@
+// Copyright 2022 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package protobuf
+
+import (
+	"fmt"
+	"github.com/jhump/protoreflect/desc"
+	"github.com/jhump/protoreflect/desc/protoparse"
+	"strings"
+)
+
+type Converter struct {
+	descriptor *desc.MessageDescriptor
+	fc         *FieldConverter
+}
+
+var protoParser *protoparse.Parser
+
+func init() {
+	//etcDir, err := conf.GetConfLoc()
+	//if err != nil {
+	//	panic(err)
+	//}
+	protoParser = &protoparse.Parser{}
+}
+
+func NewConverter(schemaId string, fileName string) (*Converter, error) {
+	des := strings.Split(schemaId, ".")
+	if len(des) != 2 {
+		return nil, fmt.Errorf("invalid schema id %s for protobuf, the format must be protoName.mesageName", schemaId)
+	}
+	if fds, err := protoParser.ParseFiles(fileName); err != nil {
+		return nil, fmt.Errorf("parse schema file %s failed: %s", fileName, err)
+	} else {
+		messageDescriptor := fds[0].FindMessage(des[1])
+		if messageDescriptor == nil {
+			return nil, fmt.Errorf("message type %s not found in schema file %s", des[1], fileName)
+		}
+		return &Converter{
+			descriptor: messageDescriptor,
+			fc:         GetFieldConverter(),
+		}, nil
+	}
+}
+
+func (c *Converter) Encode(d interface{}) ([]byte, error) {
+	switch m := d.(type) {
+	case map[string]interface{}:
+		msg, err := c.fc.encodeMap(c.descriptor, m)
+		if err != nil {
+			return nil, err
+		}
+		return msg.Marshal()
+	default:
+		return nil, fmt.Errorf("unsupported type %v, must be a map", d)
+	}
+}
+
+func (c *Converter) Decode(b []byte) (interface{}, error) {
+	//TODO implement me
+	panic("implement me")
+}

+ 59 - 0
internal/schema/protobuf/converter_test.go

@@ -0,0 +1,59 @@
+// Copyright 2022 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package protobuf
+
+import (
+	"fmt"
+	"github.com/lf-edge/ekuiper/internal/testx"
+	"reflect"
+	"testing"
+)
+
+func TestEncode(t *testing.T) {
+	c, err := NewConverter("test1.Person", "../test/test1.proto")
+	if err != nil {
+		t.Fatal(err)
+	}
+	tests := []struct {
+		m map[string]interface{}
+		r []byte
+		e string
+	}{
+		{
+			m: map[string]interface{}{
+				"name": "test",
+				"id":   1,
+				"age":  1,
+			},
+			e: "field email not found",
+		}, {
+			m: map[string]interface{}{
+				"name":  "test",
+				"id":    1,
+				"email": "Dddd",
+			},
+			r: []byte{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x1a, 0x04, 0x44, 0x64, 0x64, 0x64},
+		},
+	}
+	fmt.Printf("The test bucket size is %d.\n\n", len(tests))
+	for i, tt := range tests {
+		a, err := c.Encode(tt.m)
+		if !reflect.DeepEqual(tt.e, testx.Errstring(err)) {
+			t.Errorf("%d.error mismatch:\n  exp=%s\n  got=%s\n\n", i, tt.e, err)
+		} else if tt.e == "" && !reflect.DeepEqual(tt.r, a) {
+			t.Errorf("%d. \n\nresult mismatch:\n\nexp=%x\n\ngot=%x\n\n", i, tt.r, a)
+		}
+	}
+}

+ 200 - 0
internal/schema/protobuf/fieldConverterSingleton.go

@@ -0,0 +1,200 @@
+// Copyright 2022 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package protobuf
+
+import (
+	"fmt"
+	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
+	"github.com/jhump/protoreflect/desc"
+	"github.com/jhump/protoreflect/dynamic"
+	"github.com/lf-edge/ekuiper/pkg/cast"
+)
+
+var (
+	fieldConverterIns = &FieldConverter{}
+	mf                = dynamic.NewMessageFactoryWithDefaults()
+)
+
+type FieldConverter struct{}
+
+func GetFieldConverter() *FieldConverter {
+	return fieldConverterIns
+}
+
+func (fc *FieldConverter) encodeMap(im *desc.MessageDescriptor, i interface{}) (*dynamic.Message, error) {
+	result := mf.NewDynamicMessage(im)
+	fields := im.GetFields()
+	if m, ok := i.(map[string]interface{}); ok {
+		for _, field := range fields {
+			v, ok := m[field.GetName()]
+			if !ok {
+				return nil, fmt.Errorf("field %s not found", field.GetName())
+			}
+			fv, err := fc.EncodeField(field, v)
+			if err != nil {
+				return nil, err
+			}
+			result.SetFieldByName(field.GetName(), fv)
+		}
+	}
+	return result, nil
+}
+
+func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
+	fn := field.GetName()
+	ft := field.GetType()
+	if field.IsRepeated() {
+		var (
+			result interface{}
+			err    error
+		)
+		switch ft {
+		case dpb.FieldDescriptorProto_TYPE_DOUBLE:
+			result, err = cast.ToFloat64Slice(v, cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_FLOAT:
+			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
+				r, err := cast.ToFloat64(input, sn)
+				if err != nil {
+					return 0, nil
+				} else {
+					return float32(r), nil
+				}
+			}, "float", cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32:
+			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
+				r, err := cast.ToInt(input, sn)
+				if err != nil {
+					return 0, nil
+				} else {
+					return int32(r), nil
+				}
+			}, "int", cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64:
+			result, err = cast.ToInt64Slice(v, cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32:
+			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
+				r, err := cast.ToUint64(input, sn)
+				if err != nil {
+					return 0, nil
+				} else {
+					return uint32(r), nil
+				}
+			}, "uint", cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
+			result, err = cast.ToUint64Slice(v, cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_BOOL:
+			result, err = cast.ToBoolSlice(v, cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_STRING:
+			result, err = cast.ToStringSlice(v, cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_BYTES:
+			result, err = cast.ToBytesSlice(v, cast.STRICT)
+		case dpb.FieldDescriptorProto_TYPE_MESSAGE:
+			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
+				r, err := cast.ToStringMap(v)
+				if err == nil {
+					return fc.encodeMap(field.GetMessageType(), r)
+				} else {
+					return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
+				}
+			}, "map", cast.STRICT)
+		default:
+			return nil, fmt.Errorf("invalid type for field '%s'", fn)
+		}
+		if err != nil {
+			err = fmt.Errorf("failed to encode field '%s':%v", fn, err)
+		}
+		return result, err
+	} else {
+		return fc.encodeSingleField(field, v)
+	}
+}
+
+func (fc *FieldConverter) encodeSingleField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
+	fn := field.GetName()
+	switch field.GetType() {
+	case dpb.FieldDescriptorProto_TYPE_DOUBLE:
+		r, err := cast.ToFloat64(v, cast.STRICT)
+		if err == nil {
+			return r, nil
+		} else {
+			return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_FLOAT:
+		r, err := cast.ToFloat64(v, cast.STRICT)
+		if err == nil {
+			return float32(r), nil
+		} else {
+			return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32:
+		r, err := cast.ToInt(v, cast.STRICT)
+		if err == nil {
+			return int32(r), nil
+		} else {
+			return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64:
+		r, err := cast.ToInt64(v, cast.STRICT)
+		if err == nil {
+			return r, nil
+		} else {
+			return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32:
+		r, err := cast.ToUint64(v, cast.STRICT)
+		if err == nil {
+			return uint32(r), nil
+		} else {
+			return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
+		r, err := cast.ToUint64(v, cast.STRICT)
+		if err == nil {
+			return r, nil
+		} else {
+			return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_BOOL:
+		r, err := cast.ToBool(v, cast.STRICT)
+		if err == nil {
+			return r, nil
+		} else {
+			return nil, fmt.Errorf("invalid type for bool type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_STRING:
+		r, err := cast.ToString(v, cast.STRICT)
+		if err == nil {
+			return r, nil
+		} else {
+			return nil, fmt.Errorf("invalid type for string type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_BYTES:
+		r, err := cast.ToBytes(v, cast.STRICT)
+		if err == nil {
+			return r, nil
+		} else {
+			return nil, fmt.Errorf("invalid type for bytes type field '%s': %v", fn, err)
+		}
+	case dpb.FieldDescriptorProto_TYPE_MESSAGE:
+		r, err := cast.ToStringMap(v)
+		if err == nil {
+			return fc.encodeMap(field.GetMessageType(), r)
+		} else {
+			return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
+		}
+	default:
+		return nil, fmt.Errorf("invalid type for field '%s'", fn)
+	}
+}

+ 19 - 11
internal/schema/registry.go

@@ -50,7 +50,7 @@ func InitRegistry() error {
 		return fmt.Errorf("cannot find etc folder: %s", err)
 		return fmt.Errorf("cannot find etc folder: %s", err)
 	}
 	}
 	for _, schemaType := range def.SchemaTypes {
 	for _, schemaType := range def.SchemaTypes {
-		schemaDir := filepath.Join(etcDir, string(schemaType))
+		schemaDir := filepath.Join(etcDir, "schemas", string(schemaType))
 		var newSchemas map[string]string
 		var newSchemas map[string]string
 		files, err := ioutil.ReadDir(schemaDir)
 		files, err := ioutil.ReadDir(schemaDir)
 		if err != nil {
 		if err != nil {
@@ -97,7 +97,7 @@ func CreateOrUpdateSchema(info *Info) error {
 		return fmt.Errorf("schema type %s not found", info.Type)
 		return fmt.Errorf("schema type %s not found", info.Type)
 	}
 	}
 	etcDir, _ := conf.GetConfLoc()
 	etcDir, _ := conf.GetConfLoc()
-	etcDir = filepath.Join(etcDir, string(info.Type))
+	etcDir = filepath.Join(etcDir, "schemas", string(info.Type))
 	if err := os.MkdirAll(etcDir, os.ModePerm); err != nil {
 	if err := os.MkdirAll(etcDir, os.ModePerm); err != nil {
 		return err
 		return err
 	}
 	}
@@ -125,16 +125,11 @@ func CreateOrUpdateSchema(info *Info) error {
 	return nil
 	return nil
 }
 }
 
 
-func GetSchemaContent(schemaType def.SchemaType, name string) (*Info, error) {
-	registry.RLock()
-	defer registry.RUnlock()
-	if _, ok := registry.schemas[schemaType]; !ok {
-		return nil, fmt.Errorf("schema type %s not found", schemaType)
-	}
-	if _, ok := registry.schemas[schemaType][name]; !ok {
-		return nil, fmt.Errorf("schema %s.%s not found", schemaType, name)
+func GetSchema(schemaType def.SchemaType, name string) (*Info, error) {
+	schemaFile, err := getSchemaFile(schemaType, name)
+	if err != nil {
+		return nil, err
 	}
 	}
-	schemaFile := registry.schemas[schemaType][name]
 	content, err := ioutil.ReadFile(schemaFile)
 	content, err := ioutil.ReadFile(schemaFile)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("cannot read schema file %s: %s", schemaFile, err)
 		return nil, fmt.Errorf("cannot read schema file %s: %s", schemaFile, err)
@@ -147,6 +142,19 @@ func GetSchemaContent(schemaType def.SchemaType, name string) (*Info, error) {
 	}, nil
 	}, nil
 }
 }
 
 
+func getSchemaFile(schemaType def.SchemaType, name string) (string, error) {
+	registry.RLock()
+	defer registry.RUnlock()
+	if _, ok := registry.schemas[schemaType]; !ok {
+		return "", fmt.Errorf("schema type %s not found", schemaType)
+	}
+	if _, ok := registry.schemas[schemaType][name]; !ok {
+		return "", fmt.Errorf("schema %s.%s not found", schemaType, name)
+	}
+	schemaFile := registry.schemas[schemaType][name]
+	return schemaFile, nil
+}
+
 func DeleteSchema(schemaType def.SchemaType, name string) error {
 func DeleteSchema(schemaType def.SchemaType, name string) error {
 	registry.Lock()
 	registry.Lock()
 	defer registry.Unlock()
 	defer registry.Unlock()

+ 2 - 2
internal/schema/registry_test.go

@@ -33,7 +33,7 @@ func TestRegistry(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	etcDir = filepath.Join(etcDir, "protobuf")
+	etcDir = filepath.Join(etcDir, "schemas", "protobuf")
 	err = os.MkdirAll(etcDir, os.ModePerm)
 	err = os.MkdirAll(etcDir, os.ModePerm)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -81,7 +81,7 @@ func TestRegistry(t *testing.T) {
 		Content:  "syntax = \"proto3\";message Person {string name = 1;int32 id = 2;string email = 3;}",
 		Content:  "syntax = \"proto3\";message Person {string name = 1;int32 id = 2;string email = 3;}",
 		FilePath: filepath.Join(etcDir, "test1.proto"),
 		FilePath: filepath.Join(etcDir, "test1.proto"),
 	}
 	}
-	gottenSchema, err := GetSchemaContent("protobuf", "test1")
+	gottenSchema, err := GetSchema("protobuf", "test1")
 	if !reflect.DeepEqual(gottenSchema, expectedSchema) {
 	if !reflect.DeepEqual(gottenSchema, expectedSchema) {
 		t.Errorf("Get test1 unmatch: Expect\n%v\nbut got\n%v", *expectedSchema, *gottenSchema)
 		t.Errorf("Get test1 unmatch: Expect\n%v\nbut got\n%v", *expectedSchema, *gottenSchema)
 		return
 		return

+ 24 - 1
internal/schema/schema.go

@@ -14,7 +14,11 @@
 
 
 package schema
 package schema
 
 
-import "github.com/lf-edge/ekuiper/internal/pkg/def"
+import (
+	"fmt"
+	"github.com/lf-edge/ekuiper/internal/pkg/def"
+	"github.com/lf-edge/ekuiper/internal/schema/protobuf"
+)
 
 
 type Info struct {
 type Info struct {
 	Type     def.SchemaType `json:"type"`
 	Type     def.SchemaType `json:"type"`
@@ -28,3 +32,22 @@ var (
 		def.PROTOBUF: ".proto",
 		def.PROTOBUF: ".proto",
 	}
 	}
 )
 )
+
+// Converter converts bytes & map or []map according to the schema
+type Converter interface {
+	Encode(d interface{}) ([]byte, error)
+	Decode(b []byte) (interface{}, error)
+}
+
+func GetOrCreateSchema(t def.SchemaType, schemaFile string, schemaId string) (Converter, error) {
+	switch t {
+	case def.PROTOBUF:
+		fileName, err := getSchemaFile(t, schemaFile)
+		if err != nil {
+			return nil, err
+		}
+		return protobuf.NewConverter(schemaId, fileName)
+	default:
+		return nil, fmt.Errorf("unsupported schema type: %s", t)
+	}
+}

+ 1 - 1
internal/server/schema_init.go

@@ -85,7 +85,7 @@ func schemaHandler(w http.ResponseWriter, r *http.Request) {
 	name := vars["name"]
 	name := vars["name"]
 	switch r.Method {
 	switch r.Method {
 	case http.MethodGet:
 	case http.MethodGet:
-		j, err := schema.GetSchemaContent(def.SchemaType(st), name)
+		j, err := schema.GetSchema(def.SchemaType(st), name)
 		if err != nil {
 		if err != nil {
 			handleError(w, err, "", logger)
 			handleError(w, err, "", logger)
 			return
 			return

+ 7 - 170
internal/service/schema.go

@@ -1,4 +1,4 @@
-// Copyright 2021 EMQ Technologies Co., Ltd.
+// Copyright 2021-2022 EMQ Technologies Co., Ltd.
 //
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
 // you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import (
 	"github.com/jhump/protoreflect/desc/protoparse"
 	"github.com/jhump/protoreflect/desc/protoparse"
 	"github.com/jhump/protoreflect/dynamic"
 	"github.com/jhump/protoreflect/dynamic"
 	kconf "github.com/lf-edge/ekuiper/internal/conf"
 	kconf "github.com/lf-edge/ekuiper/internal/conf"
+	"github.com/lf-edge/ekuiper/internal/schema/protobuf"
 	"github.com/lf-edge/ekuiper/internal/xsql"
 	"github.com/lf-edge/ekuiper/internal/xsql"
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	_ "google.golang.org/genproto/googleapis/api/annotations"
 	_ "google.golang.org/genproto/googleapis/api/annotations"
@@ -121,6 +122,7 @@ func parse(schema schema, file string) (descriptor, error) {
 			result := &wrappedProtoDescriptor{
 			result := &wrappedProtoDescriptor{
 				FileDescriptor: fds[0],
 				FileDescriptor: fds[0],
 				mf:             dynamic.NewMessageFactoryWithDefaults(),
 				mf:             dynamic.NewMessageFactoryWithDefaults(),
+				fc:             protobuf.GetFieldConverter(),
 			}
 			}
 			err := result.parseHttpOptions()
 			err := result.parseHttpOptions()
 			if err != nil {
 			if err != nil {
@@ -138,6 +140,7 @@ type wrappedProtoDescriptor struct {
 	*desc.FileDescriptor
 	*desc.FileDescriptor
 	methodOptions map[string]*httpOptions
 	methodOptions map[string]*httpOptions
 	mf            *dynamic.MessageFactory
 	mf            *dynamic.MessageFactory
+	fc            *protobuf.FieldConverter
 }
 }
 
 
 //TODO support for duplicate names
 //TODO support for duplicate names
@@ -236,7 +239,7 @@ func (d *wrappedProtoDescriptor) convertParams(im *desc.MessageDescriptor, param
 		}
 		}
 		// For non map params, treat it as special case of multiple params
 		// For non map params, treat it as special case of multiple params
 		if len(fields) == 1 {
 		if len(fields) == 1 {
-			param0, err := d.encodeField(fields[0], params[0])
+			param0, err := d.fc.EncodeField(fields[0], params[0])
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
@@ -247,7 +250,7 @@ func (d *wrappedProtoDescriptor) convertParams(im *desc.MessageDescriptor, param
 	default:
 	default:
 		if len(fields) == len(params) {
 		if len(fields) == len(params) {
 			for i, field := range fields {
 			for i, field := range fields {
-				param, err := d.encodeField(field, params[i])
+				param, err := d.fc.EncodeField(field, params[i])
 				if err != nil {
 				if err != nil {
 					return nil, err
 					return nil, err
 				}
 				}
@@ -319,7 +322,7 @@ func (d *wrappedProtoDescriptor) unfoldMap(ft *desc.MessageDescriptor, i interfa
 			if !ok {
 			if !ok {
 				return nil, fmt.Errorf("field %s not found", field.GetName())
 				return nil, fmt.Errorf("field %s not found", field.GetName())
 			}
 			}
-			fv, err := d.encodeField(field, v)
+			fv, err := d.fc.EncodeField(field, v)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
@@ -331,172 +334,6 @@ func (d *wrappedProtoDescriptor) unfoldMap(ft *desc.MessageDescriptor, i interfa
 	return result, nil
 	return result, nil
 }
 }
 
 
-func (d *wrappedProtoDescriptor) encodeMap(im *desc.MessageDescriptor, i interface{}) (*dynamic.Message, error) {
-	result := d.mf.NewDynamicMessage(im)
-	fields := im.GetFields()
-	if m, ok := i.(map[string]interface{}); ok {
-		for _, field := range fields {
-			v, ok := m[field.GetName()]
-			if !ok {
-				return nil, fmt.Errorf("field %s not found", field.GetName())
-			}
-			fv, err := d.encodeField(field, v)
-			if err != nil {
-				return nil, err
-			}
-			result.SetFieldByName(field.GetName(), fv)
-		}
-	}
-	return result, nil
-}
-
-func (d *wrappedProtoDescriptor) encodeField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
-	fn := field.GetName()
-	ft := field.GetType()
-	if field.IsRepeated() {
-		var (
-			result interface{}
-			err    error
-		)
-		switch ft {
-		case dpb.FieldDescriptorProto_TYPE_DOUBLE:
-			result, err = cast.ToFloat64Slice(v, cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_FLOAT:
-			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-				r, err := cast.ToFloat64(input, sn)
-				if err != nil {
-					return 0, nil
-				} else {
-					return float32(r), nil
-				}
-			}, "float", cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32:
-			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-				r, err := cast.ToInt(input, sn)
-				if err != nil {
-					return 0, nil
-				} else {
-					return int32(r), nil
-				}
-			}, "int", cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64:
-			result, err = cast.ToInt64Slice(v, cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32:
-			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-				r, err := cast.ToUint64(input, sn)
-				if err != nil {
-					return 0, nil
-				} else {
-					return uint32(r), nil
-				}
-			}, "uint", cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
-			result, err = cast.ToUint64Slice(v, cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_BOOL:
-			result, err = cast.ToBoolSlice(v, cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_STRING:
-			result, err = cast.ToStringSlice(v, cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_BYTES:
-			result, err = cast.ToBytesSlice(v, cast.STRICT)
-		case dpb.FieldDescriptorProto_TYPE_MESSAGE:
-			result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-				r, err := cast.ToStringMap(v)
-				if err == nil {
-					return d.encodeMap(field.GetMessageType(), r)
-				} else {
-					return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
-				}
-			}, "map", cast.STRICT)
-		default:
-			return nil, fmt.Errorf("invalid type for field '%s'", fn)
-		}
-		if err != nil {
-			err = fmt.Errorf("failed to encode field '%s':%v", fn, err)
-		}
-		return result, err
-	} else {
-		return d.encodeSingleField(field, v)
-	}
-}
-
-func (d *wrappedProtoDescriptor) encodeSingleField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
-	fn := field.GetName()
-	switch field.GetType() {
-	case dpb.FieldDescriptorProto_TYPE_DOUBLE:
-		r, err := cast.ToFloat64(v, cast.STRICT)
-		if err == nil {
-			return r, nil
-		} else {
-			return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_FLOAT:
-		r, err := cast.ToFloat64(v, cast.STRICT)
-		if err == nil {
-			return float32(r), nil
-		} else {
-			return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32:
-		r, err := cast.ToInt(v, cast.STRICT)
-		if err == nil {
-			return int32(r), nil
-		} else {
-			return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64:
-		r, err := cast.ToInt64(v, cast.STRICT)
-		if err == nil {
-			return r, nil
-		} else {
-			return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32:
-		r, err := cast.ToUint64(v, cast.STRICT)
-		if err == nil {
-			return uint32(r), nil
-		} else {
-			return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
-		r, err := cast.ToUint64(v, cast.STRICT)
-		if err == nil {
-			return r, nil
-		} else {
-			return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_BOOL:
-		r, err := cast.ToBool(v, cast.STRICT)
-		if err == nil {
-			return r, nil
-		} else {
-			return nil, fmt.Errorf("invalid type for bool type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_STRING:
-		r, err := cast.ToString(v, cast.STRICT)
-		if err == nil {
-			return r, nil
-		} else {
-			return nil, fmt.Errorf("invalid type for string type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_BYTES:
-		r, err := cast.ToBytes(v, cast.STRICT)
-		if err == nil {
-			return r, nil
-		} else {
-			return nil, fmt.Errorf("invalid type for bytes type field '%s': %v", fn, err)
-		}
-	case dpb.FieldDescriptorProto_TYPE_MESSAGE:
-		r, err := cast.ToStringMap(v)
-		if err == nil {
-			return d.encodeMap(field.GetMessageType(), r)
-		} else {
-			return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
-		}
-	default:
-		return nil, fmt.Errorf("invalid type for field '%s'", fn)
-	}
-}
-
 func decodeMessage(message *dynamic.Message, outputType *desc.MessageDescriptor) interface{} {
 func decodeMessage(message *dynamic.Message, outputType *desc.MessageDescriptor) interface{} {
 	if _, ok := WRAPPER_TYPES[outputType.GetFullyQualifiedName()]; ok {
 	if _, ok := WRAPPER_TYPES[outputType.GetFullyQualifiedName()]; ok {
 		return message.GetFieldByNumber(1)
 		return message.GetFieldByNumber(1)

+ 10 - 1
internal/topo/node/sink_node.go

@@ -24,6 +24,7 @@ import (
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	"github.com/lf-edge/ekuiper/pkg/errorx"
 	"github.com/lf-edge/ekuiper/pkg/errorx"
 	"github.com/lf-edge/ekuiper/pkg/infra"
 	"github.com/lf-edge/ekuiper/pkg/infra"
+	"github.com/lf-edge/ekuiper/pkg/message"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -39,6 +40,8 @@ type SinkConf struct {
 	Omitempty         bool   `json:"omitIfEmpty"`
 	Omitempty         bool   `json:"omitIfEmpty"`
 	SendSingle        bool   `json:"sendSingle"`
 	SendSingle        bool   `json:"sendSingle"`
 	DataTemplate      string `json:"dataTemplate"`
 	DataTemplate      string `json:"dataTemplate"`
+	Format            string `json:"format"`
+	SchemaId          string `json:"schemaId"`
 }
 }
 
 
 type SinkNode struct {
 type SinkNode struct {
@@ -139,8 +142,14 @@ func (m *SinkNode) Open(ctx api.StreamContext, result chan<- error) {
 				logger.Warnf("invalid type for cacheSaveInterval property, should be positive integer but found %t", sconf.CacheSaveInterval)
 				logger.Warnf("invalid type for cacheSaveInterval property, should be positive integer but found %t", sconf.CacheSaveInterval)
 				sconf.CacheSaveInterval = 1000
 				sconf.CacheSaveInterval = 1000
 			}
 			}
+			if sconf.Format == "" {
+				sconf.Format = "json"
+			} else if sconf.Format != message.FormatJson && sconf.Format != message.FormatProtobuf {
+				logger.Warnf("invalid type for format property, should be json or protobuf but found %s", sconf.Format)
+				sconf.Format = "json"
+			}
 
 
-			tf, err := transform.GenTransform(sconf.DataTemplate)
+			tf, err := transform.GenTransform(sconf.DataTemplate, sconf.Format, sconf.SchemaId)
 			if err != nil {
 			if err != nil {
 				msg := fmt.Sprintf("property dataTemplate %v is invalid: %v", sconf.DataTemplate, err)
 				msg := fmt.Sprintf("property dataTemplate %v is invalid: %v", sconf.DataTemplate, err)
 				logger.Warnf(msg)
 				logger.Warnf(msg)

+ 77 - 0
internal/topo/node/sink_node_test.go

@@ -20,10 +20,14 @@ package node
 import (
 import (
 	"fmt"
 	"fmt"
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/conf"
+	"github.com/lf-edge/ekuiper/internal/schema"
 	"github.com/lf-edge/ekuiper/internal/topo/context"
 	"github.com/lf-edge/ekuiper/internal/topo/context"
 	"github.com/lf-edge/ekuiper/internal/topo/topotest/mocknode"
 	"github.com/lf-edge/ekuiper/internal/topo/topotest/mocknode"
 	"github.com/lf-edge/ekuiper/internal/topo/transform"
 	"github.com/lf-edge/ekuiper/internal/topo/transform"
 	"github.com/lf-edge/ekuiper/internal/xsql"
 	"github.com/lf-edge/ekuiper/internal/xsql"
+	"io/ioutil"
+	"os"
+	"path/filepath"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -222,3 +226,76 @@ func TestOmitEmpty_Apply(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestFormat_Apply(t *testing.T) {
+	conf.InitConf()
+	etcDir, err := conf.GetConfLoc()
+	if err != nil {
+		t.Fatal(err)
+	}
+	etcDir = filepath.Join(etcDir, "schemas", "protobuf")
+	err = os.MkdirAll(etcDir, os.ModePerm)
+	if err != nil {
+		t.Fatal(err)
+	}
+	//Copy init.proto
+	bytesRead, err := ioutil.ReadFile("../../schema/test/test1.proto")
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = ioutil.WriteFile(filepath.Join(etcDir, "test1.proto"), bytesRead, 0755)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		err = os.RemoveAll(etcDir)
+		if err != nil {
+			t.Fatal(err)
+		}
+	}()
+	schema.InitRegistry()
+	transform.RegisterAdditionalFuncs()
+	var tests = []struct {
+		config map[string]interface{}
+		data   []map[string]interface{}
+		result [][]byte
+	}{
+		{
+			config: map[string]interface{}{
+				"sendSingle": true,
+				"format":     `protobuf`,
+				"schemaId":   "test1.Person",
+			},
+			data: []map[string]interface{}{{
+				"name":  "test",
+				"id":    1,
+				"email": "Dddd",
+			}},
+			result: [][]byte{{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x1a, 0x04, 0x44, 0x64, 0x64, 0x64}},
+		}, {
+			config: map[string]interface{}{
+				"sendSingle":   true,
+				"dataTemplate": `{"name":"test","email":"{{.ab}}","id":1}`,
+				"format":       `protobuf`,
+				"schemaId":     "test1.Person",
+			},
+			data:   []map[string]interface{}{{"ab": "Dddd"}},
+			result: [][]byte{{0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01, 0x1a, 0x04, 0x44, 0x64, 0x64, 0x64}},
+		},
+	}
+	fmt.Printf("The test bucket size is %d.\n\n", len(tests))
+	contextLogger := conf.Log.WithField("rule", "TestSinkFormat_Apply")
+	ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
+
+	for i, tt := range tests {
+		mockSink := mocknode.NewMockSink()
+		s := NewSinkNodeWithSink("mockSink", mockSink, tt.config)
+		s.Open(ctx, make(chan error))
+		s.input <- tt.data
+		time.Sleep(1 * time.Second)
+		results := mockSink.GetResults()
+		if !reflect.DeepEqual(tt.result, results) {
+			t.Errorf("%d \tresult mismatch:\n\nexp=%x\n\ngot=%x\n\n", i, tt.result, results)
+		}
+	}
+}

+ 2 - 2
internal/topo/sink/rest_sink_test.go

@@ -1,4 +1,4 @@
-// Copyright 2021 EMQ Technologies Co., Ltd.
+// Copyright 2021-2022 EMQ Technologies Co., Ltd.
 //
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
 // you may not use this file except in compliance with the License.
@@ -185,7 +185,7 @@ func TestRestSink_Apply(t *testing.T) {
 		contextLogger.Debugf(string(body))
 		contextLogger.Debugf(string(body))
 		fmt.Fprintf(w, string(body))
 		fmt.Fprintf(w, string(body))
 	}))
 	}))
-	tf, _ := transform.GenTransform("")
+	tf, _ := transform.GenTransform("", "json", "")
 	defer ts.Close()
 	defer ts.Close()
 	for i, tt := range tests {
 	for i, tt := range tests {
 		requests = nil
 		requests = nil

+ 47 - 4
internal/topo/transform/template.go

@@ -19,13 +19,33 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/conf"
+	"github.com/lf-edge/ekuiper/internal/pkg/def"
+	"github.com/lf-edge/ekuiper/internal/schema"
+	"github.com/lf-edge/ekuiper/pkg/message"
+	"strings"
 	"text/template"
 	"text/template"
 )
 )
 
 
 type TransFunc func(interface{}) ([]byte, bool, error)
 type TransFunc func(interface{}) ([]byte, bool, error)
 
 
-func GenTransform(dt string) (TransFunc, error) {
-	var tp *template.Template = nil
+func GenTransform(dt string, format string, schemaId string) (TransFunc, error) {
+	var (
+		tp  *template.Template = nil
+		c   schema.Converter
+		err error
+	)
+	switch format {
+	case message.FormatProtobuf:
+		r := strings.Split(schemaId, ".")
+		if len(r) != 2 {
+			return nil, fmt.Errorf("invalid schemaId: %s", schemaId)
+		}
+		c, err = schema.GetOrCreateSchema(def.PROTOBUF, r[0], schemaId)
+		if err != nil {
+			return nil, err
+		}
+	}
+
 	if dt != "" {
 	if dt != "" {
 		temp, err := template.New("sink").Funcs(conf.FuncMap).Parse(dt)
 		temp, err := template.New("sink").Funcs(conf.FuncMap).Parse(dt)
 		if err != nil {
 		if err != nil {
@@ -34,16 +54,39 @@ func GenTransform(dt string) (TransFunc, error) {
 		tp = temp
 		tp = temp
 	}
 	}
 	return func(d interface{}) ([]byte, bool, error) {
 	return func(d interface{}) ([]byte, bool, error) {
+		var (
+			bs          []byte
+			transformed bool
+		)
 		if tp != nil {
 		if tp != nil {
 			var output bytes.Buffer
 			var output bytes.Buffer
 			err := tp.Execute(&output, d)
 			err := tp.Execute(&output, d)
 			if err != nil {
 			if err != nil {
 				return nil, false, fmt.Errorf("fail to encode data %v with dataTemplate for error %v", d, err)
 				return nil, false, fmt.Errorf("fail to encode data %v with dataTemplate for error %v", d, err)
 			}
 			}
-			return output.Bytes(), true, nil
-		} else {
+			bs = output.Bytes()
+			transformed = true
+		}
+		switch format {
+		case message.FormatJson:
+			if transformed {
+				return bs, transformed, nil
+			}
 			j, err := json.Marshal(d)
 			j, err := json.Marshal(d)
 			return j, false, err
 			return j, false, err
+		case message.FormatProtobuf:
+			if transformed {
+				m := make(map[string]interface{})
+				err := json.Unmarshal(bs, &m)
+				if err != nil {
+					return nil, false, fmt.Errorf("fail to decode data %s after applying dataTemplate for error %v", string(bs), err)
+				}
+				d = m
+			}
+			b, err := c.Encode(d)
+			return b, transformed, err
+		default: // should not happen
+			return nil, false, fmt.Errorf("unsupported format %v", format)
 		}
 		}
 	}, nil
 	}, nil
 }
 }

+ 4 - 3
pkg/message/decode.go

@@ -1,4 +1,4 @@
-// Copyright 2021 EMQ Technologies Co., Ltd.
+// Copyright 2021-2022 EMQ Technologies Co., Ltd.
 //
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
 // you may not use this file except in compliance with the License.
@@ -21,8 +21,9 @@ import (
 )
 )
 
 
 const (
 const (
-	FormatBinary = "binary"
-	FormatJson   = "json"
+	FormatBinary   = "binary"
+	FormatJson     = "json"
+	FormatProtobuf = "protobuf"
 
 
 	DefaultField = "self"
 	DefaultField = "self"
 	MetaKey      = "__meta"
 	MetaKey      = "__meta"