labelImage.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // +build tflite
  2. package main
  3. import (
  4. "bufio"
  5. "bytes"
  6. "fmt"
  7. "github.com/emqx/kuiper/pkg/api"
  8. tflite "github.com/mattn/go-tflite"
  9. "github.com/nfnt/resize"
  10. "image"
  11. _ "image/jpeg"
  12. _ "image/png"
  13. "os"
  14. "path"
  15. "sort"
  16. "sync"
  17. )
  18. type labelImage struct {
  19. modelPath string
  20. labelPath string
  21. once sync.Once
  22. interpreter *tflite.Interpreter
  23. labels []string
  24. }
  25. func (f *labelImage) Validate(args []interface{}) error {
  26. if len(args) != 1 {
  27. return fmt.Errorf("labelImage function only supports 1 parameter but got %d", len(args))
  28. }
  29. return nil
  30. }
  31. func (f *labelImage) Exec(args []interface{}, ctx api.FunctionContext) (interface{}, bool) {
  32. arg0, ok := args[0].([]byte)
  33. if !ok {
  34. return fmt.Errorf("labelImage function parameter must be a bytea, but got %[1]T(%[1]v)", args[0]), false
  35. }
  36. img, _, err := image.Decode(bytes.NewReader(arg0))
  37. if err != nil {
  38. return err, false
  39. }
  40. var outerErr error
  41. f.once.Do(func() {
  42. ploc := path.Join(ctx.GetRootPath(), "etc", "functions")
  43. f.labels, err = loadLabels(path.Join(ploc, f.labelPath))
  44. if err != nil {
  45. outerErr = fmt.Errorf("fail to load labels: %s", err)
  46. return
  47. }
  48. model := tflite.NewModelFromFile(path.Join(ploc, f.modelPath))
  49. if model == nil {
  50. outerErr = fmt.Errorf("fail to load model: %s", err)
  51. return
  52. }
  53. defer model.Delete()
  54. options := tflite.NewInterpreterOptions()
  55. options.SetNumThread(4)
  56. options.SetErrorReporter(func(msg string, user_data interface{}) {
  57. fmt.Println(msg)
  58. }, nil)
  59. defer options.Delete()
  60. interpreter := tflite.NewInterpreter(model, options)
  61. if interpreter == nil {
  62. outerErr = fmt.Errorf("cannot create interpreter")
  63. return
  64. }
  65. status := interpreter.AllocateTensors()
  66. if status != tflite.OK {
  67. outerErr = fmt.Errorf("allocate failed")
  68. interpreter.Delete()
  69. return
  70. }
  71. f.interpreter = interpreter
  72. // TODO If created, the interpreter will be kept through the whole life of kuiper. Refactor this later.
  73. //defer interpreter.Delete()
  74. })
  75. if f.interpreter == nil {
  76. return fmt.Errorf("fail to load model %s %s", f.modelPath, outerErr), false
  77. }
  78. input := f.interpreter.GetInputTensor(0)
  79. wantedHeight := input.Dim(1)
  80. wantedWidth := input.Dim(2)
  81. wantedChannels := input.Dim(3)
  82. wantedType := input.Type()
  83. resized := resize.Resize(uint(wantedWidth), uint(wantedHeight), img, resize.NearestNeighbor)
  84. bounds := resized.Bounds()
  85. dx, dy := bounds.Dx(), bounds.Dy()
  86. if wantedType == tflite.UInt8 {
  87. bb := make([]byte, dx*dy*wantedChannels)
  88. for y := 0; y < dy; y++ {
  89. for x := 0; x < dx; x++ {
  90. col := resized.At(x, y)
  91. r, g, b, _ := col.RGBA()
  92. bb[(y*dx+x)*3+0] = byte(float64(r) / 255.0)
  93. bb[(y*dx+x)*3+1] = byte(float64(g) / 255.0)
  94. bb[(y*dx+x)*3+2] = byte(float64(b) / 255.0)
  95. }
  96. }
  97. input.CopyFromBuffer(bb)
  98. } else {
  99. return fmt.Errorf("is not wanted type"), false
  100. }
  101. status := f.interpreter.Invoke()
  102. if status != tflite.OK {
  103. return fmt.Errorf("invoke failed"), false
  104. }
  105. output := f.interpreter.GetOutputTensor(0)
  106. outputSize := output.Dim(output.NumDims() - 1)
  107. b := make([]byte, outputSize)
  108. type result struct {
  109. score float64
  110. index int
  111. }
  112. status = output.CopyToBuffer(&b[0])
  113. if status != tflite.OK {
  114. return fmt.Errorf("output failed"), false
  115. }
  116. var results []result
  117. for i := 0; i < outputSize; i++ {
  118. score := float64(b[i]) / 255.0
  119. if score < 0.2 {
  120. continue
  121. }
  122. results = append(results, result{score: score, index: i})
  123. }
  124. sort.Slice(results, func(i, j int) bool {
  125. return results[i].score > results[j].score
  126. })
  127. // output is the biggest score labelImage
  128. if len(results) > 0 {
  129. return f.labels[results[0].index], true
  130. } else {
  131. return "", true
  132. }
  133. }
  134. func (f *labelImage) IsAggregate() bool {
  135. return false
  136. }
  137. func loadLabels(filename string) ([]string, error) {
  138. labels := []string{}
  139. f, err := os.Open(filename)
  140. if err != nil {
  141. return nil, err
  142. }
  143. defer f.Close()
  144. scanner := bufio.NewScanner(f)
  145. for scanner.Scan() {
  146. labels = append(labels, scanner.Text())
  147. }
  148. return labels, nil
  149. }
  150. var LabelImage = labelImage{
  151. modelPath: "labelImage/mobilenet_quant_v1_224.tflite",
  152. labelPath: "labelImage/labels.txt",
  153. }