sql_test.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. // Copyright 2022-2023 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package main
  15. import (
  16. "database/sql"
  17. "database/sql/driver"
  18. "fmt"
  19. "os"
  20. "reflect"
  21. "testing"
  22. "github.com/stretchr/testify/assert"
  23. "github.com/lf-edge/ekuiper/extensions/sqldatabase"
  24. econf "github.com/lf-edge/ekuiper/internal/conf"
  25. "github.com/lf-edge/ekuiper/internal/topo/context"
  26. )
  27. func TestSingle(t *testing.T) {
  28. db, err := sql.Open("sqlite", "file:test.db")
  29. if err != nil {
  30. t.Error(err)
  31. return
  32. }
  33. contextLogger := econf.Log.WithField("rule", "test")
  34. ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
  35. s := &sqlSink{}
  36. defer func() {
  37. db.Close()
  38. s.Close(ctx)
  39. err := os.Remove("test.db")
  40. if err != nil {
  41. fmt.Println(err)
  42. }
  43. }()
  44. _, err = db.Exec("CREATE TABLE IF NOT EXISTS single (id BIGINT PRIMARY KEY, name TEXT NOT NULL, address varchar(20), mobile varchar(20))")
  45. if err != nil {
  46. panic(err)
  47. }
  48. err = s.Configure(map[string]interface{}{
  49. "url": "sqlite://test.db",
  50. "table": "single",
  51. })
  52. if err != nil {
  53. t.Error(err)
  54. return
  55. }
  56. err = s.Open(ctx)
  57. if err != nil {
  58. t.Error(err)
  59. return
  60. }
  61. data := []map[string]interface{}{
  62. {"id": 1, "name": "John", "address": "343", "mobile": "334433"},
  63. {"id": 2, "name": "Susan", "address": "34", "mobile": "334433"},
  64. {"id": 3, "name": "Susan", "address": "34", "mobile": "334433"},
  65. }
  66. for _, d := range data {
  67. err = s.Collect(ctx, d)
  68. if err != nil {
  69. t.Error(err)
  70. return
  71. }
  72. }
  73. s.Close(ctx)
  74. rows, err := db.Query("SELECT * FROM single")
  75. if err != nil {
  76. t.Error(err)
  77. return
  78. }
  79. act, _ := rowsToMap(rows)
  80. exp := []map[string]interface{}{
  81. {"id": int64(1), "name": "John", "address": "343", "mobile": "334433"},
  82. {"id": int64(2), "name": "Susan", "address": "34", "mobile": "334433"},
  83. {"id": int64(3), "name": "Susan", "address": "34", "mobile": "334433"},
  84. }
  85. if !reflect.DeepEqual(act, exp) {
  86. t.Errorf("Expect %v but got %v", exp, act)
  87. }
  88. }
  89. func TestBatch(t *testing.T) {
  90. db, err := sql.Open("sqlite", "file:test.db")
  91. if err != nil {
  92. t.Error(err)
  93. return
  94. }
  95. contextLogger := econf.Log.WithField("rule", "test")
  96. ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
  97. s := &sqlSink{}
  98. defer func() {
  99. db.Close()
  100. s.Close(ctx)
  101. err := os.Remove("test.db")
  102. if err != nil {
  103. fmt.Println(err)
  104. }
  105. }()
  106. _, err = db.Exec("CREATE TABLE IF NOT EXISTS batch (id BIGINT PRIMARY KEY, name TEXT NOT NULL)")
  107. if err != nil {
  108. panic(err)
  109. }
  110. err = s.Configure(map[string]interface{}{
  111. "url": "sqlite://test.db",
  112. "table": "batch",
  113. "fields": []string{"id", "name"},
  114. })
  115. if err != nil {
  116. t.Error(err)
  117. return
  118. }
  119. err = s.Open(ctx)
  120. if err != nil {
  121. t.Error(err)
  122. return
  123. }
  124. data := []map[string]interface{}{
  125. {"id": 1, "name": "John", "address": "343", "mobile": "334433"},
  126. {"id": 2, "name": "Susan", "address": "34", "mobile": "334433"},
  127. {"id": 3, "name": "Susan", "address": "34", "mobile": "334433"},
  128. }
  129. err = s.Collect(ctx, data)
  130. if err != nil {
  131. t.Error(err)
  132. return
  133. }
  134. s.Close(ctx)
  135. rows, err := db.Query("SELECT * FROM batch")
  136. if err != nil {
  137. t.Error(err)
  138. return
  139. }
  140. act, _ := rowsToMap(rows)
  141. exp := []map[string]interface{}{
  142. {"id": int64(1), "name": "John"},
  143. {"id": int64(2), "name": "Susan"},
  144. {"id": int64(3), "name": "Susan"},
  145. }
  146. if !reflect.DeepEqual(act, exp) {
  147. t.Errorf("Expect %v but got %v", exp, act)
  148. }
  149. }
  150. func TestUpdate(t *testing.T) {
  151. db, err := sql.Open("sqlite", "file:test.db")
  152. if err != nil {
  153. t.Error(err)
  154. return
  155. }
  156. contextLogger := econf.Log.WithField("rule", "test")
  157. ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
  158. s := &sqlSink{}
  159. defer func() {
  160. db.Close()
  161. s.Close(ctx)
  162. err := os.Remove("test.db")
  163. if err != nil {
  164. fmt.Println(err)
  165. }
  166. }()
  167. _, err = db.Exec("CREATE TABLE IF NOT EXISTS updateTable (id BIGINT PRIMARY KEY, name TEXT NOT NULL)")
  168. if err != nil {
  169. panic(err)
  170. }
  171. err = s.Configure(map[string]interface{}{
  172. "url": "sqlite://test.db",
  173. "table": "updateTable",
  174. "rowkindField": "action",
  175. "keyField": "id",
  176. "fields": []string{"id", "name"},
  177. })
  178. if err != nil {
  179. t.Error(err)
  180. return
  181. }
  182. err = s.Open(ctx)
  183. if err != nil {
  184. t.Error(err)
  185. return
  186. }
  187. test := []struct {
  188. d []map[string]interface{}
  189. b bool
  190. r []map[string]interface{}
  191. }{
  192. {
  193. d: []map[string]interface{}{
  194. {"id": 1, "name": "John", "address": "343", "mobile": "334433"},
  195. {"action": "insert", "id": 2, "name": "Susan", "address": "34", "mobile": "334433"},
  196. {"action": "update", "id": 2, "name": "Diana"},
  197. },
  198. b: true,
  199. r: []map[string]interface{}{
  200. {"id": int64(1), "name": "John"},
  201. {"id": int64(2), "name": "Diana"},
  202. },
  203. }, {
  204. d: []map[string]interface{}{
  205. {"id": 4, "name": "Charles", "address": "343", "mobile": "334433"},
  206. {"action": "delete", "id": 2},
  207. {"action": "update", "id": 1, "name": "Lizz"},
  208. },
  209. b: false,
  210. r: []map[string]interface{}{
  211. {"id": int64(1), "name": "Lizz"},
  212. {"id": int64(4), "name": "Charles"},
  213. },
  214. }, {
  215. d: []map[string]interface{}{
  216. {"action": "upsert", "id": 4, "name": "Charles", "address": "343", "mobile": "334433"},
  217. {"action": "update", "id": 3, "name": "Lizz"},
  218. {"action": "update", "id": 1, "name": "Philips"},
  219. },
  220. b: true,
  221. r: []map[string]interface{}{
  222. {"id": int64(1), "name": "Philips"},
  223. {"id": int64(4), "name": "Charles"},
  224. },
  225. },
  226. }
  227. for i, tt := range test {
  228. if tt.b {
  229. err = s.Collect(ctx, tt.d)
  230. if err != nil {
  231. fmt.Println(err)
  232. }
  233. } else {
  234. for _, d := range tt.d {
  235. err = s.Collect(ctx, d)
  236. if err != nil {
  237. fmt.Println(err)
  238. }
  239. }
  240. }
  241. rows, err := db.Query("SELECT * FROM updateTable")
  242. if err != nil {
  243. t.Error(err)
  244. return
  245. }
  246. act, _ := rowsToMap(rows)
  247. if !reflect.DeepEqual(act, tt.r) {
  248. t.Errorf("Case %d Expect %v but got %v", i, tt.r, act)
  249. }
  250. }
  251. }
  252. func TestSaveSql(t *testing.T) {
  253. contextLogger := econf.Log.WithField("rule", "test")
  254. ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
  255. s := &sqlSink{}
  256. mdb := &sqldatabase.MockDB{}
  257. s.db = mdb
  258. s.conf = &sqlConfig{
  259. Fields: []string{ // set fields to make sure order is always testable
  260. "id", "name", "address", "mobile",
  261. },
  262. KeyField: "id",
  263. RowkindField: "action",
  264. }
  265. test := []struct {
  266. name string
  267. d map[string]interface{}
  268. s string
  269. }{
  270. {
  271. name: "insert",
  272. d: map[string]interface{}{"id": 1, "name": "John", "address": "343", "mobile": "334433"},
  273. s: "INSERT INTO test (id,name,address,mobile) values (1,'John','343','334433');",
  274. },
  275. {
  276. name: "update",
  277. d: map[string]interface{}{"action": "update", "id": 1, "name": "John", "address": "343", "mobile": "334433"},
  278. s: "UPDATE test SET id=1,name='John',address='343',mobile='334433' WHERE id = 1;",
  279. },
  280. }
  281. for _, tt := range test {
  282. t.Run(tt.name, func(t *testing.T) {
  283. err := s.save(ctx, "test", tt.d)
  284. assert.NoError(t, err)
  285. assert.Equal(t, tt.s, mdb.LastSql())
  286. })
  287. }
  288. }
  289. func rowsToMap(rows *sql.Rows) ([]map[string]interface{}, error) {
  290. cols, _ := rows.Columns()
  291. types, err := rows.ColumnTypes()
  292. if err != nil {
  293. return nil, err
  294. }
  295. var result []map[string]interface{}
  296. for rows.Next() {
  297. data := make(map[string]interface{})
  298. columns := make([]interface{}, len(cols))
  299. prepareValues(columns, types, cols)
  300. err := rows.Scan(columns...)
  301. if err != nil {
  302. return nil, err
  303. }
  304. scanIntoMap(data, columns, cols)
  305. result = append(result, data)
  306. }
  307. return result, nil
  308. }
  309. func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
  310. for idx, column := range columns {
  311. if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
  312. mapValue[column] = reflectValue.Interface()
  313. if valuer, ok := mapValue[column].(driver.Valuer); ok {
  314. mapValue[column], _ = valuer.Value()
  315. } else if b, ok := mapValue[column].(sql.RawBytes); ok {
  316. mapValue[column] = string(b)
  317. }
  318. } else {
  319. mapValue[column] = nil
  320. }
  321. }
  322. }
  323. func prepareValues(values []interface{}, columnTypes []*sql.ColumnType, columns []string) {
  324. if len(columnTypes) > 0 {
  325. for idx, columnType := range columnTypes {
  326. if columnType.ScanType() != nil {
  327. values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
  328. } else {
  329. values[idx] = new(interface{})
  330. }
  331. }
  332. } else {
  333. for idx := range columns {
  334. values[idx] = new(interface{})
  335. }
  336. }
  337. }