fieldConverterSingleton.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. // Copyright 2022 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package protobuf
  15. import (
  16. "fmt"
  17. "github.com/golang/protobuf/proto"
  18. dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
  19. "github.com/jhump/protoreflect/desc"
  20. "github.com/jhump/protoreflect/dynamic"
  21. "github.com/lf-edge/ekuiper/pkg/cast"
  22. )
  23. const (
  24. WrapperBool = "google.protobuf.BoolValue"
  25. WrapperBytes = "google.protobuf.BytesValue"
  26. WrapperDouble = "google.protobuf.DoubleValue"
  27. WrapperFloat = "google.protobuf.FloatValue"
  28. WrapperInt32 = "google.protobuf.Int32Value"
  29. WrapperInt64 = "google.protobuf.Int64Value"
  30. WrapperString = "google.protobuf.StringValue"
  31. WrapperUInt32 = "google.protobuf.UInt32Value"
  32. WrapperUInt64 = "google.protobuf.UInt64Value"
  33. WrapperVoid = "google.protobuf.EMPTY"
  34. )
  35. var WRAPPER_TYPES = map[string]struct{}{
  36. WrapperBool: {},
  37. WrapperBytes: {},
  38. WrapperDouble: {},
  39. WrapperFloat: {},
  40. WrapperInt32: {},
  41. WrapperInt64: {},
  42. WrapperString: {},
  43. WrapperUInt32: {},
  44. WrapperUInt64: {},
  45. }
  46. var (
  47. fieldConverterIns = &FieldConverter{}
  48. mf = dynamic.NewMessageFactoryWithDefaults()
  49. )
  50. type FieldConverter struct{}
  51. func GetFieldConverter() *FieldConverter {
  52. return fieldConverterIns
  53. }
  54. func (fc *FieldConverter) encodeMap(im *desc.MessageDescriptor, i interface{}) (*dynamic.Message, error) {
  55. result := mf.NewDynamicMessage(im)
  56. fields := im.GetFields()
  57. if m, ok := i.(map[string]interface{}); ok {
  58. for _, field := range fields {
  59. v, ok := m[field.GetName()]
  60. if !ok {
  61. if field.IsRequired() {
  62. return nil, fmt.Errorf("field %s not found", field.GetName())
  63. } else {
  64. v = field.GetDefaultValue()
  65. }
  66. }
  67. fv, err := fc.EncodeField(field, v)
  68. if err != nil {
  69. return nil, err
  70. }
  71. result.SetFieldByName(field.GetName(), fv)
  72. }
  73. }
  74. return result, nil
  75. }
  76. func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
  77. fn := field.GetName()
  78. ft := field.GetType()
  79. if field.IsRepeated() {
  80. var (
  81. result interface{}
  82. err error
  83. )
  84. switch ft {
  85. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_DOUBLE):
  86. result, err = cast.ToFloat64Slice(v, cast.STRICT)
  87. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FLOAT):
  88. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  89. r, err := cast.ToFloat64(input, sn)
  90. if err != nil {
  91. return 0, nil
  92. } else {
  93. return float32(r), nil
  94. }
  95. }, "float", cast.STRICT)
  96. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT32):
  97. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  98. r, err := cast.ToInt(input, sn)
  99. if err != nil {
  100. return 0, nil
  101. } else {
  102. return int32(r), nil
  103. }
  104. }, "int", cast.STRICT)
  105. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT64):
  106. result, err = cast.ToInt64Slice(v, cast.STRICT)
  107. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT32):
  108. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  109. r, err := cast.ToUint64(input, sn)
  110. if err != nil {
  111. return 0, nil
  112. } else {
  113. return uint32(r), nil
  114. }
  115. }, "uint", cast.STRICT)
  116. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT64):
  117. result, err = cast.ToUint64Slice(v, cast.STRICT)
  118. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BOOL):
  119. result, err = cast.ToBoolSlice(v, cast.STRICT)
  120. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_STRING):
  121. result, err = cast.ToStringSlice(v, cast.STRICT)
  122. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BYTES):
  123. result, err = cast.ToBytesSlice(v, cast.STRICT)
  124. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_MESSAGE):
  125. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  126. r, err := cast.ToStringMap(v)
  127. if err == nil {
  128. return fc.encodeMap(field.GetMessageType(), r)
  129. } else {
  130. return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
  131. }
  132. }, "map", cast.STRICT)
  133. default:
  134. return nil, fmt.Errorf("invalid type for field '%s'", fn)
  135. }
  136. if err != nil {
  137. err = fmt.Errorf("failed to encode field '%s':%v", fn, err)
  138. }
  139. return result, err
  140. } else {
  141. return fc.encodeSingleField(field, v)
  142. }
  143. }
  144. func (fc *FieldConverter) encodeSingleField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
  145. fn := field.GetName()
  146. switch field.GetType() {
  147. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_DOUBLE):
  148. r, err := cast.ToFloat64(v, cast.STRICT)
  149. if err == nil {
  150. return r, nil
  151. } else {
  152. return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
  153. }
  154. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FLOAT):
  155. r, err := cast.ToFloat64(v, cast.STRICT)
  156. if err == nil {
  157. return float32(r), nil
  158. } else {
  159. return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
  160. }
  161. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT32):
  162. r, err := cast.ToInt(v, cast.STRICT)
  163. if err == nil {
  164. return int32(r), nil
  165. } else {
  166. return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
  167. }
  168. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT64):
  169. r, err := cast.ToInt64(v, cast.STRICT)
  170. if err == nil {
  171. return r, nil
  172. } else {
  173. return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
  174. }
  175. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT32):
  176. r, err := cast.ToUint64(v, cast.STRICT)
  177. if err == nil {
  178. return uint32(r), nil
  179. } else {
  180. return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
  181. }
  182. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT64):
  183. r, err := cast.ToUint64(v, cast.STRICT)
  184. if err == nil {
  185. return r, nil
  186. } else {
  187. return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
  188. }
  189. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BOOL):
  190. r, err := cast.ToBool(v, cast.STRICT)
  191. if err == nil {
  192. return r, nil
  193. } else {
  194. return nil, fmt.Errorf("invalid type for bool type field '%s': %v", fn, err)
  195. }
  196. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_STRING):
  197. r, err := cast.ToString(v, cast.STRICT)
  198. if err == nil {
  199. return r, nil
  200. } else {
  201. return nil, fmt.Errorf("invalid type for string type field '%s': %v", fn, err)
  202. }
  203. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BYTES):
  204. r, err := cast.ToBytes(v, cast.STRICT)
  205. if err == nil {
  206. return r, nil
  207. } else {
  208. return nil, fmt.Errorf("invalid type for bytes type field '%s': %v", fn, err)
  209. }
  210. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_MESSAGE):
  211. r, err := cast.ToStringMap(v)
  212. if err == nil {
  213. return fc.encodeMap(field.GetMessageType(), r)
  214. } else {
  215. return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
  216. }
  217. default:
  218. return nil, fmt.Errorf("invalid type for field '%s'", fn)
  219. }
  220. }
  221. func (fc *FieldConverter) DecodeField(src interface{}, field *desc.FieldDescriptor, sn cast.Strictness) (interface{}, error) {
  222. var (
  223. r interface{}
  224. e error
  225. )
  226. fn := field.GetName()
  227. switch field.GetType() {
  228. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_DOUBLE), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FLOAT):
  229. if field.IsRepeated() {
  230. r, e = cast.ToFloat64Slice(src, sn)
  231. } else {
  232. r, e = cast.ToFloat64(src, sn)
  233. }
  234. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_INT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SFIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_SINT64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT32), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_FIXED64), dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_UINT64):
  235. if field.IsRepeated() {
  236. r, e = cast.ToInt64Slice(src, sn)
  237. } else {
  238. r, e = cast.ToInt64(src, sn)
  239. }
  240. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BOOL):
  241. if field.IsRepeated() {
  242. r, e = cast.ToBoolSlice(src, sn)
  243. } else {
  244. r, e = cast.ToBool(src, sn)
  245. }
  246. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_STRING):
  247. if field.IsRepeated() {
  248. r, e = cast.ToStringSlice(src, sn)
  249. } else {
  250. r, e = cast.ToString(src, sn)
  251. }
  252. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_BYTES):
  253. if field.IsRepeated() {
  254. r, e = cast.ToBytesSlice(src, sn)
  255. } else {
  256. r, e = cast.ToBytes(src, sn)
  257. }
  258. case dpb.FieldDescriptorProto_Type(dpb.FieldDescriptorProto_TYPE_MESSAGE):
  259. if field.IsRepeated() {
  260. r, e = cast.ToTypedSlice(src, func(input interface{}, ssn cast.Strictness) (interface{}, error) {
  261. return fc.decodeSubMessage(input, field.GetMessageType(), ssn)
  262. }, "map", sn)
  263. } else {
  264. r, e = fc.decodeSubMessage(src, field.GetMessageType(), sn)
  265. }
  266. default:
  267. return nil, fmt.Errorf("unsupported type for %s", fn)
  268. }
  269. if e != nil {
  270. e = fmt.Errorf("invalid type of return value for '%s': %v", fn, e)
  271. }
  272. return r, e
  273. }
  274. func (fc *FieldConverter) decodeSubMessage(input interface{}, ft *desc.MessageDescriptor, sn cast.Strictness) (interface{}, error) {
  275. m := map[string]interface{}{}
  276. switch v := input.(type) {
  277. case map[interface{}]interface{}:
  278. for k, val := range v {
  279. m[cast.ToStringAlways(k)] = val
  280. }
  281. return fc.DecodeMap(m, ft, sn)
  282. case map[string]interface{}:
  283. return fc.DecodeMap(v, ft, sn)
  284. case proto.Message:
  285. message, err := dynamic.AsDynamicMessage(v)
  286. if err != nil {
  287. return nil, err
  288. }
  289. return fc.DecodeMessage(message, ft), nil
  290. default:
  291. return nil, fmt.Errorf("cannot decode %[1]T(%[1]v) to map", input)
  292. }
  293. }
  294. func (fc *FieldConverter) DecodeMap(src map[string]interface{}, ft *desc.MessageDescriptor, sn cast.Strictness) (map[string]interface{}, error) {
  295. result := make(map[string]interface{})
  296. for _, field := range ft.GetFields() {
  297. val, ok := src[field.GetName()]
  298. if !ok {
  299. continue
  300. }
  301. err := fc.decodeMessageField(val, field, result, sn)
  302. if err != nil {
  303. return nil, err
  304. }
  305. }
  306. return result, nil
  307. }
  308. func (fc *FieldConverter) decodeMessageField(src interface{}, field *desc.FieldDescriptor, result map[string]interface{}, sn cast.Strictness) error {
  309. if f, err := fc.DecodeField(src, field, sn); err != nil {
  310. return err
  311. } else {
  312. result[field.GetName()] = f
  313. return nil
  314. }
  315. }
  316. func (fc *FieldConverter) DecodeMessage(message *dynamic.Message, outputType *desc.MessageDescriptor) interface{} {
  317. if _, ok := WRAPPER_TYPES[outputType.GetFullyQualifiedName()]; ok {
  318. return message.GetFieldByNumber(1)
  319. } else if WrapperVoid == outputType.GetFullyQualifiedName() {
  320. return nil
  321. }
  322. result := make(map[string]interface{})
  323. for _, field := range outputType.GetFields() {
  324. fc.decodeMessageField(message.GetField(field), field, result, cast.STRICT)
  325. }
  326. return result
  327. }