Bladeren bron

feat(extension): Support function extension

ngjaying 5 jaren geleden
bovenliggende
commit
014ef5adef

+ 40 - 0
common/plugin_manager/manager.go

@@ -0,0 +1,40 @@
+package plugin_manager
+
+import (
+	"fmt"
+	"plugin"
+	"unicode"
+)
+
+var registry map[string]plugin.Symbol
+
+func init(){
+	registry = make(map[string]plugin.Symbol)
+}
+
+func GetPlugin(t string, ptype string) (plugin.Symbol, error) {
+	t = ucFirst(t)
+	key := ptype + "/" + t
+	var nf plugin.Symbol
+	nf, ok := registry[key]
+	if !ok {
+		mod := "plugins/" + key + ".so"
+		plug, err := plugin.Open(mod)
+		if err != nil {
+			return nil, fmt.Errorf("cannot open %s: %v", mod, err)
+		}
+		nf, err = plug.Lookup(t)
+		if err != nil {
+			return nil, fmt.Errorf("cannot find symbol %s, please check if it is exported", t)
+		}
+	}
+	return nf, nil
+}
+
+func ucFirst(str string) string {
+	for i, v := range str {
+		return string(unicode.ToUpper(v)) + str[i+1:]
+	}
+	return ""
+}
+

+ 1 - 1
examples/testExtension.go

@@ -28,7 +28,7 @@ func main() {
 
 	rp := processors.NewRuleProcessor(BadgerDir)
 	rp.ExecDrop("$$test1")
-	rs, err := rp.ExecCreate("$$test1", "{\"sql\": \"SELECT count FROM ext where ext.count > 3\",\"actions\": [{\"memory\":  {}}]}")
+	rs, err := rp.ExecCreate("$$test1", "{\"sql\": \"SELECT echo(count) FROM ext where count > 3\",\"actions\": [{\"memory\":  {}}]}")
 	if err != nil {
 		msg := fmt.Sprintf("failed to create rule: %s.", err)
 		log.Printf(msg)

+ 22 - 0
plugins/functions/echo.go

@@ -0,0 +1,22 @@
+package main
+
+import (
+	"fmt"
+)
+
+type echo struct {
+}
+
+func (f *echo) Validate(args []interface{}) error{
+	if len(args) != 1{
+		return fmt.Errorf("echo function only supports 1 parameter but got %d", len(args))
+	}
+	return nil
+}
+
+func (f *echo) Exec(args []interface{}) (interface{}, bool) {
+	result := args[0]
+	return result, true
+}
+
+var Echo echo

+ 16 - 1
xsql/funcs_ast_validator.go

@@ -1,6 +1,8 @@
 package xsql
 
 import (
+	"engine/common/plugin_manager"
+	"engine/xstream/api"
 	"fmt"
 	"strings"
 )
@@ -21,8 +23,21 @@ func validateFuncs(funcName string, args []Expr) error {
 		return validateHashFunc(lowerName, args)
 	} else if _, ok := otherFuncMap[lowerName]; ok {
 		return validateOtherFunc(lowerName, args)
+	} else {
+		if nf, err := plugin_manager.GetPlugin(funcName, "functions"); err != nil {
+			return err
+		}else{
+			f, ok := nf.(api.Function)
+			if !ok {
+				return fmt.Errorf("exported symbol %s is not type of api.Function", funcName)
+			}
+			var targs []interface{}
+			for _, arg := range args{
+				targs = append(targs, arg)
+			}
+			return f.Validate(targs)
+		}
 	}
-	return nil
 }
 
 func validateMathFunc(name string, args []Expr) error {

+ 15 - 1
xsql/functions.go

@@ -1,6 +1,9 @@
 package xsql
 
 import (
+	"engine/common"
+	"engine/common/plugin_manager"
+	"engine/xstream/api"
 	"strings"
 )
 
@@ -65,6 +68,17 @@ func (*FunctionValuer) Call(name string, args []interface{}) (interface{}, bool)
 		return hashCall(lowerName, args)
 	} else if _, ok := otherFuncMap[lowerName]; ok {
 		return otherCall(lowerName, args)
+	} else {
+		if nf, err := plugin_manager.GetPlugin(name, "functions"); err != nil {
+			return nil, false
+		}else{
+			f, ok := nf.(api.Function)
+			if !ok {
+				return nil, false
+			}
+			result, ok := f.Exec(args)
+			common.Log.Debugf("run custom function %s, get result %v", name, result)
+			return result, ok
+		}
 	}
-	return nil, false
 }

