sorter.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. // Copyright 2022-2023 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package xsql
  15. import (
  16. "fmt"
  17. "sort"
  18. "github.com/lf-edge/ekuiper/pkg/ast"
  19. "github.com/lf-edge/ekuiper/pkg/cast"
  20. )
  21. // MultiSorter implements the Sort interface, sorting the changes within.
  22. type MultiSorter struct {
  23. SortingData
  24. fields ast.SortFields
  25. valuer *FunctionValuer
  26. aggValuer *AggregateFunctionValuer
  27. values []map[string]interface{}
  28. }
  29. // OrderedBy returns a Sorter that sorts using the less functions, in order.
  30. // Call its Sort method to sort the data.
  31. func OrderedBy(fields ast.SortFields, fv *FunctionValuer, afv *AggregateFunctionValuer) *MultiSorter {
  32. return &MultiSorter{
  33. fields: fields,
  34. valuer: fv,
  35. aggValuer: afv,
  36. }
  37. }
  38. // Less is part of sort.Interface. It is implemented by looping along the
  39. // less functions until it finds a comparison that discriminates between
  40. // the two items (one is less than the other). Note that it can call the
  41. // less functions twice per call. We could change the functions to return
  42. // -1, 0, 1 and reduce the number of calls for greater efficiency: an
  43. // exercise for the reader.
  44. func (ms *MultiSorter) Less(i, j int) bool {
  45. p, q := ms.values[i], ms.values[j]
  46. v := &ValuerEval{Valuer: MultiValuer(ms.valuer)}
  47. for _, field := range ms.fields {
  48. n := field.Uname
  49. vp, _ := p[n]
  50. vq, _ := q[n]
  51. if vp == nil && vq != nil {
  52. return false
  53. } else if vp != nil && vq == nil {
  54. return true
  55. } else if vp == nil && vq == nil {
  56. return false
  57. }
  58. switch {
  59. case v.simpleDataEval(vp, vq, ast.LT):
  60. return field.Ascending
  61. case v.simpleDataEval(vq, vp, ast.LT):
  62. return !field.Ascending
  63. }
  64. }
  65. return false
  66. }
  67. func (ms *MultiSorter) Swap(i, j int) {
  68. ms.values[i], ms.values[j] = ms.values[j], ms.values[i]
  69. ms.SortingData.Swap(i, j)
  70. }
  71. // Sort sorts the argument slice according to the less functions passed to OrderedBy.
  72. func (ms *MultiSorter) Sort(data SortingData) error {
  73. ms.SortingData = data
  74. types := make([]string, len(ms.fields))
  75. ms.values = make([]map[string]interface{}, data.Len())
  76. switch input := data.(type) {
  77. case error:
  78. return input
  79. case SingleCollection:
  80. err := input.RangeSet(func(i int, row Row) (bool, error) {
  81. ms.values[i] = make(map[string]interface{})
  82. vep := &ValuerEval{Valuer: MultiValuer(ms.valuer, row, ms.valuer, &WildcardValuer{Data: row})}
  83. for j, field := range ms.fields {
  84. vp := vep.Eval(field.FieldExpr)
  85. if types[j] == "" && vp != nil {
  86. types[j] = fmt.Sprintf("%T", vp)
  87. }
  88. if err := validate(types[j], vp); err != nil {
  89. return false, err
  90. } else {
  91. ms.values[i][field.Uname] = vp
  92. }
  93. }
  94. return true, nil
  95. })
  96. if err != nil {
  97. return err
  98. }
  99. case GroupedCollection:
  100. err := input.GroupRange(func(i int, aggRow CollectionRow) (bool, error) {
  101. ms.values[i] = make(map[string]interface{})
  102. ms.aggValuer.SetData(aggRow)
  103. vep := &ValuerEval{Valuer: MultiAggregateValuer(aggRow, ms.valuer, aggRow, ms.aggValuer, &WildcardValuer{Data: aggRow})}
  104. for j, field := range ms.fields {
  105. vp := vep.Eval(field.FieldExpr)
  106. if types[j] == "" && vp != nil {
  107. types[j] = fmt.Sprintf("%T", vp)
  108. }
  109. if err := validate(types[j], vp); err != nil {
  110. return false, err
  111. } else {
  112. ms.values[i][field.Uname] = vp
  113. }
  114. }
  115. return true, nil
  116. })
  117. if err != nil {
  118. return err
  119. }
  120. }
  121. sort.Sort(ms)
  122. return nil
  123. }
  124. func validate(t string, v interface{}) error {
  125. if v == nil || t == "" {
  126. return nil
  127. }
  128. vt := fmt.Sprintf("%T", v)
  129. switch t {
  130. case "int", "int64", "float64", "uint64":
  131. if vt == "int" || vt == "int64" || vt == "float64" || vt == "uint64" {
  132. return nil
  133. } else {
  134. return fmt.Errorf("incompatible types for comparison: %s and %s", t, vt)
  135. }
  136. case "bool":
  137. if vt == "bool" {
  138. return nil
  139. } else {
  140. return fmt.Errorf("incompatible types for comparison: %s and %s", t, vt)
  141. }
  142. case "string":
  143. if vt == "string" {
  144. return nil
  145. } else {
  146. return fmt.Errorf("incompatible types for comparison: %s and %s", t, vt)
  147. }
  148. case "time.Time":
  149. _, err := cast.InterfaceToTime(v, "")
  150. if err != nil {
  151. return fmt.Errorf("incompatible types for comparison: %s and %s", t, vt)
  152. } else {
  153. return nil
  154. }
  155. default:
  156. return fmt.Errorf("incompatible types for comparison: %s and %s", t, vt)
  157. }
  158. }