Browse Source

sms 缓存,使用 guava 替代 job 扫描,目的:提升启动速度,加快缓存失效

YunaiV 1 year ago
parent
commit
48fd4972da

+ 2 - 2
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/convert/sms/SmsChannelConvert.java

@@ -32,8 +32,8 @@ public interface SmsChannelConvert {
 
     PageResult<SmsChannelRespVO> convertPage(PageResult<SmsChannelDO> page);
 
-    List<SmsChannelProperties> convertList02(List<SmsChannelDO> list);
-
     List<SmsChannelSimpleRespVO> convertList03(List<SmsChannelDO> list);
 
+    SmsChannelProperties convert02(SmsChannelDO channel);
+
 }

+ 3 - 5
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/mysql/sms/SmsChannelMapper.java

@@ -6,9 +6,6 @@ import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelPageReqVO;
 import cn.iocoder.yudao.module.system.dal.dataobject.sms.SmsChannelDO;
 import org.apache.ibatis.annotations.Mapper;
-import org.apache.ibatis.annotations.Select;
-
-import java.time.LocalDateTime;
 
 @Mapper
 public interface SmsChannelMapper extends BaseMapperX<SmsChannelDO> {
@@ -21,7 +18,8 @@ public interface SmsChannelMapper extends BaseMapperX<SmsChannelDO> {
                 .orderByDesc(SmsChannelDO::getId));
     }
 
-    @Select("SELECT COUNT(*) FROM system_sms_channel WHERE update_time > #{maxUpdateTime}")
-    Long selectCountByUpdateTimeGt(LocalDateTime maxTime);
+    default SmsChannelDO selectByCode(String code) {
+        return selectOne(SmsChannelDO::getCode, code);
+    }
 
 }

+ 18 - 6
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/sms/SmsChannelService.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.system.service.sms;
 
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.sms.core.client.SmsClient;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelCreateReqVO;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelPageReqVO;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelUpdateReqVO;
@@ -13,16 +14,11 @@ import java.util.List;
  * 短信渠道 Service 接口
  *
  * @author zzf
- * @date 2021/1/25 9:24
+ * @since 2021/1/25 9:24
  */
 public interface SmsChannelService {
 
     /**
-     * 初始化短信客户端
-     */
-    void initLocalCache();
-
-    /**
      * 创建短信渠道
      *
      * @param createReqVO 创建信息
@@ -67,4 +63,20 @@ public interface SmsChannelService {
      */
     PageResult<SmsChannelDO> getSmsChannelPage(SmsChannelPageReqVO pageReqVO);
 
+    /**
+     * 获得短信客户端
+     *
+     * @param id 编号
+     * @return 短信客户端
+     */
+    SmsClient getSmsClient(Long id);
+
+    /**
+     * 获得短信客户端
+     *
+     * @param code 编码
+     * @return 短信客户端
+     */
+    SmsClient getSmsClient(String code);
+
 }

+ 80 - 55
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/sms/SmsChannelServiceImpl.java

@@ -1,7 +1,9 @@
 package cn.iocoder.yudao.module.system.service.sms;
 
-import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.common.util.cache.CacheUtils;
+import cn.iocoder.yudao.framework.sms.core.client.SmsClient;
 import cn.iocoder.yudao.framework.sms.core.client.SmsClientFactory;
 import cn.iocoder.yudao.framework.sms.core.property.SmsChannelProperties;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelCreateReqVO;
@@ -10,20 +12,18 @@ import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannel
 import cn.iocoder.yudao.module.system.convert.sms.SmsChannelConvert;
 import cn.iocoder.yudao.module.system.dal.dataobject.sms.SmsChannelDO;
 import cn.iocoder.yudao.module.system.dal.mysql.sms.SmsChannelMapper;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
 import lombok.Getter;
+import lombok.SneakyThrows;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.scheduling.annotation.Scheduled;
 import org.springframework.stereotype.Service;
 
