analyzer_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. // Copyright 2021-2023 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package planner
  15. import (
  16. "encoding/json"
  17. "errors"
  18. "fmt"
  19. "reflect"
  20. "strings"
  21. "testing"
  22. "github.com/stretchr/testify/assert"
  23. "github.com/stretchr/testify/require"
  24. "github.com/lf-edge/ekuiper/internal/pkg/store"
  25. "github.com/lf-edge/ekuiper/internal/testx"
  26. "github.com/lf-edge/ekuiper/internal/xsql"
  27. "github.com/lf-edge/ekuiper/pkg/api"
  28. "github.com/lf-edge/ekuiper/pkg/ast"
  29. )
  30. func init() {
  31. }
  32. type errorStruct struct {
  33. err string
  34. serr *string
  35. }
  36. func newErrorStruct(err string) *errorStruct {
  37. return &errorStruct{
  38. err: err,
  39. }
  40. }
  41. func newErrorStructWithS(err string, serr string) *errorStruct {
  42. return &errorStruct{
  43. err: err,
  44. serr: &serr,
  45. }
  46. }
  47. func (e *errorStruct) Serr() string {
  48. if e.serr != nil {
  49. return *e.serr
  50. }
  51. return e.err
  52. }
  53. var tests = []struct {
  54. sql string
  55. r *errorStruct
  56. }{
  57. { // 0
  58. sql: `SELECT count(*) FROM src1 HAVING sin(temp) > 0.3`,
  59. r: newErrorStruct("Not allowed to call non-aggregate functions in HAVING clause."),
  60. },
  61. { // 1
  62. sql: `SELECT count(*) FROM src1 WHERE name = "dname" HAVING sin(count(*)) > 0.3`,
  63. r: newErrorStruct(""),
  64. },
  65. { // 2
  66. sql: `SELECT count(*) as c FROM src1 WHERE name = "dname" HAVING sin(c) > 0.3`,
  67. r: newErrorStruct(""),
  68. },
  69. { // 3
  70. sql: `SELECT count(*) as c FROM src1 WHERE name = "dname" HAVING sum(c) > 0.3`,
  71. r: newErrorStruct("invalid argument for func sum: aggregate argument is not allowed"),
  72. },
  73. { // 4
  74. sql: `SELECT count(*) as c FROM src1 WHERE name = "dname" GROUP BY sin(c)`,
  75. r: newErrorStruct("Not allowed to call aggregate functions in GROUP BY clause."),
  76. },
  77. { // 5
  78. sql: `SELECT count(*) as c FROM src1 WHERE name = "dname" HAVING sum(c) > 0.3 OR sin(temp) > 3`,
  79. r: newErrorStruct("Not allowed to call non-aggregate functions in HAVING clause."),
  80. },
  81. { // 6
  82. sql: `SELECT collect(*) as c FROM src1 WHERE name = "dname" HAVING c[2]->temp > 20 AND sin(c[0]->temp) > 0`,
  83. r: newErrorStruct(""),
  84. },
  85. { // 7
  86. sql: `SELECT collect(*) as c FROM src1 WHERE name = "dname" HAVING c[2]->temp + temp > 0`,
  87. r: newErrorStruct("Not allowed to call non-aggregate functions in HAVING clause."),
  88. },
  89. { // 8
  90. sql: `SELECT deduplicate(temp, true) as de FROM src1 HAVING cardinality(de) > 20`,
  91. r: newErrorStruct(""),
  92. },
  93. { // 9
  94. sql: `SELECT sin(temp) as temp FROM src1`,
  95. r: newErrorStruct(""),
  96. },
  97. { // 10
  98. sql: `SELECT sum(temp) as temp1, count(temp) as temp FROM src1`,
  99. r: newErrorStruct("invalid argument for func sum: aggregate argument is not allowed"),
  100. },
  101. { // 11
  102. sql: `SELECT sum(temp) as temp1, count(temp) as ct FROM src1`,
  103. r: newErrorStruct(""),
  104. },
  105. { // 12
  106. sql: `SELECT collect(*)->abc FROM src1`,
  107. r: newErrorStruct(""),
  108. },
  109. { // 13
  110. sql: `SELECT sin(temp) as temp1, cos(temp1) FROM src1`,
  111. r: newErrorStruct(""),
  112. },
  113. { // 14
  114. sql: `SELECT collect(*)[-1] as current FROM src1 GROUP BY COUNTWINDOW(2, 1) HAVING isNull(current->name) = false`,
  115. r: newErrorStruct(""),
  116. },
  117. { // 15
  118. sql: `SELECT sum(next->nid) as nid FROM src1 WHERE next->nid > 20 `,
  119. r: newErrorStruct(""),
  120. },
  121. { // 16
  122. sql: `SELECT collect(*)[0] as last FROM src1 GROUP BY SlidingWindow(ss,5) HAVING last.temp > 30`,
  123. r: newErrorStruct(""),
  124. },
  125. { // 17
  126. sql: `SELECT last_hit_time() FROM src1 GROUP BY SlidingWindow(ss,5) HAVING last_agg_hit_count() < 3`,
  127. r: newErrorStruct("function last_hit_time is not allowed in an aggregate query"),
  128. },
  129. { // 18
  130. sql: `SELECT * FROM src1 GROUP BY SlidingWindow(ss,5) Over (WHEN last_hit_time() > 1) HAVING last_agg_hit_count() < 3`,
  131. r: newErrorStruct(""),
  132. },
  133. {
  134. sql: "select a + 1 as b, b + 1 as a from src1",
  135. r: newErrorStruct("select fields have cycled alias"),
  136. },
  137. {
  138. sql: "select a + 1 as b, b * 2 as c, c + 1 as a from src1",
  139. r: newErrorStruct("select fields have cycled alias"),
  140. },
  141. //{ // 19 already captured in parser
  142. // sql: `SELECT * FROM src1 GROUP BY SlidingWindow(ss,5) Over (WHEN abs(sum(a)) > 1) HAVING last_agg_hit_count() < 3`,
  143. // r: newErrorStruct("error compile sql: Not allowed to call aggregate functions in GROUP BY clause."),
  144. //},
  145. }
  146. func TestCheckTopoSort(t *testing.T) {
  147. store, err := store.GetKV("stream")
  148. require.NoError(t, err)
  149. streamSqls := map[string]string{
  150. "src1": `CREATE STREAM src1 (
  151. id1 BIGINT,
  152. temp BIGINT,
  153. name string,
  154. next STRUCT(NAME STRING, NID BIGINT)
  155. ) WITH (DATASOURCE="src1", FORMAT="json", KEY="ts");`,
  156. }
  157. types := map[string]ast.StreamType{
  158. "src1": ast.TypeStream,
  159. }
  160. for name, sql := range streamSqls {
  161. s, err := json.Marshal(&xsql.StreamInfo{
  162. StreamType: types[name],
  163. Statement: sql,
  164. })
  165. require.NoError(t, err)
  166. store.Set(name, string(s))
  167. }
  168. streams := make(map[string]*ast.StreamStmt)
  169. for n := range streamSqls {
  170. streamStmt, err := xsql.GetDataSource(store, n)
  171. if err != nil {
  172. t.Errorf("fail to get stream %s, please check if stream is created", n)
  173. return
  174. }
  175. streams[n] = streamStmt
  176. }
  177. sql := "select latest(a) as a from src1"
  178. stmt, err := xsql.NewParser(strings.NewReader(sql)).Parse()
  179. require.NoError(t, err)
  180. _, err = createLogicalPlan(stmt, &api.RuleOption{
  181. IsEventTime: false,
  182. LateTol: 0,
  183. Concurrency: 0,
  184. BufferLength: 0,
  185. SendMetaToSink: false,
  186. Qos: 0,
  187. CheckpointInterval: 0,
  188. SendError: true,
  189. }, store)
  190. require.Equal(t, errors.New("unknown field a"), err)
  191. }
  192. func Test_validation(t *testing.T) {
  193. tests[10].r = newErrorStruct("invalid argument for func sum: aggregate argument is not allowed")
  194. store, err := store.GetKV("stream")
  195. if err != nil {
  196. t.Error(err)
  197. return
  198. }
  199. streamSqls := map[string]string{
  200. "src1": `CREATE STREAM src1 (
  201. id1 BIGINT,
  202. temp BIGINT,
  203. name string,
  204. next STRUCT(NAME STRING, NID BIGINT)
  205. ) WITH (DATASOURCE="src1", FORMAT="json", KEY="ts");`,
  206. }
  207. types := map[string]ast.StreamType{
  208. "src1": ast.TypeStream,
  209. }
  210. for name, sql := range streamSqls {
  211. s, err := json.Marshal(&xsql.StreamInfo{
  212. StreamType: types[name],
  213. Statement: sql,
  214. })
  215. if err != nil {
  216. t.Error(err)
  217. t.Fail()
  218. }
  219. store.Set(name, string(s))
  220. }
  221. streams := make(map[string]*ast.StreamStmt)
  222. for n := range streamSqls {
  223. streamStmt, err := xsql.GetDataSource(store, n)
  224. if err != nil {
  225. t.Errorf("fail to get stream %s, please check if stream is created", n)
  226. return
  227. }
  228. streams[n] = streamStmt
  229. }
  230. fmt.Printf("The test bucket size is %d.\n\n", len(tests))
  231. for i, tt := range tests {
  232. stmt, err := xsql.NewParser(strings.NewReader(tt.sql)).Parse()
  233. if err != nil {
  234. t.Errorf("%d. %q: error compile sql: %s\n", i, tt.sql, err)
  235. continue
  236. }
  237. _, err = createLogicalPlan(stmt, &api.RuleOption{
  238. IsEventTime: false,
  239. LateTol: 0,
  240. Concurrency: 0,
  241. BufferLength: 0,
  242. SendMetaToSink: false,
  243. Qos: 0,
  244. CheckpointInterval: 0,
  245. SendError: true,
  246. }, store)
  247. assert.Equal(t, tt.r.err, testx.Errstring(err))
  248. }
  249. }
  250. func Test_validationSchemaless(t *testing.T) {
  251. store, err := store.GetKV("stream")
  252. if err != nil {
  253. t.Error(err)
  254. return
  255. }
  256. streamSqls := map[string]string{
  257. "src1": `CREATE STREAM src1 (
  258. ) WITH (DATASOURCE="src1", FORMAT="json", KEY="ts");`,
  259. }
  260. types := map[string]ast.StreamType{
  261. "src1": ast.TypeStream,
  262. }
  263. for name, sql := range streamSqls {
  264. s, err := json.Marshal(&xsql.StreamInfo{
  265. StreamType: types[name],
  266. Statement: sql,
  267. })
  268. if err != nil {
  269. t.Error(err)
  270. t.Fail()
  271. }
  272. store.Set(name, string(s))
  273. }
  274. streams := make(map[string]*ast.StreamStmt)
  275. for n := range streamSqls {
  276. streamStmt, err := xsql.GetDataSource(store, n)
  277. if err != nil {
  278. t.Errorf("fail to get stream %s, please check if stream is created", n)
  279. return
  280. }
  281. streams[n] = streamStmt
  282. }
  283. fmt.Printf("The test bucket size is %d.\n\n", len(tests))
  284. for i, tt := range tests {
  285. stmt, err := xsql.NewParser(strings.NewReader(tt.sql)).Parse()
  286. if err != nil {
  287. t.Errorf("%d. %q: error compile sql: %s\n", i, tt.sql, err)
  288. continue
  289. }
  290. _, err = createLogicalPlan(stmt, &api.RuleOption{
  291. IsEventTime: false,
  292. LateTol: 0,
  293. Concurrency: 0,
  294. BufferLength: 0,
  295. SendMetaToSink: false,
  296. Qos: 0,
  297. CheckpointInterval: 0,
  298. SendError: true,
  299. }, store)
  300. serr := tt.r.Serr()
  301. if !reflect.DeepEqual(serr, testx.Errstring(err)) {
  302. t.Errorf("%d. %q: error mismatch:\n exp=%s\n got=%s\n\n", i, tt.sql, serr, err)
  303. }
  304. }
  305. }
  306. func TestConvertStreamInfo(t *testing.T) {
  307. testCases := []struct {
  308. name string
  309. streamStmt *ast.StreamStmt
  310. expected ast.StreamFields
  311. }{
  312. {
  313. name: "with match fields & schema",
  314. streamStmt: &ast.StreamStmt{
  315. StreamFields: []ast.StreamField{
  316. {
  317. Name: "field1",
  318. FieldType: &ast.BasicType{
  319. Type: ast.BIGINT,
  320. },
  321. },
  322. {
  323. Name: "field2",
  324. FieldType: &ast.BasicType{
  325. Type: ast.STRINGS,
  326. },
  327. },
  328. },
  329. Options: &ast.Options{
  330. FORMAT: "protobuf",
  331. SCHEMAID: "myschema.schema1",
  332. TIMESTAMP: "ts",
  333. },
  334. },
  335. expected: []ast.StreamField{
  336. {
  337. Name: "field1",
  338. FieldType: &ast.BasicType{
  339. Type: ast.BIGINT,
  340. },
  341. },
  342. {
  343. Name: "field2",
  344. FieldType: &ast.BasicType{
  345. Type: ast.STRINGS,
  346. },
  347. },
  348. },
  349. },
  350. {
  351. name: "with unmatch fields & schema",
  352. streamStmt: &ast.StreamStmt{
  353. StreamFields: []ast.StreamField{
  354. {
  355. Name: "field1",
  356. FieldType: &ast.BasicType{
  357. Type: ast.STRINGS,
  358. },
  359. },
  360. {
  361. Name: "field2",
  362. FieldType: &ast.BasicType{
  363. Type: ast.STRINGS,
  364. },
  365. },
  366. },
  367. Options: &ast.Options{
  368. FORMAT: "protobuf",
  369. SCHEMAID: "myschema.schema1",
  370. TIMESTAMP: "ts",
  371. },
  372. },
  373. expected: []ast.StreamField{
  374. {
  375. Name: "field1",
  376. FieldType: &ast.BasicType{
  377. Type: ast.BIGINT,
  378. },
  379. },
  380. {
  381. Name: "field2",
  382. FieldType: &ast.BasicType{
  383. Type: ast.STRINGS,
  384. },
  385. },
  386. },
  387. },
  388. {
  389. name: "without schema",
  390. streamStmt: &ast.StreamStmt{
  391. StreamFields: []ast.StreamField{
  392. {
  393. Name: "field1",
  394. FieldType: &ast.BasicType{
  395. Type: ast.FLOAT,
  396. },
  397. },
  398. {
  399. Name: "field2",
  400. FieldType: &ast.BasicType{
  401. Type: ast.STRINGS,
  402. },
  403. },
  404. },
  405. Options: &ast.Options{
  406. FORMAT: "json",
  407. TIMESTAMP: "ts",
  408. },
  409. },
  410. expected: []ast.StreamField{
  411. {
  412. Name: "field1",
  413. FieldType: &ast.BasicType{
  414. Type: ast.FLOAT,
  415. },
  416. },
  417. {
  418. Name: "field2",
  419. FieldType: &ast.BasicType{
  420. Type: ast.STRINGS,
  421. },
  422. },
  423. },
  424. },
  425. {
  426. name: "without fields",
  427. streamStmt: &ast.StreamStmt{
  428. Options: &ast.Options{
  429. FORMAT: "protobuf",
  430. SCHEMAID: "myschema.schema1",
  431. TIMESTAMP: "ts",
  432. },
  433. },
  434. expected: []ast.StreamField{
  435. {
  436. Name: "field1",
  437. FieldType: &ast.BasicType{
  438. Type: ast.BIGINT,
  439. },
  440. },
  441. {
  442. Name: "field2",
  443. FieldType: &ast.BasicType{
  444. Type: ast.STRINGS,
  445. },
  446. },
  447. },
  448. },
  449. {
  450. name: "schemaless",
  451. streamStmt: &ast.StreamStmt{
  452. Options: &ast.Options{
  453. FORMAT: "json",
  454. TIMESTAMP: "ts",
  455. },
  456. },
  457. expected: nil,
  458. },
  459. }
  460. for _, tc := range testCases {
  461. t.Run(tc.name, func(t *testing.T) {
  462. actual, err := convertStreamInfo(tc.streamStmt)
  463. if err != nil {
  464. t.Errorf("unexpected error: %v", err)
  465. return
  466. }
  467. if !reflect.DeepEqual(actual.schema, tc.expected) {
  468. t.Errorf("unexpected result: got %v, want %v", actual.schema, tc.expected)
  469. }
  470. })
  471. }
  472. }