Quellcode durchsuchen

完成 SysSmsServiceTest 方法的单元测试

YunaiV vor 4 Jahren
Ursprung
Commit
71d2f74110

+ 9 - 4
src/main/java/cn/iocoder/dashboard/modules/system/service/sms/impl/SysSmsServiceImpl.java

@@ -18,6 +18,7 @@ import cn.iocoder.dashboard.modules.system.service.sms.SysSmsLogService;
 import cn.iocoder.dashboard.modules.system.service.sms.SysSmsService;
 import cn.iocoder.dashboard.modules.system.service.sms.SysSmsTemplateService;
 import cn.iocoder.dashboard.modules.system.service.user.SysUserService;
+import com.google.common.annotations.VisibleForTesting;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
 import org.springframework.util.Assert;
@@ -77,6 +78,8 @@ public class SysSmsServiceImpl implements SysSmsService {
         SysSmsTemplateDO template = this.checkSmsTemplateValid(templateCode);
         // 校验手机号码是否存在
         mobile = this.checkMobile(mobile);
+        // 构建有序的模板参数。为什么放在这个位置,是提前保证模板参数的正确性,而不是到了插入发送日志
+        List<KeyValue<String, Object>> newTemplateParams = this.buildTemplateParams(template, templateParams);
 
         // 创建发送日志
         Boolean isSend = CommonStatusEnum.ENABLE.getStatus().equals(template.getStatus()); // 如果模板被禁用,则不发送短信,只记录日志
@@ -85,7 +88,6 @@ public class SysSmsServiceImpl implements SysSmsService {
 
         // 发送 MQ 消息,异步执行发送短信
         if (isSend) {
-            List<KeyValue<String, Object>> newTemplateParams = this.buildTemplateParams(template, templateParams);
             smsProducer.sendSmsSendMessage(sendLogId, mobile, template.getChannelId(), template.getApiTemplateId(), newTemplateParams);
         }
         return sendLogId;
@@ -97,7 +99,8 @@ public class SysSmsServiceImpl implements SysSmsService {
         throw new UnsupportedOperationException("暂时不支持该操作,感兴趣可以实现该功能哟!");
     }
 
-    private SysSmsTemplateDO checkSmsTemplateValid(String templateCode) {
+    @VisibleForTesting
+    public SysSmsTemplateDO checkSmsTemplateValid(String templateCode) {
         // 获得短信模板。考虑到效率,从缓存中获取
         SysSmsTemplateDO template = smsTemplateService.getSmsTemplateByCodeFromCache(templateCode);
         // 短信模板不存在
@@ -116,7 +119,8 @@ public class SysSmsServiceImpl implements SysSmsService {
      * @param templateParams 原始参数
      * @return 处理后的参数
      */
-    private List<KeyValue<String, Object>> buildTemplateParams(SysSmsTemplateDO template, Map<String, Object> templateParams) {
+    @VisibleForTesting
+    public List<KeyValue<String, Object>> buildTemplateParams(SysSmsTemplateDO template, Map<String, Object> templateParams) {
         return template.getParams().stream().map(key -> {
             Object value = templateParams.get(key);
             if (value == null) {
@@ -126,7 +130,8 @@ public class SysSmsServiceImpl implements SysSmsService {
         }).collect(Collectors.toList());
     }
 
-    private String checkMobile(String mobile) {
+    @VisibleForTesting
+    public String checkMobile(String mobile) {
         if (StrUtil.isEmpty(mobile)) {
             throw exception(SMS_SEND_MOBILE_NOT_EXISTS);
         }

+ 13 - 0
src/test/java/cn/iocoder/dashboard/BaseMockitoUnitTest.java

@@ -0,0 +1,13 @@
+package cn.iocoder.dashboard;
+
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+/**
+ * 纯 Mockito 的单元测试
+ *
+ * @author 芋道源码
+ */
+@ExtendWith(MockitoExtension.class)
+public class BaseMockitoUnitTest {
+}

+ 201 - 0
src/test/java/cn/iocoder/dashboard/modules/system/service/sms/SysSmsServiceTest.java

@@ -0,0 +1,201 @@
+package cn.iocoder.dashboard.modules.system.service.sms;
+
+import cn.hutool.core.map.MapUtil;
+import cn.iocoder.dashboard.BaseMockitoUnitTest;
+import cn.iocoder.dashboard.common.core.KeyValue;
+import cn.iocoder.dashboard.common.enums.CommonStatusEnum;
+import cn.iocoder.dashboard.common.enums.UserTypeEnum;
+import cn.iocoder.dashboard.framework.sms.core.client.SmsClient;
+import cn.iocoder.dashboard.framework.sms.core.client.SmsClientFactory;
+import cn.iocoder.dashboard.framework.sms.core.client.SmsCommonResult;
+import cn.iocoder.dashboard.framework.sms.core.client.dto.SmsReceiveRespDTO;
+import cn.iocoder.dashboard.framework.sms.core.client.dto.SmsSendRespDTO;
+import cn.iocoder.dashboard.modules.system.dal.dataobject.sms.SysSmsTemplateDO;
+import cn.iocoder.dashboard.modules.system.mq.message.sms.SysSmsSendMessage;
+import cn.iocoder.dashboard.modules.system.mq.producer.sms.SysSmsProducer;
+import cn.iocoder.dashboard.modules.system.service.sms.impl.SysSmsServiceImpl;
+import org.assertj.core.util.Lists;
+import org.junit.jupiter.api.Test;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static cn.hutool.core.util.RandomUtil.randomEle;
+import static cn.iocoder.dashboard.modules.system.enums.SysErrorCodeConstants.*;
+import static cn.iocoder.dashboard.util.AssertUtils.assertServiceException;
+import static cn.iocoder.dashboard.util.RandomUtils.*;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.*;
+
+/**
+ * {@link SysSmsServiceImpl} 的单元测试类
+ *
+ * @author 芋道源码
+ */
+public class SysSmsServiceTest extends BaseMockitoUnitTest {
+
+    @InjectMocks
+    private SysSmsServiceImpl smsService;
+
+    @Mock
+    private SysSmsTemplateService smsTemplateService;
+    @Mock
+    private SysSmsLogService smsLogService;
+    @Mock
+    private SysSmsProducer smsProducer;
+    @Mock
+    private SmsClientFactory smsClientFactory;
+
+    /**
+     * 发送成功,当短信模板开启时
+     */
+    @Test
+    public void testSendSingleSms_successWhenSmsTemplateEnable() {
+        // 准备参数
+        String mobile = randomString();
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String templateCode = randomString();
+        Map<String, Object> templateParams = MapUtil.<String, Object>builder().put("code", "1234")
+                .put("op", "login").build();
+        // mock SmsTemplateService 的方法
+        SysSmsTemplateDO template = randomPojo(SysSmsTemplateDO.class, o -> {
+            o.setStatus(CommonStatusEnum.ENABLE.getStatus());
+            o.setContent("验证码为{code}, 操作为{op}");
+            o.setParams(Lists.newArrayList("code", "op"));
+        });
+        when(smsTemplateService.getSmsTemplateByCodeFromCache(eq(templateCode))).thenReturn(template);
+        String content = randomString();
+        when(smsTemplateService.formatSmsTemplateContent(eq(template.getContent()), eq(templateParams)))
+                .thenReturn(content);
+        // mock SmsLogService 的方法
+        Long smsLogId = randomLongId();
+        when(smsLogService.createSmsLog(eq(mobile), eq(userId), eq(userType), eq(Boolean.TRUE), eq(template),
+                eq(content), eq(templateParams))).thenReturn(smsLogId);
+
+        // 调用
+        Long resultSmsLogId = smsService.sendSingleSms(mobile, userId, userType, templateCode, templateParams);
+        // 断言
+        assertEquals(smsLogId, resultSmsLogId);
+        // 断言调用
+        verify(smsProducer, times(1)).sendSmsSendMessage(eq(smsLogId), eq(mobile),
+                eq(template.getChannelId()), eq(template.getApiTemplateId()),
+                eq(Lists.newArrayList(new KeyValue<>("code", "1234"), new KeyValue<>("op", "login"))));
+    }
+
+    /**
+     * 发送成功,当短信模板关闭时
+     */
+    @Test
+    public void testSendSingleSms_successWhenSmsTemplateDisable() {
+        // 准备参数
+        String mobile = randomString();
+        Long userId = randomLongId();
+        Integer userType = randomEle(UserTypeEnum.values()).getValue();
+        String templateCode = randomString();
+        Map<String, Object> templateParams = MapUtil.<String, Object>builder().put("code", "1234")
+                .put("op", "login").build();
+        // mock SmsTemplateService 的方法
+        SysSmsTemplateDO template = randomPojo(SysSmsTemplateDO.class, o -> {
+            o.setStatus(CommonStatusEnum.DISABLE.getStatus());
+            o.setContent("验证码为{code}, 操作为{op}");
+            o.setParams(Lists.newArrayList("code", "op"));
+        });
+        when(smsTemplateService.getSmsTemplateByCodeFromCache(eq(templateCode))).thenReturn(template);
+        String content = randomString();
+        when(smsTemplateService.formatSmsTemplateContent(eq(template.getContent()), eq(templateParams)))
+                .thenReturn(content);
+        // mock SmsLogService 的方法
+        Long smsLogId = randomLongId();
+        when(smsLogService.createSmsLog(eq(mobile), eq(userId), eq(userType), eq(Boolean.FALSE), eq(template),
+                eq(content), eq(templateParams))).thenReturn(smsLogId);
+
+        // 调用
+        Long resultSmsLogId = smsService.sendSingleSms(mobile, userId, userType, templateCode, templateParams);
+        // 断言
+        assertEquals(smsLogId, resultSmsLogId);
+        // 断言调用
+        verify(smsProducer, times(0)).sendSmsSendMessage(anyLong(), anyString(),
+                anyLong(), any(), anyList());
+    }
+
+    @Test
+    public void testCheckSmsTemplateValid_notExists() {
+        // 准备参数
+        String templateCode = randomString();
+        // mock 方法
+
+        // 调用,并断言异常
+        assertServiceException(() -> smsService.checkSmsTemplateValid(templateCode),
+                SMS_TEMPLATE_NOT_EXISTS);
+    }
+
+    @Test
+    public void testBuildTemplateParams_paramMiss() {
+        // 准备参数
+        SysSmsTemplateDO template = randomPojo(SysSmsTemplateDO.class,
+                o -> o.setParams(Lists.newArrayList("code")));
+        Map<String, Object> templateParams = new HashMap<>();
+        // mock 方法
+
+        // 调用,并断言异常
+        assertServiceException(() -> smsService.buildTemplateParams(template, templateParams),
+                SMS_SEND_MOBILE_TEMPLATE_PARAM_MISS, "code");
+    }
+
+    @Test
+    public void testCheckMobile_notExists() {
+        // 准备参数
+        // mock 方法
+
+        // 调用,并断言异常
+        assertServiceException(() -> smsService.checkMobile(null),
+                SMS_SEND_MOBILE_NOT_EXISTS);
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testDoSendSms() {
+        // 准备参数
+        SysSmsSendMessage message = randomPojo(SysSmsSendMessage.class);
+        // mock SmsClientFactory 的方法
+        SmsClient smsClient = spy(SmsClient.class);
+        when(smsClientFactory.getSmsClient(eq(message.getChannelId()))).thenReturn(smsClient);
+        // mock SmsClient 的方法
+        SmsCommonResult<SmsSendRespDTO> sendResult = randomPojo(SmsCommonResult.class, SmsSendRespDTO.class);
+        when(smsClient.sendSms(eq(message.getLogId()), eq(message.getMobile()), eq(message.getApiTemplateId()),
+                eq(message.getTemplateParams()))).thenReturn(sendResult);
+
+        // 调用
+        smsService.doSendSms(message);
+        // 断言
+        verify(smsLogService, times(1)).updateSmsSendResult(eq(message.getLogId()),
+                eq(sendResult.getCode()), eq(sendResult.getMsg()), eq(sendResult.getApiCode()),
+                eq(sendResult.getApiMsg()), eq(sendResult.getApiRequestId()), eq(sendResult.getData().getSerialNo()));
+    }
+
+    @Test
+    public void testReceiveSmsStatus() throws Throwable {
+        // 准备参数
+        String channelCode = randomString();
+        String text = randomString();
+        // mock SmsClientFactory 的方法
+        SmsClient smsClient = spy(SmsClient.class);
+        when(smsClientFactory.getSmsClient(eq(channelCode))).thenReturn(smsClient);
+        // mock SmsClient 的方法
+        List<SmsReceiveRespDTO> receiveResults = randomPojoList(SmsReceiveRespDTO.class);
+
+        // 调用
+        smsService.receiveSmsStatus(channelCode, text);
+        // 断言
+        receiveResults.forEach(result -> {
+            smsLogService.updateSmsReceiveResult(eq(result.getLogId()), eq(result.getSuccess()),
+                    eq(result.getReceiveTime()), eq(result.getErrorCode()), eq(result.getErrorCode()));
+        });
+    }
+
+}

+ 13 - 14
src/test/java/cn/iocoder/dashboard/util/AopTargetUtils.java

@@ -5,43 +5,42 @@ import org.springframework.aop.framework.AdvisedSupport;
 import org.springframework.aop.framework.AopProxy;
 import org.springframework.aop.support.AopUtils;
 
-import java.lang.reflect.Field;
-
 /**
- * http://www.bubuko.com/infodetail-3471885.html
+ * Spring AOP 工具类
+ *
+ * 参考波克尔 http://www.bubuko.com/infodetail-3471885.html 实现
  */
 public class AopTargetUtils {
 
     /**
-     * 获取 目标对象
+     * 获取代理的目标对象
      *
      * @param proxy 代理对象
-     * @return
-     * @throws Exception
+     * @return 目标对象
      */
     public static Object getTarget(Object proxy) throws Exception {
+        // 不是代理对象
         if (!AopUtils.isAopProxy(proxy)) {
-            return proxy; //不是代理对象
+            return proxy;
         }
+        // Jdk 代理
         if (AopUtils.isJdkDynamicProxy(proxy)) {
             return getJdkDynamicProxyTargetObject(proxy);
-        } else { //cglib
-            return getCglibProxyTargetObject(proxy);
         }
+        // Cglib 代理
+        return getCglibProxyTargetObject(proxy);
     }
 
     private static Object getCglibProxyTargetObject(Object proxy) throws Exception {
         Object dynamicAdvisedInterceptor = BeanUtil.getFieldValue(proxy, "CGLIB$CALLBACK_0");
         AdvisedSupport advisedSupport = (AdvisedSupport) BeanUtil.getFieldValue(dynamicAdvisedInterceptor, "advised");
-        Object target = advisedSupport.getTargetSource().getTarget();
-        return target;
+        return advisedSupport.getTargetSource().getTarget();
     }
 
     private static Object getJdkDynamicProxyTargetObject(Object proxy) throws Exception {
         AopProxy aopProxy = (AopProxy) BeanUtil.getFieldValue(proxy, "h");
         AdvisedSupport advisedSupport = (AdvisedSupport) BeanUtil.getFieldValue(aopProxy, "advised");
-        Object target = advisedSupport.getTargetSource().getTarget();
-        return target;
+        return advisedSupport.getTargetSource().getTarget();
     }
 
-}
+}

+ 8 - 3
src/test/java/cn/iocoder/dashboard/util/RandomUtils.java

@@ -8,9 +8,7 @@ import uk.co.jemos.podam.api.PodamFactory;
 import uk.co.jemos.podam.api.PodamFactoryImpl;
 
 import java.lang.reflect.Type;
-import java.util.Arrays;
-import java.util.Date;
-import java.util.Set;
+import java.util.*;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -98,4 +96,11 @@ public class RandomUtils {
         return pojo;
     }
 
+    @SafeVarargs
+    public static <T> List<T> randomPojoList(Class<T> clazz, Consumer<T>... consumers) {
+        int size = RandomUtil.randomInt(0, RANDOM_COLLECTION_LENGTH);
+        return Stream.iterate(0, i -> i).limit(size).map(o -> randomPojo(clazz, consumers))
+                .collect(Collectors.toList());
+    }
+
 }