-import javax.annotation.PostConstruct;
 import javax.annotation.Resource;
-import java.time.LocalDateTime;
-import java.util.Collections;
+import java.time.Duration;
 import java.util.List;
-import java.util.concurrent.TimeUnit;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
-import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.getMaxValue;
 import static cn.iocoder.yudao.module.system.enums.ErrorCodeConstants.SMS_CHANNEL_HAS_CHILDREN;
 import static cn.iocoder.yudao.module.system.enums.ErrorCodeConstants.SMS_CHANNEL_NOT_EXISTS;
 
@@ -37,10 +37,44 @@ import static cn.iocoder.yudao.module.system.enums.ErrorCodeConstants.SMS_CHANNE
 public class SmsChannelServiceImpl implements SmsChannelService {
 
     /**
-     * 短信渠道列表的缓存
+     * {@link SmsClient} 缓存,通过它异步刷新 smsClientFactory
      */
     @Getter
-    private volatile List<SmsChannelDO> channelCache = Collections.emptyList();
+    private final LoadingCache<Long, SmsClient> idClientCache = CacheUtils.buildAsyncReloadingCache(Duration.ofSeconds(10L),
+            new CacheLoader<Long, SmsClient>() {
+
+                @Override
+                public SmsClient load(Long id) {
+                    // 查询,然后尝试刷新
+                    SmsChannelDO channel = smsChannelMapper.selectById(id);
+                    if (channel != null) {
+                        SmsChannelProperties properties = SmsChannelConvert.INSTANCE.convert02(channel);
+                        smsClientFactory.createOrUpdateSmsClient(properties);
+                    }
+                    return smsClientFactory.getSmsClient(id);
+                }
+
+            });
+
+    /**
+     * {@link SmsClient} 缓存,通过它异步刷新 smsClientFactory
+     */
+    @Getter
+    private final LoadingCache<String, SmsClient> codeClientCache = CacheUtils.buildAsyncReloadingCache(Duration.ofSeconds(60L),
+            new CacheLoader<String, SmsClient>() {
+
+                @Override
+                public SmsClient load(String code) {
+                    // 查询,然后尝试刷新
+                    SmsChannelDO channel = smsChannelMapper.selectByCode(code);
+                    if (channel != null) {
+                        SmsChannelProperties properties = SmsChannelConvert.INSTANCE.convert02(channel);
+                        smsClientFactory.createOrUpdateSmsClient(properties);
+                    }
+                    return smsClientFactory.getSmsClient(code);
+                }
+
+            });
 
     @Resource
     private SmsClientFactory smsClientFactory;
@@ -52,65 +86,32 @@ public class SmsChannelServiceImpl implements SmsChannelService {
     private SmsTemplateService smsTemplateService;
 
     @Override
-    @PostConstruct
-    public void initLocalCache() {
-        // 第一步:查询数据
-        List<SmsChannelDO> channels = smsChannelMapper.selectList();
-        log.info("[initLocalCache][缓存短信渠道,数量为:{}]", channels.size());
-
-        // 第二步:构建缓存:创建或更新短信 Client
-        List<SmsChannelProperties> propertiesList = SmsChannelConvert.INSTANCE.convertList02(channels);
-        propertiesList.forEach(properties -> smsClientFactory.createOrUpdateSmsClient(properties));
-        this.channelCache = channels;
-    }
-
-    /**
-     * 通过定时任务轮询,刷新缓存
-     *
-     * 目的:多节点部署时,通过轮询”通知“所有节点,进行刷新
-     */
-    @Scheduled(initialDelay = 60, fixedRate = 60, timeUnit = TimeUnit.SECONDS)
-    public void refreshLocalCache() {
-        // 情况一:如果缓存里没有数据,则直接刷新缓存
-        if (CollUtil.isEmpty(channelCache)) {
-            initLocalCache();
-            return;
-        }
-
-        // 情况二,如果缓存里数据,则通过 updateTime 判断是否有数据变更,有变更则刷新缓存
-        LocalDateTime maxTime = getMaxValue(channelCache, SmsChannelDO::getUpdateTime);
-        if (smsChannelMapper.selectCountByUpdateTimeGt(maxTime) > 0) {
-            initLocalCache();
-        }
-    }
-
-    @Override
     public Long createSmsChannel(SmsChannelCreateReqVO createReqVO) {
         // 插入
-        SmsChannelDO smsChannel = SmsChannelConvert.INSTANCE.convert(createReqVO);
-        smsChannelMapper.insert(smsChannel);
+        SmsChannelDO channel = SmsChannelConvert.INSTANCE.convert(createReqVO);
+        smsChannelMapper.insert(channel);
 
-        // 刷新缓存
-        initLocalCache();
-        return smsChannel.getId();
+        // 清空缓存
+        clearCache(channel.getId(), null);
+        return channel.getId();
     }
 
     @Override
     public void updateSmsChannel(SmsChannelUpdateReqVO updateReqVO) {
         // 校验存在
-        validateSmsChannelExists(updateReqVO.getId());
+        SmsChannelDO channel = validateSmsChannelExists(updateReqVO.getId());
         // 更新
         SmsChannelDO updateObj = SmsChannelConvert.INSTANCE.convert(updateReqVO);
         smsChannelMapper.updateById(updateObj);
 
-        // 刷新缓存
-        initLocalCache();
+        // 清空缓存
+        clearCache(updateReqVO.getId(), channel.getCode());
     }
 
     @Override
     public void deleteSmsChannel(Long id) {
         // 校验存在
-        validateSmsChannelExists(id);
+        SmsChannelDO channel = validateSmsChannelExists(id);
         // 校验是否有在使用该账号的模版
         if (smsTemplateService.countByChannelId(id) > 0) {
             throw exception(SMS_CHANNEL_HAS_CHILDREN);
@@ -118,14 +119,28 @@ public class SmsChannelServiceImpl implements SmsChannelService {
         // 删除
         smsChannelMapper.deleteById(id);
 
-        // 刷新缓存
-        initLocalCache();
+        // 清空缓存
+        clearCache(id, channel.getCode());
+    }
+
+    /**
+     * 清空指定渠道编号的缓存
+     *
+     * @param id 渠道编号
+     */
+    private void clearCache(Long id, String code) {
+        idClientCache.invalidate(id);
+        if (StrUtil.isNotEmpty(code)) {
+            codeClientCache.invalidate(code);
+        }
     }
 
-    private void validateSmsChannelExists(Long id) {
-        if (smsChannelMapper.selectById(id) == null) {
+    private SmsChannelDO validateSmsChannelExists(Long id) {
+        SmsChannelDO channel = smsChannelMapper.selectById(id);
+        if (channel == null) {
             throw exception(SMS_CHANNEL_NOT_EXISTS);
         }
+        return channel;
     }
 
     @Override
@@ -143,4 +158,14 @@ public class SmsChannelServiceImpl implements SmsChannelService {
         return smsChannelMapper.selectPage(pageReqVO);
     }
 
+    @Override
+    public SmsClient getSmsClient(Long id) {
+        return idClientCache.getUnchecked(id);
+    }
+
+    @Override
+    public SmsClient getSmsClient(String code) {
+        return codeClientCache.getUnchecked(code);
+    }
+
 }

+ 3 - 8
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/service/sms/SmsSendServiceImpl.java

@@ -8,7 +8,6 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
 import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
 import cn.iocoder.yudao.framework.sms.core.client.SmsClient;
-import cn.iocoder.yudao.framework.sms.core.client.SmsClientFactory;
 import cn.iocoder.yudao.framework.sms.core.client.SmsCommonResult;
 import cn.iocoder.yudao.framework.sms.core.client.dto.SmsReceiveRespDTO;
 import cn.iocoder.yudao.framework.sms.core.client.dto.SmsSendRespDTO;
@@ -50,9 +49,6 @@ public class SmsSendServiceImpl implements SmsSendService {
     private SmsLogService smsLogService;
 
     @Resource
-    private SmsClientFactory smsClientFactory;
-
-    @Resource
     private SmsProducer smsProducer;
 
     @Override
@@ -95,7 +91,6 @@ public class SmsSendServiceImpl implements SmsSendService {
         // 创建发送日志。如果模板被禁用,则不发送短信,只记录日志
         Boolean isSend = CommonStatusEnum.ENABLE.getStatus().equals(template.getStatus())
                 && CommonStatusEnum.ENABLE.getStatus().equals(smsChannel.getStatus());
-        ;
         String content = smsTemplateService.formatSmsTemplateContent(template.getContent(), templateParams);
         Long sendLogId = smsLogService.createSmsLog(mobile, userId, userType, isSend, template, content, templateParams);
 
@@ -132,7 +127,7 @@ public class SmsSendServiceImpl implements SmsSendService {
     /**
      * 将参数模板,处理成有序的 KeyValue 数组
      * <p>
-     * 原因是,部分短信平台并不是使用 key 作为参数,而是数组下标,例如说腾讯云 https://cloud.tencent.com/document/product/382/39023
+     * 原因是,部分短信平台并不是使用 key 作为参数,而是数组下标,例如说 <a href="https://cloud.tencent.com/document/product/382/39023">腾讯云</a>
      *
      * @param template       短信模板
      * @param templateParams 原始参数
@@ -160,7 +155,7 @@ public class SmsSendServiceImpl implements SmsSendService {
     @Override
     public void doSendSms(SmsSendMessage message) {
         // 获得渠道对应的 SmsClient 客户端
-        SmsClient smsClient = smsClientFactory.getSmsClient(message.getChannelId());
+        SmsClient smsClient = smsChannelService.getSmsClient(message.getChannelId());
         Assert.notNull(smsClient, "短信客户端({}) 不存在", message.getChannelId());
         // 发送短信
         SmsCommonResult<SmsSendRespDTO> sendResult = smsClient.sendSms(message.getLogId(), message.getMobile(),
@@ -173,7 +168,7 @@ public class SmsSendServiceImpl implements SmsSendService {
     @Override
     public void receiveSmsStatus(String channelCode, String text) throws Throwable {
         // 获得渠道对应的 SmsClient 客户端
-        SmsClient smsClient = smsClientFactory.getSmsClient(channelCode);
+        SmsClient smsClient = smsChannelService.getSmsClient(channelCode);
         Assert.notNull(smsClient, "短信客户端({}) 不存在", channelCode);
         // 解析内容
         List<SmsReceiveRespDTO> receiveResults = smsClient.parseSmsReceiveStatus(text);

+ 56 - 22
yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/sms/SmsChannelServiceTest.java

@@ -2,11 +2,14 @@ package cn.iocoder.yudao.module.system.service.sms;
 
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.sms.core.client.SmsClient;
 import cn.iocoder.yudao.framework.sms.core.client.SmsClientFactory;
+import cn.iocoder.yudao.framework.sms.core.property.SmsChannelProperties;
 import cn.iocoder.yudao.framework.test.core.ut.BaseDbUnitTest;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelCreateReqVO;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelPageReqVO;
 import cn.iocoder.yudao.module.system.controller.admin.sms.vo.channel.SmsChannelUpdateReqVO;
+import cn.iocoder.yudao.module.system.convert.sms.SmsChannelConvert;
 import cn.iocoder.yudao.module.system.dal.dataobject.sms.SmsChannelDO;
 import cn.iocoder.yudao.module.system.dal.mysql.sms.SmsChannelMapper;
 import org.junit.jupiter.api.Test;
@@ -19,7 +22,8 @@ import java.util.List;
 import static cn.iocoder.yudao.framework.common.util.date.LocalDateTimeUtils.buildBetweenTime;
 import static cn.iocoder.yudao.framework.common.util.date.LocalDateTimeUtils.buildTime;
 import static cn.iocoder.yudao.framework.common.util.object.ObjectUtils.cloneIgnoreId;
-import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.*;
+import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertPojoEquals;
+import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertServiceException;
 import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.*;
 import static cn.iocoder.yudao.module.system.enums.ErrorCodeConstants.SMS_CHANNEL_HAS_CHILDREN;
 import static cn.iocoder.yudao.module.system.enums.ErrorCodeConstants.SMS_CHANNEL_NOT_EXISTS;
@@ -42,27 +46,6 @@ public class SmsChannelServiceTest extends BaseDbUnitTest {
     private SmsTemplateService smsTemplateService;
 
     @Test
-    public void testInitLocalCache_success() {
-        // mock 数据
-        SmsChannelDO smsChannelDO01 = randomPojo(SmsChannelDO.class);
-        smsChannelMapper.insert(smsChannelDO01);
-        SmsChannelDO smsChannelDO02 = randomPojo(SmsChannelDO.class);
-        smsChannelMapper.insert(smsChannelDO02);
-
-        // 调用
-        smsChannelService.initLocalCache();
-        // 校验调用
-        verify(smsClientFactory, times(1)).createOrUpdateSmsClient(
-                argThat(properties -> isPojoEquals(smsChannelDO01, properties)));
-        verify(smsClientFactory, times(1)).createOrUpdateSmsClient(
-                argThat(properties -> isPojoEquals(smsChannelDO02, properties)));
-        // 断言 channelCache 缓存
-        assertEquals(2, smsChannelService.getChannelCache().size());
-        assertPojoEquals(smsChannelDO01, smsChannelService.getChannelCache().get(0));
-        assertPojoEquals(smsChannelDO02, smsChannelService.getChannelCache().get(1));
-    }
-
-    @Test
     public void testCreateSmsChannel_success() {
         // 准备参数
         SmsChannelCreateReqVO reqVO = randomPojo(SmsChannelCreateReqVO.class, o -> o.setStatus(randomCommonStatus()));
@@ -74,6 +57,9 @@ public class SmsChannelServiceTest extends BaseDbUnitTest {
         // 校验记录的属性是否正确
         SmsChannelDO smsChannel = smsChannelMapper.selectById(smsChannelId);
         assertPojoEquals(reqVO, smsChannel);
+        // 断言 cache
+        assertNull(smsChannelService.getIdClientCache().getIfPresent(smsChannel.getId()));
+        assertNull(smsChannelService.getCodeClientCache().getIfPresent(smsChannel.getCode()));
     }
 
     @Test
@@ -93,6 +79,9 @@ public class SmsChannelServiceTest extends BaseDbUnitTest {
         // 校验是否更新正确
         SmsChannelDO smsChannel = smsChannelMapper.selectById(reqVO.getId()); // 获取最新的
         assertPojoEquals(reqVO, smsChannel);
+        // 断言 cache
+        assertNull(smsChannelService.getIdClientCache().getIfPresent(smsChannel.getId()));
+        assertNull(smsChannelService.getCodeClientCache().getIfPresent(smsChannel.getCode()));
     }
 
     @Test
@@ -116,6 +105,9 @@ public class SmsChannelServiceTest extends BaseDbUnitTest {
         smsChannelService.deleteSmsChannel(id);
         // 校验数据不存在了
         assertNull(smsChannelMapper.selectById(id));
+        // 断言 cache
+        assertNull(smsChannelService.getIdClientCache().getIfPresent(dbSmsChannel.getId()));
+        assertNull(smsChannelService.getCodeClientCache().getIfPresent(dbSmsChannel.getCode()));
     }
 
     @Test
@@ -199,4 +191,46 @@ public class SmsChannelServiceTest extends BaseDbUnitTest {
        assertPojoEquals(dbSmsChannel, pageResult.getList().get(0));
     }
 
+    @Test
+    public void testGetSmsClient_id() {
+        // mock 数据
+        SmsChannelDO channel = randomPojo(SmsChannelDO.class);
+        smsChannelMapper.insert(channel);
+        // mock 参数
+        Long id = channel.getId();
+        // mock 方法
+        SmsClient mockClient = mock(SmsClient.class);
+        when(smsClientFactory.getSmsClient(eq(id))).thenReturn(mockClient);
+
+        // 调用
+        SmsClient client = smsChannelService.getSmsClient(id);
+        // 断言
+        assertSame(client, mockClient);
+        verify(smsClientFactory).createOrUpdateSmsClient(argThat(arg -> {
+            SmsChannelProperties properties = SmsChannelConvert.INSTANCE.convert02(channel);
+            return properties.equals(arg);
+        }));
+    }
+
+    @Test
+    public void testGetSmsClient_code() {
+        // mock 数据
+        SmsChannelDO channel = randomPojo(SmsChannelDO.class);
+        smsChannelMapper.insert(channel);
+        // mock 参数
+        String code = channel.getCode();
+        // mock 方法
+        SmsClient mockClient = mock(SmsClient.class);
+        when(smsClientFactory.getSmsClient(eq(code))).thenReturn(mockClient);
+
+        // 调用
+        SmsClient client = smsChannelService.getSmsClient(code);
+        // 断言
+        assertSame(client, mockClient);
+        verify(smsClientFactory).createOrUpdateSmsClient(argThat(arg -> {
+            SmsChannelProperties properties = SmsChannelConvert.INSTANCE.convert02(channel);
+            return properties.equals(arg);
+        }));
+    }
+
 }

+ 2 - 6
yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/sms/SmsSendServiceImplTest.java

@@ -5,7 +5,6 @@ import cn.iocoder.yudao.framework.common.core.KeyValue;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
 import cn.iocoder.yudao.framework.sms.core.client.SmsClient;
-import cn.iocoder.yudao.framework.sms.core.client.SmsClientFactory;
 import cn.iocoder.yudao.framework.sms.core.client.SmsCommonResult;
 import cn.iocoder.yudao.framework.sms.core.client.dto.SmsReceiveRespDTO;
 import cn.iocoder.yudao.framework.sms.core.client.dto.SmsSendRespDTO;
@@ -52,9 +51,6 @@ public class SmsSendServiceImplTest extends BaseMockitoUnitTest {
     @Mock
     private SmsProducer smsProducer;
 
-    @Mock
-    private SmsClientFactory smsClientFactory;
-
     @Test
     public void testSendSingleSmsToAdmin() {
         // 准备参数
@@ -253,7 +249,7 @@ public class SmsSendServiceImplTest extends BaseMockitoUnitTest {
         SmsSendMessage message = randomPojo(SmsSendMessage.class);
         // mock SmsClientFactory 的方法
         SmsClient smsClient = spy(SmsClient.class);
-        when(smsClientFactory.getSmsClient(eq(message.getChannelId()))).thenReturn(smsClient);
+        when(smsChannelService.getSmsClient(eq(message.getChannelId()))).thenReturn(smsClient);
         // mock SmsClient 的方法
         SmsCommonResult<SmsSendRespDTO> sendResult = randomPojo(SmsCommonResult.class, SmsSendRespDTO.class);
         sendResult.setData(randomPojo(SmsSendRespDTO.class));
@@ -275,7 +271,7 @@ public class SmsSendServiceImplTest extends BaseMockitoUnitTest {
         String text = randomString();
         // mock SmsClientFactory 的方法
         SmsClient smsClient = spy(SmsClient.class);
-        when(smsClientFactory.getSmsClient(eq(channelCode))).thenReturn(smsClient);
+        when(smsChannelService.getSmsClient(eq(channelCode))).thenReturn(smsClient);
         // mock SmsClient 的方法
         List<SmsReceiveRespDTO> receiveResults = randomPojoList(SmsReceiveRespDTO.class);