Prechádzať zdrojové kódy

feat(httppull): refresh token on pulling (#2018)

Signed-off-by: xjasonlyu <xjasonlyu@gmail.com>
Jason Lyu 1 rok pred
rodič
commit
88e58e1c50

+ 7 - 3
internal/io/http/client.go

@@ -35,9 +35,10 @@ import (
 // ClientConf is the configuration for http client
 // It is shared by httppull source and rest sink to configure their http client
 type ClientConf struct {
-	config      *RawConf
-	accessConf  *AccessTokenConf
-	refreshConf *RefreshTokenConf
+	config            *RawConf
+	accessConf        *AccessTokenConf
+	refreshConf       *RefreshTokenConf
+	tokenLastUpdateAt time.Time
 
 	tokens map[string]interface{}
 	client *http.Client
@@ -256,6 +257,8 @@ func (cc *ClientConf) auth(ctx api.StreamContext) error {
 			if err != nil {
 				return err
 			}
+		} else {
+			cc.tokenLastUpdateAt = time.Now()
 		}
 	} else {
 		return fmt.Errorf("fail to get access token: %v", e)
@@ -286,6 +289,7 @@ func (cc *ClientConf) refresh(ctx api.StreamContext) error {
 				cc.tokens[k] = v
 			}
 		}
+		cc.tokenLastUpdateAt = time.Now()
 		return nil
 	} else if cc.accessConf != nil {
 		return cc.auth(ctx)

+ 8 - 19
internal/io/http/httppull_source.go

@@ -20,7 +20,6 @@ import (
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/pkg/httpx"
 	"github.com/lf-edge/ekuiper/pkg/api"
-	"github.com/lf-edge/ekuiper/pkg/infra"
 )
 
 type PullSource struct {
@@ -34,24 +33,6 @@ func (hps *PullSource) Configure(device string, props map[string]interface{}) er
 
 func (hps *PullSource) Open(ctx api.StreamContext, consumer chan<- api.SourceTuple, errCh chan<- error) {
 	ctx.GetLogger().Infof("Opening HTTP pull source with conf %+v", hps.config)
-	// trigger refresh token timer
-	if hps.accessConf != nil && hps.accessConf.ExpireInSecond > 0 {
-		go infra.SafeRun(func() error {
-			ctx.GetLogger().Infof("Starting refresh token for %d seconds", hps.accessConf.ExpireInSecond/2)
-			ticker := time.NewTicker(time.Duration(hps.accessConf.ExpireInSecond/2) * time.Second)
-			defer ticker.Stop()
-			for {
-				select {
-				case <-ticker.C:
-					ctx.GetLogger().Debugf("Refreshing token")
-					hps.refresh(ctx)
-				case <-ctx.Done():
-					ctx.GetLogger().Infof("Closing refresh token timer")
-					return nil
-				}
-			}
-		})
-	}
 	hps.initTimerPull(ctx, consumer, errCh)
 }
 
@@ -75,6 +56,14 @@ func (hps *PullSource) initTimerPull(ctx api.StreamContext, consumer chan<- api.
 			if err != nil {
 				continue
 			}
+			// check oAuth token expiration
+			if hps.accessConf != nil && hps.accessConf.ExpireInSecond > 0 &&
+				int(time.Now().Sub(hps.tokenLastUpdateAt).Abs().Seconds()) >= hps.accessConf.ExpireInSecond {
+				ctx.GetLogger().Debugf("Refreshing token")
+				if err := hps.refresh(ctx); err != nil {
+					ctx.GetLogger().Warnf("Refresh token error: %v", err)
+				}
+			}
 			ctx.GetLogger().Debugf("rest sink sending request url: %s, headers: %v, body %s", hps.config.Url, headers, hps.config.Body)
 			if resp, e := httpx.Send(logger, hps.client, hps.config.BodyType, hps.config.Method, hps.config.Url, headers, true, []byte(hps.config.Body)); e != nil {
 				logger.Warnf("Found error %s when trying to reach %v ", e, hps)