funcsAstValidator.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. // Copyright 2021 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package xsql
  15. import (
  16. "fmt"
  17. "github.com/lf-edge/ekuiper/pkg/ast"
  18. "strings"
  19. )
  20. type AllowTypes struct {
  21. types []ast.Literal
  22. }
  23. func validateFuncs(funcName string, args []ast.Expr) error {
  24. lowerName := strings.ToLower(funcName)
  25. switch ast.FuncFinderSingleton().FuncType(lowerName) {
  26. case ast.NotFoundFunc:
  27. nf, _, err := parserFuncRuntime.Get(funcName)
  28. if err != nil {
  29. return fmt.Errorf("error getting function %s: %v", funcName, err)
  30. }
  31. var targs []interface{}
  32. for _, arg := range args {
  33. targs = append(targs, arg)
  34. }
  35. return nf.Validate(targs)
  36. case ast.AggFunc:
  37. return validateAggFunc(lowerName, args)
  38. case ast.MathFunc:
  39. return validateMathFunc(lowerName, args)
  40. case ast.ConvFunc:
  41. return validateConvFunc(lowerName, args)
  42. case ast.StrFunc:
  43. return validateStrFunc(lowerName, args)
  44. case ast.HashFunc:
  45. return validateHashFunc(lowerName, args)
  46. case ast.JsonFunc:
  47. return validateJsonFunc(lowerName, args)
  48. case ast.OtherFunc:
  49. return validateOtherFunc(lowerName, args)
  50. default:
  51. return fmt.Errorf("unkndow function %s", lowerName)
  52. }
  53. }
  54. func validateMathFunc(name string, args []ast.Expr) error {
  55. len := len(args)
  56. switch name {
  57. case "abs", "acos", "asin", "atan", "ceil", "cos", "cosh", "exp", "ln", "log", "round", "sign", "sin", "sinh",
  58. "sqrt", "tan", "tanh":
  59. if err := ast.ValidateLen(name, 1, len); err != nil {
  60. return err
  61. }
  62. if ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  63. return ast.ProduceErrInfo(name, 0, "number - float or int")
  64. }
  65. case "bitand", "bitor", "bitxor":
  66. if err := ast.ValidateLen(name, 2, len); err != nil {
  67. return err
  68. }
  69. if ast.IsFloatArg(args[0]) || ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  70. return ast.ProduceErrInfo(name, 0, "int")
  71. }
  72. if ast.IsFloatArg(args[1]) || ast.IsStringArg(args[1]) || ast.IsTimeArg(args[1]) || ast.IsBooleanArg(args[1]) {
  73. return ast.ProduceErrInfo(name, 1, "int")
  74. }
  75. case "bitnot":
  76. if err := ast.ValidateLen(name, 1, len); err != nil {
  77. return err
  78. }
  79. if ast.IsFloatArg(args[0]) || ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  80. return ast.ProduceErrInfo(name, 0, "int")
  81. }
  82. case "atan2", "mod", "power":
  83. if err := ast.ValidateLen(name, 2, len); err != nil {
  84. return err
  85. }
  86. if ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  87. return ast.ProduceErrInfo(name, 0, "number - float or int")
  88. }
  89. if ast.IsStringArg(args[1]) || ast.IsTimeArg(args[1]) || ast.IsBooleanArg(args[1]) {
  90. return ast.ProduceErrInfo(name, 1, "number - float or int")
  91. }
  92. case "rand":
  93. if err := ast.ValidateLen(name, 0, len); err != nil {
  94. return err
  95. }
  96. }
  97. return nil
  98. }
  99. func validateStrFunc(name string, args []ast.Expr) error {
  100. len := len(args)
  101. switch name {
  102. case "concat":
  103. if len == 0 {
  104. return fmt.Errorf("The arguments for %s should be at least one.\n", name)
  105. }
  106. for i, a := range args {
  107. if ast.IsNumericArg(a) || ast.IsTimeArg(a) || ast.IsBooleanArg(a) {
  108. return ast.ProduceErrInfo(name, i, "string")
  109. }
  110. }
  111. case "endswith", "indexof", "regexp_matches", "startswith":
  112. if err := ast.ValidateLen(name, 2, len); err != nil {
  113. return err
  114. }
  115. for i := 0; i < 2; i++ {
  116. if ast.IsNumericArg(args[i]) || ast.IsTimeArg(args[i]) || ast.IsBooleanArg(args[i]) {
  117. return ast.ProduceErrInfo(name, i, "string")
  118. }
  119. }
  120. case "format_time":
  121. if err := ast.ValidateLen(name, 2, len); err != nil {
  122. return err
  123. }
  124. if ast.IsNumericArg(args[0]) || ast.IsStringArg(args[0]) || ast.IsBooleanArg(args[0]) {
  125. return ast.ProduceErrInfo(name, 0, "datetime")
  126. }
  127. if ast.IsNumericArg(args[1]) || ast.IsTimeArg(args[1]) || ast.IsBooleanArg(args[1]) {
  128. return ast.ProduceErrInfo(name, 1, "string")
  129. }
  130. case "regexp_replace":
  131. if err := ast.ValidateLen(name, 3, len); err != nil {
  132. return err
  133. }
  134. for i := 0; i < 3; i++ {
  135. if ast.IsNumericArg(args[i]) || ast.IsTimeArg(args[i]) || ast.IsBooleanArg(args[i]) {
  136. return ast.ProduceErrInfo(name, i, "string")
  137. }
  138. }
  139. case "length", "lower", "ltrim", "numbytes", "rtrim", "trim", "upper":
  140. if err := ast.ValidateLen(name, 1, len); err != nil {
  141. return err
  142. }
  143. if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  144. return ast.ProduceErrInfo(name, 0, "string")
  145. }
  146. case "lpad", "rpad":
  147. if err := ast.ValidateLen(name, 2, len); err != nil {
  148. return err
  149. }
  150. if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  151. return ast.ProduceErrInfo(name, 0, "string")
  152. }
  153. if ast.IsFloatArg(args[1]) || ast.IsTimeArg(args[1]) || ast.IsBooleanArg(args[1]) || ast.IsStringArg(args[1]) {
  154. return ast.ProduceErrInfo(name, 1, "int")
  155. }
  156. case "substring":
  157. if len != 2 && len != 3 {
  158. return fmt.Errorf("the arguments for substring should be 2 or 3")
  159. }
  160. if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  161. return ast.ProduceErrInfo(name, 0, "string")
  162. }
  163. for i := 1; i < len; i++ {
  164. if ast.IsFloatArg(args[i]) || ast.IsTimeArg(args[i]) || ast.IsBooleanArg(args[i]) || ast.IsStringArg(args[i]) {
  165. return ast.ProduceErrInfo(name, i, "int")
  166. }
  167. }
  168. if s, ok := args[1].(*ast.IntegerLiteral); ok {
  169. sv := s.Val
  170. if sv < 0 {
  171. return fmt.Errorf("The start index should not be a nagtive integer.")
  172. }
  173. if len == 3 {
  174. if e, ok1 := args[2].(*ast.IntegerLiteral); ok1 {
  175. ev := e.Val
  176. if ev < sv {
  177. return fmt.Errorf("The end index should be larger than start index.")
  178. }
  179. }
  180. }
  181. }
  182. case "split_value":
  183. if len != 3 {
  184. return fmt.Errorf("the arguments for split_value should be 3")
  185. }
  186. if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  187. return ast.ProduceErrInfo(name, 0, "string")
  188. }
  189. if ast.IsNumericArg(args[1]) || ast.IsTimeArg(args[1]) || ast.IsBooleanArg(args[1]) {
  190. return ast.ProduceErrInfo(name, 1, "string")
  191. }
  192. if ast.IsFloatArg(args[2]) || ast.IsTimeArg(args[2]) || ast.IsBooleanArg(args[2]) || ast.IsStringArg(args[2]) {
  193. return ast.ProduceErrInfo(name, 2, "int")
  194. }
  195. if s, ok := args[2].(*ast.IntegerLiteral); ok {
  196. if s.Val < 0 {
  197. return fmt.Errorf("The index should not be a nagtive integer.")
  198. }
  199. }
  200. }
  201. return nil
  202. }
  203. func validateConvFunc(name string, args []ast.Expr) error {
  204. len := len(args)
  205. switch name {
  206. case "cast":
  207. if err := ast.ValidateLen(name, 2, len); err != nil {
  208. return err
  209. }
  210. a := args[1]
  211. if !ast.IsStringArg(a) {
  212. return ast.ProduceErrInfo(name, 1, "string")
  213. }
  214. if av, ok := a.(*ast.StringLiteral); ok {
  215. if !(av.Val == "bigint" || av.Val == "float" || av.Val == "string" || av.Val == "boolean" || av.Val == "datetime") {
  216. return fmt.Errorf("Expect one of following value for the 2nd parameter: bigint, float, string, boolean, datetime.")
  217. }
  218. }
  219. case "chr":
  220. if err := ast.ValidateLen(name, 1, len); err != nil {
  221. return err
  222. }
  223. if ast.IsFloatArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  224. return ast.ProduceErrInfo(name, 0, "int")
  225. }
  226. case "encode":
  227. if err := ast.ValidateLen(name, 2, len); err != nil {
  228. return err
  229. }
  230. if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  231. return ast.ProduceErrInfo(name, 0, "string")
  232. }
  233. a := args[1]
  234. if !ast.IsStringArg(a) {
  235. return ast.ProduceErrInfo(name, 1, "string")
  236. }
  237. if av, ok := a.(*ast.StringLiteral); ok {
  238. if av.Val != "base64" {
  239. return fmt.Errorf("Only base64 is supported for the 2nd parameter.")
  240. }
  241. }
  242. case "trunc":
  243. if err := ast.ValidateLen(name, 2, len); err != nil {
  244. return err
  245. }
  246. if ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) || ast.IsStringArg(args[0]) {
  247. return ast.ProduceErrInfo(name, 0, "number - float or int")
  248. }
  249. if ast.IsFloatArg(args[1]) || ast.IsTimeArg(args[1]) || ast.IsBooleanArg(args[1]) || ast.IsStringArg(args[1]) {
  250. return ast.ProduceErrInfo(name, 1, "int")
  251. }
  252. }
  253. return nil
  254. }
  255. func validateHashFunc(name string, args []ast.Expr) error {
  256. len := len(args)
  257. switch name {
  258. case "md5", "sha1", "sha224", "sha256", "sha384", "sha512":
  259. if err := ast.ValidateLen(name, 1, len); err != nil {
  260. return err
  261. }
  262. if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  263. return ast.ProduceErrInfo(name, 0, "string")
  264. }
  265. }
  266. return nil
  267. }
  268. func validateOtherFunc(name string, args []ast.Expr) error {
  269. len := len(args)
  270. switch name {
  271. case "isNull":
  272. if err := ast.ValidateLen(name, 1, len); err != nil {
  273. return err
  274. }
  275. case "cardinality":
  276. if err := ast.ValidateLen(name, 1, len); err != nil {
  277. return err
  278. }
  279. case "nanvl":
  280. if err := ast.ValidateLen(name, 2, len); err != nil {
  281. return err
  282. }
  283. if ast.IsIntegerArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) || ast.IsStringArg(args[0]) {
  284. return ast.ProduceErrInfo(name, 1, "float")
  285. }
  286. case "newuuid":
  287. if err := ast.ValidateLen(name, 0, len); err != nil {
  288. return err
  289. }
  290. case "mqtt":
  291. if err := ast.ValidateLen(name, 1, len); err != nil {
  292. return err
  293. }
  294. if ast.IsIntegerArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) || ast.IsStringArg(args[0]) || ast.IsFloatArg(args[0]) {
  295. return ast.ProduceErrInfo(name, 0, "meta reference")
  296. }
  297. if p, ok := args[0].(*ast.MetaRef); ok {
  298. name := strings.ToLower(p.Name)
  299. if name != "topic" && name != "messageid" {
  300. return fmt.Errorf("Parameter of mqtt function can be only topic or messageid.")
  301. }
  302. }
  303. case "meta":
  304. if err := ast.ValidateLen(name, 1, len); err != nil {
  305. return err
  306. }
  307. if _, ok := args[0].(*ast.MetaRef); ok {
  308. return nil
  309. }
  310. expr := args[0]
  311. for {
  312. if be, ok := expr.(*ast.BinaryExpr); ok {
  313. if _, ok := be.LHS.(*ast.MetaRef); ok && be.OP == ast.ARROW {
  314. return nil
  315. }
  316. expr = be.LHS
  317. } else {
  318. break
  319. }
  320. }
  321. return ast.ProduceErrInfo(name, 0, "meta reference")
  322. }
  323. return nil
  324. }
  325. func validateJsonFunc(name string, args []ast.Expr) error {
  326. len := len(args)
  327. if err := ast.ValidateLen(name, 2, len); err != nil {
  328. return err
  329. }
  330. if !ast.IsStringArg(args[1]) {
  331. return ast.ProduceErrInfo(name, 1, "string")
  332. }
  333. return nil
  334. }
  335. func validateAggFunc(name string, args []ast.Expr) error {
  336. len := len(args)
  337. switch name {
  338. case "avg", "max", "min", "sum":
  339. if err := ast.ValidateLen(name, 1, len); err != nil {
  340. return err
  341. }
  342. if ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
  343. return ast.ProduceErrInfo(name, 0, "number - float or int")
  344. }
  345. case "count":
  346. if err := ast.ValidateLen(name, 1, len); err != nil {
  347. return err
  348. }
  349. case "collect":
  350. if err := ast.ValidateLen(name, 1, len); err != nil {
  351. return err
  352. }
  353. case "deduplicate":
  354. if err := ast.ValidateLen(name, 2, len); err != nil {
  355. return err
  356. }
  357. if !ast.IsBooleanArg(args[1]) {
  358. return ast.ProduceErrInfo(name, 1, "bool")
  359. }
  360. }
  361. return nil
  362. }