funcs_ast_validator.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. package xsql
  2. import (
  3. "fmt"
  4. "github.com/emqx/kuiper/plugins"
  5. "strings"
  6. )
  7. type AllowTypes struct {
  8. types []Literal
  9. }
  10. func validateFuncs(funcName string, args []Expr) error {
  11. lowerName := strings.ToLower(funcName)
  12. if _, ok := mathFuncMap[lowerName]; ok {
  13. return validateMathFunc(funcName, args)
  14. } else if _, ok := strFuncMap[lowerName]; ok {
  15. return validateStrFunc(funcName, args)
  16. } else if _, ok := convFuncMap[lowerName]; ok {
  17. return validateConvFunc(lowerName, args)
  18. } else if _, ok := hashFuncMap[lowerName]; ok {
  19. return validateHashFunc(lowerName, args)
  20. } else if _, ok := otherFuncMap[lowerName]; ok {
  21. return validateOtherFunc(lowerName, args)
  22. } else if _, ok := aggFuncMap[lowerName]; ok {
  23. return validateAggFunc(lowerName, args)
  24. } else {
  25. if nf, err := plugins.GetFunction(funcName); err != nil {
  26. return err
  27. } else {
  28. var targs []interface{}
  29. for _, arg := range args {
  30. targs = append(targs, arg)
  31. }
  32. return nf.Validate(targs)
  33. }
  34. }
  35. }
  36. func validateMathFunc(name string, args []Expr) error {
  37. len := len(args)
  38. switch name {
  39. case "abs", "acos", "asin", "atan", "ceil", "cos", "cosh", "exp", "ln", "log", "round", "sign", "sin", "sinh",
  40. "sqrt", "tan", "tanh":
  41. if err := validateLen(name, 1, len); err != nil {
  42. return err
  43. }
  44. if isStringArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  45. return produceErrInfo(name, 0, "number - float or int")
  46. }
  47. case "bitand", "bitor", "bitxor":
  48. if err := validateLen(name, 2, len); err != nil {
  49. return err
  50. }
  51. if isFloatArg(args[0]) || isStringArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  52. return produceErrInfo(name, 0, "int")
  53. }
  54. if isFloatArg(args[1]) || isStringArg(args[1]) || isTimeArg(args[1]) || isBooleanArg(args[1]) {
  55. return produceErrInfo(name, 1, "int")
  56. }
  57. case "bitnot":
  58. if err := validateLen(name, 1, len); err != nil {
  59. return err
  60. }
  61. if isFloatArg(args[0]) || isStringArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  62. return produceErrInfo(name, 0, "int")
  63. }
  64. case "atan2", "mod", "power":
  65. if err := validateLen(name, 2, len); err != nil {
  66. return err
  67. }
  68. if isStringArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  69. return produceErrInfo(name, 0, "number - float or int")
  70. }
  71. if isStringArg(args[1]) || isTimeArg(args[1]) || isBooleanArg(args[1]) {
  72. return produceErrInfo(name, 1, "number - float or int")
  73. }
  74. case "rand":
  75. if err := validateLen(name, 0, len); err != nil {
  76. return err
  77. }
  78. }
  79. return nil
  80. }
  81. func validateStrFunc(name string, args []Expr) error {
  82. len := len(args)
  83. switch name {
  84. case "concat":
  85. if len == 0 {
  86. return fmt.Errorf("The arguments for %s should be at least one.\n", name)
  87. }
  88. for i, a := range args {
  89. if isNumericArg(a) || isTimeArg(a) || isBooleanArg(a) {
  90. return produceErrInfo(name, i, "string")
  91. }
  92. }
  93. case "endswith", "indexof", "regexp_matches", "startswith":
  94. if err := validateLen(name, 2, len); err != nil {
  95. return err
  96. }
  97. for i := 0; i < 2; i++ {
  98. if isNumericArg(args[i]) || isTimeArg(args[i]) || isBooleanArg(args[i]) {
  99. return produceErrInfo(name, i, "string")
  100. }
  101. }
  102. case "format_time":
  103. if err := validateLen(name, 2, len); err != nil {
  104. return err
  105. }
  106. if isNumericArg(args[0]) || isStringArg(args[0]) || isBooleanArg(args[0]) {
  107. return produceErrInfo(name, 0, "datetime")
  108. }
  109. if isNumericArg(args[1]) || isTimeArg(args[1]) || isBooleanArg(args[1]) {
  110. return produceErrInfo(name, 1, "string")
  111. }
  112. case "regexp_replace":
  113. if err := validateLen(name, 3, len); err != nil {
  114. return err
  115. }
  116. for i := 0; i < 3; i++ {
  117. if isNumericArg(args[i]) || isTimeArg(args[i]) || isBooleanArg(args[i]) {
  118. return produceErrInfo(name, i, "string")
  119. }
  120. }
  121. case "length", "lower", "ltrim", "numbytes", "rtrim", "trim", "upper":
  122. if err := validateLen(name, 1, len); err != nil {
  123. return err
  124. }
  125. if isNumericArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  126. return produceErrInfo(name, 0, "string")
  127. }
  128. case "lpad", "rpad":
  129. if err := validateLen(name, 2, len); err != nil {
  130. return err
  131. }
  132. if isNumericArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  133. return produceErrInfo(name, 0, "string")
  134. }
  135. if isFloatArg(args[1]) || isTimeArg(args[1]) || isBooleanArg(args[1]) || isStringArg(args[1]) {
  136. return produceErrInfo(name, 1, "int")
  137. }
  138. case "substring":
  139. if len != 2 && len != 3 {
  140. return fmt.Errorf("the arguments for substring should be 2 or 3")
  141. }
  142. if isNumericArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  143. return produceErrInfo(name, 0, "string")
  144. }
  145. for i := 1; i < len; i++ {
  146. if isFloatArg(args[i]) || isTimeArg(args[i]) || isBooleanArg(args[i]) || isStringArg(args[i]) {
  147. return produceErrInfo(name, i, "int")
  148. }
  149. }
  150. if s, ok := args[1].(*IntegerLiteral); ok {
  151. sv := s.Val
  152. if sv < 0 {
  153. return fmt.Errorf("The start index should not be a nagtive integer.")
  154. }
  155. if len == 3 {
  156. if e, ok1 := args[2].(*IntegerLiteral); ok1 {
  157. ev := e.Val
  158. if ev < sv {
  159. return fmt.Errorf("The end index should be larger than start index.")
  160. }
  161. }
  162. }
  163. }
  164. case "split_value":
  165. if len != 3 {
  166. return fmt.Errorf("the arguments for split_value should be 3")
  167. }
  168. if isNumericArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  169. return produceErrInfo(name, 0, "string")
  170. }
  171. if isNumericArg(args[1]) || isTimeArg(args[1]) || isBooleanArg(args[1]) {
  172. return produceErrInfo(name, 1, "string")
  173. }
  174. if isFloatArg(args[2]) || isTimeArg(args[2]) || isBooleanArg(args[2]) || isStringArg(args[2]) {
  175. return produceErrInfo(name, 2, "int")
  176. }
  177. if s, ok := args[2].(*IntegerLiteral); ok {
  178. if s.Val < 0 {
  179. return fmt.Errorf("The index should not be a nagtive integer.")
  180. }
  181. }
  182. }
  183. return nil
  184. }
  185. func validateConvFunc(name string, args []Expr) error {
  186. len := len(args)
  187. switch name {
  188. case "cast":
  189. if err := validateLen(name, 2, len); err != nil {
  190. return err
  191. }
  192. a := args[1]
  193. if !isStringArg(a) {
  194. return produceErrInfo(name, 1, "string")
  195. }
  196. if av, ok := a.(*StringLiteral); ok {
  197. if !(av.Val == "bigint" || av.Val == "float" || av.Val == "string" || av.Val == "boolean" || av.Val == "datetime") {
  198. return fmt.Errorf("Expect one of following value for the 2nd parameter: bigint, float, string, boolean, datetime.")
  199. }
  200. }
  201. case "chr":
  202. if err := validateLen(name, 1, len); err != nil {
  203. return err
  204. }
  205. if isFloatArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  206. return produceErrInfo(name, 0, "int")
  207. }
  208. case "encode":
  209. if err := validateLen(name, 2, len); err != nil {
  210. return err
  211. }
  212. if isNumericArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  213. return produceErrInfo(name, 0, "string")
  214. }
  215. a := args[1]
  216. if !isStringArg(a) {
  217. return produceErrInfo(name, 1, "string")
  218. }
  219. if av, ok := a.(*StringLiteral); ok {
  220. if av.Val != "base64" {
  221. return fmt.Errorf("Only base64 is supported for the 2nd parameter.")
  222. }
  223. }
  224. case "trunc":
  225. if err := validateLen(name, 2, len); err != nil {
  226. return err
  227. }
  228. if isTimeArg(args[0]) || isBooleanArg(args[0]) || isStringArg(args[0]) {
  229. return produceErrInfo(name, 0, "number - float or int")
  230. }
  231. if isFloatArg(args[1]) || isTimeArg(args[1]) || isBooleanArg(args[1]) || isStringArg(args[1]) {
  232. return produceErrInfo(name, 1, "int")
  233. }
  234. }
  235. return nil
  236. }
  237. func validateHashFunc(name string, args []Expr) error {
  238. len := len(args)
  239. switch name {
  240. case "md5", "sha1", "sha224", "sha256", "sha384", "sha512":
  241. if err := validateLen(name, 1, len); err != nil {
  242. return err
  243. }
  244. if isNumericArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  245. return produceErrInfo(name, 0, "string")
  246. }
  247. }
  248. return nil
  249. }
  250. func validateOtherFunc(name string, args []Expr) error {
  251. len := len(args)
  252. switch name {
  253. case "isNull":
  254. if err := validateLen(name, 1, len); err != nil {
  255. return err
  256. }
  257. case "nanvl":
  258. if err := validateLen(name, 2, len); err != nil {
  259. return err
  260. }
  261. if isIntegerArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) || isStringArg(args[0]) {
  262. return produceErrInfo(name, 1, "float")
  263. }
  264. case "newuuid":
  265. if err := validateLen(name, 0, len); err != nil {
  266. return err
  267. }
  268. case "mqtt":
  269. if err := validateLen(name, 1, len); err != nil {
  270. return err
  271. }
  272. if isIntegerArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) || isStringArg(args[0]) || isFloatArg(args[0]) {
  273. return produceErrInfo(name, 0, "meta reference")
  274. }
  275. if p, ok := args[0].(*MetaRef); ok {
  276. name := strings.ToLower(p.Name)
  277. if name != "topic" && name != "messageid" {
  278. return fmt.Errorf("Parameter of mqtt function can be only topic or messageid.")
  279. }
  280. }
  281. case "meta":
  282. if err := validateLen(name, 1, len); err != nil {
  283. return err
  284. }
  285. if _, ok := args[0].(*MetaRef); ok {
  286. return nil
  287. }
  288. expr := args[0]
  289. for {
  290. if be, ok := expr.(*BinaryExpr); ok {
  291. if _, ok := be.LHS.(*MetaRef); ok && be.OP == ARROW {
  292. return nil
  293. }
  294. expr = be.LHS
  295. } else {
  296. break
  297. }
  298. }
  299. return produceErrInfo(name, 0, "meta reference")
  300. }
  301. return nil
  302. }
  303. func validateAggFunc(name string, args []Expr) error {
  304. len := len(args)
  305. switch name {
  306. case "avg", "max", "min", "sum":
  307. if err := validateLen(name, 1, len); err != nil {
  308. return err
  309. }
  310. if isStringArg(args[0]) || isTimeArg(args[0]) || isBooleanArg(args[0]) {
  311. return produceErrInfo(name, 0, "number - float or int")
  312. }
  313. case "count":
  314. if err := validateLen(name, 1, len); err != nil {
  315. return err
  316. }
  317. }
  318. return nil
  319. }
  320. // Index is starting from 0
  321. func produceErrInfo(name string, index int, expect string) (err error) {
  322. index++
  323. err = fmt.Errorf("Expect %s type for %d parameter of function %s.", expect, index, name)
  324. return
  325. }
  326. func validateLen(funcName string, exp, actual int) error {
  327. if actual != exp {
  328. return fmt.Errorf("The arguments for %s should be %d.", funcName, exp)
  329. }
  330. return nil
  331. }
  332. func isNumericArg(arg Expr) bool {
  333. if _, ok := arg.(*NumberLiteral); ok {
  334. return true
  335. } else if _, ok := arg.(*IntegerLiteral); ok {
  336. return true
  337. }
  338. return false
  339. }
  340. func isIntegerArg(arg Expr) bool {
  341. if _, ok := arg.(*IntegerLiteral); ok {
  342. return true
  343. }
  344. return false
  345. }
  346. func isFloatArg(arg Expr) bool {
  347. if _, ok := arg.(*NumberLiteral); ok {
  348. return true
  349. }
  350. return false
  351. }
  352. func isBooleanArg(arg Expr) bool {
  353. if _, ok := arg.(*BooleanLiteral); ok {
  354. return true
  355. }
  356. return false
  357. }
  358. func isStringArg(arg Expr) bool {
  359. if _, ok := arg.(*StringLiteral); ok {
  360. return true
  361. }
  362. return false
  363. }
  364. func isTimeArg(arg Expr) bool {
  365. if _, ok := arg.(*TimeLiteral); ok {
  366. return true
  367. }
  368. return false
  369. }