Pārlūkot izejas kodu

fix(rest): return error when replace stream sql does not match name (#1191)

Closes #1182

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
ngjaying 3 gadi atpakaļ
vecāks
revīzija
4b51bdddfd
2 mainītis faili ar 9 papildinājumiem un 6 dzēšanām
  1. 5 2
      internal/processor/stream.go
  2. 4 4
      internal/server/rest.go

+ 5 - 2
internal/processor/stream.go

@@ -1,4 +1,4 @@
-// Copyright 2021 EMQ Technologies Co., Ltd.
+// Copyright 2021-2022 EMQ Technologies Co., Ltd.
 //
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
 // you may not use this file except in compliance with the License.
@@ -115,7 +115,7 @@ func (p *StreamProcessor) execSave(stmt *ast.StreamStmt, statement string, repla
 	return err
 	return err
 }
 }
 
 
-func (p *StreamProcessor) ExecReplaceStream(statement string, st ast.StreamType) (string, error) {
+func (p *StreamProcessor) ExecReplaceStream(name string, statement string, st ast.StreamType) (string, error) {
 	parser := xsql.NewParser(strings.NewReader(statement))
 	parser := xsql.NewParser(strings.NewReader(statement))
 	stmt, err := xsql.Language.Parse(parser)
 	stmt, err := xsql.Language.Parse(parser)
 	if err != nil {
 	if err != nil {
@@ -127,6 +127,9 @@ func (p *StreamProcessor) ExecReplaceStream(statement string, st ast.StreamType)
 		if s.StreamType != st {
 		if s.StreamType != st {
 			return "", errorx.NewWithCode(errorx.NOT_FOUND, fmt.Sprintf("%s %s is not found", ast.StreamTypeMap[st], s.Name))
 			return "", errorx.NewWithCode(errorx.NOT_FOUND, fmt.Sprintf("%s %s is not found", ast.StreamTypeMap[st], s.Name))
 		}
 		}
+		if string(s.Name) != name {
+			return "", fmt.Errorf("Replace %s fails: the sql statement must update the %s source.", name, name)
+		}
 		err = p.execSave(s, statement, true)
 		err = p.execSave(s, statement, true)
 		if err != nil {
 		if err != nil {
 			return "", fmt.Errorf("Replace %s fails: %v.", stt, err)
 			return "", fmt.Errorf("Replace %s fails: %v.", stt, err)

+ 4 - 4
internal/server/rest.go

@@ -236,7 +236,7 @@ func sourceManageHandler(w http.ResponseWriter, r *http.Request, st ast.StreamTy
 			handleError(w, err, "Invalid body", logger)
 			handleError(w, err, "Invalid body", logger)
 			return
 			return
 		}
 		}
-		content, err := streamProcessor.ExecReplaceStream(v.Sql, st)
+		content, err := streamProcessor.ExecReplaceStream(name, v.Sql, st)
 		if err != nil {
 		if err != nil {
 			handleError(w, err, fmt.Sprintf("%s command error", strings.Title(ast.StreamTypeMap[st])), logger)
 			handleError(w, err, fmt.Sprintf("%s command error", strings.Title(ast.StreamTypeMap[st])), logger)
 			return
 			return
@@ -622,15 +622,15 @@ func portableHandler(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 }
 }
 
 
-func prebuildSourcePlugins(w http.ResponseWriter, r *http.Request) {
+func prebuildSourcePlugins(w http.ResponseWriter, _ *http.Request) {
 	prebuildPluginsHandler(w, plugin.SOURCE)
 	prebuildPluginsHandler(w, plugin.SOURCE)
 }
 }
 
 
-func prebuildSinkPlugins(w http.ResponseWriter, r *http.Request) {
+func prebuildSinkPlugins(w http.ResponseWriter, _ *http.Request) {
 	prebuildPluginsHandler(w, plugin.SINK)
 	prebuildPluginsHandler(w, plugin.SINK)
 }
 }
 
 
-func prebuildFuncsPlugins(w http.ResponseWriter, r *http.Request) {
+func prebuildFuncsPlugins(w http.ResponseWriter, _ *http.Request) {
 	prebuildPluginsHandler(w, plugin.FUNCTION)
 	prebuildPluginsHandler(w, plugin.FUNCTION)
 }
 }