fieldConverterSingleton.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. // Copyright 2022 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package protobuf
  15. import (
  16. "fmt"
  17. "github.com/golang/protobuf/proto"
  18. dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
  19. "github.com/jhump/protoreflect/desc"
  20. "github.com/jhump/protoreflect/dynamic"
  21. "github.com/lf-edge/ekuiper/pkg/cast"
  22. )
  23. const (
  24. WrapperBool = "google.protobuf.BoolValue"
  25. WrapperBytes = "google.protobuf.BytesValue"
  26. WrapperDouble = "google.protobuf.DoubleValue"
  27. WrapperFloat = "google.protobuf.FloatValue"
  28. WrapperInt32 = "google.protobuf.Int32Value"
  29. WrapperInt64 = "google.protobuf.Int64Value"
  30. WrapperString = "google.protobuf.StringValue"
  31. WrapperUInt32 = "google.protobuf.UInt32Value"
  32. WrapperUInt64 = "google.protobuf.UInt64Value"
  33. WrapperVoid = "google.protobuf.EMPTY"
  34. )
  35. var WRAPPER_TYPES = map[string]struct{}{
  36. WrapperBool: {},
  37. WrapperBytes: {},
  38. WrapperDouble: {},
  39. WrapperFloat: {},
  40. WrapperInt32: {},
  41. WrapperInt64: {},
  42. WrapperString: {},
  43. WrapperUInt32: {},
  44. WrapperUInt64: {},
  45. }
  46. var (
  47. fieldConverterIns = &FieldConverter{}
  48. mf = dynamic.NewMessageFactoryWithDefaults()
  49. )
  50. type FieldConverter struct{}
  51. func GetFieldConverter() *FieldConverter {
  52. return fieldConverterIns
  53. }
  54. func (fc *FieldConverter) encodeMap(im *desc.MessageDescriptor, i interface{}) (*dynamic.Message, error) {
  55. result := mf.NewDynamicMessage(im)
  56. fields := im.GetFields()
  57. if m, ok := i.(map[string]interface{}); ok {
  58. for _, field := range fields {
  59. v, ok := m[field.GetName()]
  60. if !ok {
  61. return nil, fmt.Errorf("field %s not found", field.GetName())
  62. }
  63. fv, err := fc.EncodeField(field, v)
  64. if err != nil {
  65. return nil, err
  66. }
  67. result.SetFieldByName(field.GetName(), fv)
  68. }
  69. }
  70. return result, nil
  71. }
  72. func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
  73. fn := field.GetName()
  74. ft := field.GetType()
  75. if field.IsRepeated() {
  76. var (
  77. result interface{}
  78. err error
  79. )
  80. switch ft {
  81. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_DOUBLE):
  82. result, err = cast.ToFloat64Slice(v, cast.STRICT)
  83. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FLOAT):
  84. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  85. r, err := cast.ToFloat64(input, sn)
  86. if err != nil {
  87. return 0, nil
  88. } else {
  89. return float32(r), nil
  90. }
  91. }, "float", cast.STRICT)
  92. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT32):
  93. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  94. r, err := cast.ToInt(input, sn)
  95. if err != nil {
  96. return 0, nil
  97. } else {
  98. return int32(r), nil
  99. }
  100. }, "int", cast.STRICT)
  101. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT64):
  102. result, err = cast.ToInt64Slice(v, cast.STRICT)
  103. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT32):
  104. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  105. r, err := cast.ToUint64(input, sn)
  106. if err != nil {
  107. return 0, nil
  108. } else {
  109. return uint32(r), nil
  110. }
  111. }, "uint", cast.STRICT)
  112. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT64):
  113. result, err = cast.ToUint64Slice(v, cast.STRICT)
  114. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BOOL):
  115. result, err = cast.ToBoolSlice(v, cast.STRICT)
  116. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_STRING):
  117. result, err = cast.ToStringSlice(v, cast.STRICT)
  118. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BYTES):
  119. result, err = cast.ToBytesSlice(v, cast.STRICT)
  120. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_MESSAGE):
  121. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  122. r, err := cast.ToStringMap(v)
  123. if err == nil {
  124. return fc.encodeMap(field.GetMessageType(), r)
  125. } else {
  126. return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
  127. }
  128. }, "map", cast.STRICT)
  129. default:
  130. return nil, fmt.Errorf("invalid type for field '%s'", fn)
  131. }
  132. if err != nil {
  133. err = fmt.Errorf("failed to encode field '%s':%v", fn, err)
  134. }
  135. return result, err
  136. } else {
  137. return fc.encodeSingleField(field, v)
  138. }
  139. }
  140. func (fc *FieldConverter) encodeSingleField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
  141. fn := field.GetName()
  142. switch field.GetType() {
  143. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_DOUBLE):
  144. r, err := cast.ToFloat64(v, cast.STRICT)
  145. if err == nil {
  146. return r, nil
  147. } else {
  148. return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
  149. }
  150. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FLOAT):
  151. r, err := cast.ToFloat64(v, cast.STRICT)
  152. if err == nil {
  153. return float32(r), nil
  154. } else {
  155. return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
  156. }
  157. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT32):
  158. r, err := cast.ToInt(v, cast.STRICT)
  159. if err == nil {
  160. return int32(r), nil
  161. } else {
  162. return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
  163. }
  164. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT64):
  165. r, err := cast.ToInt64(v, cast.STRICT)
  166. if err == nil {
  167. return r, nil
  168. } else {
  169. return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
  170. }
  171. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT32):
  172. r, err := cast.ToUint64(v, cast.STRICT)
  173. if err == nil {
  174. return uint32(r), nil
  175. } else {
  176. return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
  177. }
  178. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT64):
  179. r, err := cast.ToUint64(v, cast.STRICT)
  180. if err == nil {
  181. return r, nil
  182. } else {
  183. return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
  184. }
  185. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BOOL):
  186. r, err := cast.ToBool(v, cast.STRICT)
  187. if err == nil {
  188. return r, nil
  189. } else {
  190. return nil, fmt.Errorf("invalid type for bool type field '%s': %v", fn, err)
  191. }
  192. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_STRING):
  193. r, err := cast.ToString(v, cast.STRICT)
  194. if err == nil {
  195. return r, nil
  196. } else {
  197. return nil, fmt.Errorf("invalid type for string type field '%s': %v", fn, err)
  198. }
  199. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BYTES):
  200. r, err := cast.ToBytes(v, cast.STRICT)
  201. if err == nil {
  202. return r, nil
  203. } else {
  204. return nil, fmt.Errorf("invalid type for bytes type field '%s': %v", fn, err)
  205. }
  206. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_MESSAGE):
  207. r, err := cast.ToStringMap(v)
  208. if err == nil {
  209. return fc.encodeMap(field.GetMessageType(), r)
  210. } else {
  211. return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
  212. }
  213. default:
  214. return nil, fmt.Errorf("invalid type for field '%s'", fn)
  215. }
  216. }
  217. func (fc *FieldConverter) DecodeField(src interface{}, field *desc.FieldDescriptor, sn cast.Strictness) (interface{}, error) {
  218. var (
  219. r interface{}
  220. e error
  221. )
  222. fn := field.GetName()
  223. switch field.GetType() {
  224. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_DOUBLE), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FLOAT):
  225. if field.IsRepeated() {
  226. r, e = cast.ToFloat64Slice(src, sn)
  227. } else {
  228. r, e = cast.ToFloat64(src, sn)
  229. }
  230. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT64):
  231. if field.IsRepeated() {
  232. r, e = cast.ToInt64Slice(src, sn)
  233. } else {
  234. r, e = cast.ToInt64(src, sn)
  235. }
  236. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BOOL):
  237. if field.IsRepeated() {
  238. r, e = cast.ToBoolSlice(src, sn)
  239. } else {
  240. r, e = cast.ToBool(src, sn)
  241. }
  242. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_STRING):
  243. if field.IsRepeated() {
  244. r, e = cast.ToStringSlice(src, sn)
  245. } else {
  246. r, e = cast.ToString(src, sn)
  247. }
  248. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BYTES):
  249. if field.IsRepeated() {
  250. r, e = cast.ToBytesSlice(src, sn)
  251. } else {
  252. r, e = cast.ToBytes(src, sn)
  253. }
  254. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_MESSAGE):
  255. if field.IsRepeated() {
  256. r, e = cast.ToTypedSlice(src, func(input interface{}, ssn cast.Strictness) (interface{}, error) {
  257. return fc.decodeSubMessage(input, field.GetMessageType(), ssn)
  258. }, "map", sn)
  259. } else {
  260. r, e = fc.decodeSubMessage(src, field.GetMessageType(), sn)
  261. }
  262. default:
  263. return nil, fmt.Errorf("unsupported type for %s", fn)
  264. }
  265. if e != nil {
  266. e = fmt.Errorf("invalid type of return value for '%s': %v", fn, e)
  267. }
  268. return r, e
  269. }
  270. func (fc *FieldConverter) decodeSubMessage(input interface{}, ft *desc.MessageDescriptor, sn cast.Strictness) (interface{}, error) {
  271. var m = map[string]interface{}{}
  272. switch v := input.(type) {
  273. case map[interface{}]interface{}:
  274. for k, val := range v {
  275. m[cast.ToStringAlways(k)] = val
  276. }
  277. return fc.DecodeMap(m, ft, sn)
  278. case map[string]interface{}:
  279. return fc.DecodeMap(v, ft, sn)
  280. case proto.Message:
  281. message, err := dynamic.AsDynamicMessage(v)
  282. if err != nil {
  283. return nil, err
  284. }
  285. return fc.DecodeMessage(message, ft), nil
  286. case *dynamic.Message:
  287. return fc.DecodeMessage(v, ft), nil
  288. default:
  289. return nil, fmt.Errorf("cannot decode %[1]T(%[1]v) to map", input)
  290. }
  291. }
  292. func (fc *FieldConverter) DecodeMap(src map[string]interface{}, ft *desc.MessageDescriptor, sn cast.Strictness) (map[string]interface{}, error) {
  293. result := make(map[string]interface{})
  294. for _, field := range ft.GetFields() {
  295. val, ok := src[field.GetName()]
  296. if !ok {
  297. continue
  298. }
  299. err := fc.decodeMessageField(val, field, result, sn)
  300. if err != nil {
  301. return nil, err
  302. }
  303. }
  304. return result, nil
  305. }
  306. func (fc *FieldConverter) decodeMessageField(src interface{}, field *desc.FieldDescriptor, result map[string]interface{}, sn cast.Strictness) error {
  307. if f, err := fc.DecodeField(src, field, sn); err != nil {
  308. return err
  309. } else {
  310. result[field.GetName()] = f
  311. return nil
  312. }
  313. }
  314. func (fc *FieldConverter) DecodeMessage(message *dynamic.Message, outputType *desc.MessageDescriptor) interface{} {
  315. if _, ok := WRAPPER_TYPES[outputType.GetFullyQualifiedName()]; ok {
  316. return message.GetFieldByNumber(1)
  317. } else if WrapperVoid == outputType.GetFullyQualifiedName() {
  318. return nil
  319. }
  320. result := make(map[string]interface{})
  321. for _, field := range outputType.GetFields() {
  322. fc.decodeMessageField(message.GetField(field), field, result, cast.STRICT)
  323. }
  324. return result
  325. }