+ 6 - 29
xsql/processors/xsql_processor.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"encoding/json"
 	"engine/common"
+	"engine/common/plugin_manager"
 	"engine/xsql"
 	"engine/xsql/plans"
 	"engine/xstream"
@@ -16,9 +17,7 @@ import (
 	"github.com/dgraph-io/badger"
 	"github.com/go-yaml/yaml"
 	"path"
-	"plugin"
 	"strings"
-	"unicode"
 )
 
 var log = common.Log
@@ -444,14 +443,13 @@ func getSource(streamStmt *xsql.StreamStmt) (api.Source, error) {
 	if !ok{
 		t = "mqtt"
 	}
-	t = ucFirst(t)
 	var s api.Source
 	switch t {
-	case "Mqtt":
+	case "mqtt":
 		s = &extensions.MQTTSource{}
 		log.Debugf("Source mqtt created")
 	default:
-		nf, err := getPlugin(t, "sources")
+		nf, err := plugin_manager.GetPlugin(t, "sources")
 		if err != nil {
 			return nil, err
 		}
@@ -495,30 +493,16 @@ func getConf(t string, confkey string) map[string]interface{} {
 	return props
 }
 
-func getPlugin(t string, ptype string) (plugin.Symbol, error) {
-	mod := "plugins/" + ptype + "/" + t + ".so"
-	plug, err := plugin.Open(mod)
-	if err != nil {
-		return nil, fmt.Errorf("cannot open %s: %v", mod, err)
-	}
-	nf, err := plug.Lookup(t)
-	if err != nil {
-		return nil, fmt.Errorf("cannot find symbol %s, please check if it is exported", t)
-	}
-	return nf, nil
-}
-
 func getSink(name string, action map[string]interface{}) (api.Sink, error) {
 	log.Tracef("trying to get sink %s with action %v", name, action)
 	var s api.Sink
-	name = ucFirst(name)
 	switch name {
-	case "Log":
+	case "log":
 		s = sinks.NewLogSink()
-	case "Mqtt":
+	case "mqtt":
 		s = &sinks.MQTTSink{}
 	default:
-		nf, err := getPlugin(name, "sinks")
+		nf, err := plugin_manager.GetPlugin(name, "sinks")
 		if err != nil {
 			return nil, err
 		}
@@ -537,10 +521,3 @@ func getSink(name string, action map[string]interface{}) (api.Sink, error) {
 	return s, nil
 }
 
-func ucFirst(str string) string {
-	for i, v := range str {
-		return string(unicode.ToUpper(v)) + str[i+1:]
-	}
-	return ""
-}
-

+ 9 - 2
xstream/api/stream.go

@@ -2,12 +2,11 @@ package api
 
 import (
 	"context"
-	"engine/xsql"
 	"github.com/sirupsen/logrus"
 )
 
 //The function to call when data is emitted by the source.
-type ConsumeFunc func(message xsql.Message, metadata xsql.Metadata)
+type ConsumeFunc func(message map[string]interface{}, metadata map[string]interface{})
 
 type Closable interface {
 	Close(ctx StreamContext) error
@@ -67,3 +66,11 @@ type Operator interface {
 	GetName() string
 }
 
+type Function interface {
+	//The argument is a list of xsql.Expr
+	Validate(args []interface{}) error
+	//Execute the function, return the result and if execution is successful.
+	//If execution fails, return the error and false.
+	Exec(args []interface{}) (interface{}, bool)
+}
+

+ 1 - 1
xstream/nodes/source_node.go

@@ -28,7 +28,7 @@ func (m *SourceNode) Open(ctx api.StreamContext, errCh chan<- error) {
 	logger := ctx.GetLogger()
 	logger.Debugf("open source node %s", m.name)
 	go func(){
-		if err := m.source.Open(ctx, func(message xsql.Message, meta xsql.Metadata){
+		if err := m.source.Open(ctx, func(message map[string]interface{}, meta map[string]interface{}){
 			tuple := &xsql.Tuple{Emitter: m.name, Message:message, Timestamp: common.GetNowInMilli(), Metadata:meta}
 			m.Broadcast(tuple)
 			logger.Debugf("%s consume data %v complete", m.name, tuple)