完善 OAuth2TokenServiceImplTest 单元测试

pull/2/head
YunaiV 2022-05-25 00:19:46 +08:00
parent 65d2dffe1a
commit 522d70a29b
6 changed files with 330 additions and 5 deletions

View File

@ -9,7 +9,6 @@ import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler; import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.experimental.Accessors;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
@ -22,11 +21,10 @@ import java.util.List;
* *
* @author * @author
*/ */
@TableName("system_oauth2_access_token") @TableName(value = "system_oauth2_access_token", autoResultMap = true)
@KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 @KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Accessors(chain = true)
public class OAuth2AccessTokenDO extends TenantBaseDO { public class OAuth2AccessTokenDO extends TenantBaseDO {
/** /**

View File

@ -18,7 +18,7 @@ import java.util.List;
* *
* @author * @author
*/ */
@TableName("system_oauth2_refresh_token") @TableName(value = "system_oauth2_refresh_token", autoResultMap = true)
// 由于 Oracle 的 SEQ 的名字长度有限制,所以就先用 system_oauth2_access_token_seq 吧,反正也没啥问题 // 由于 Oracle 的 SEQ 的名字长度有限制,所以就先用 system_oauth2_access_token_seq 吧,反正也没啥问题
@KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 @KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
@Data @Data

View File

@ -25,7 +25,7 @@ public interface OAuth2AccessTokenMapper extends BaseMapperX<OAuth2AccessTokenDO
return selectPage(reqVO, new LambdaQueryWrapperX<OAuth2AccessTokenDO>() return selectPage(reqVO, new LambdaQueryWrapperX<OAuth2AccessTokenDO>()
.eqIfPresent(OAuth2AccessTokenDO::getUserId, reqVO.getUserId()) .eqIfPresent(OAuth2AccessTokenDO::getUserId, reqVO.getUserId())
.eqIfPresent(OAuth2AccessTokenDO::getUserType, reqVO.getUserType()) .eqIfPresent(OAuth2AccessTokenDO::getUserType, reqVO.getUserType())
.eqIfPresent(OAuth2AccessTokenDO::getClientId, reqVO.getClientId()) .likeIfPresent(OAuth2AccessTokenDO::getClientId, reqVO.getClientId())
.gt(OAuth2AccessTokenDO::getExpiresTime, new Date()) .gt(OAuth2AccessTokenDO::getExpiresTime, new Date())
.orderByDesc(OAuth2AccessTokenDO::getId)); .orderByDesc(OAuth2AccessTokenDO::getId));
} }

View File

@ -0,0 +1,289 @@
package cn.iocoder.yudao.module.system.service.oauth2;
import cn.hutool.core.util.RandomUtil;
import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.date.DateUtils;
import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder;
import cn.iocoder.yudao.framework.test.core.ut.BaseDbAndRedisUnitTest;
import cn.iocoder.yudao.module.system.controller.admin.oauth2.vo.token.OAuth2AccessTokenPageReqVO;
import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2AccessTokenDO;
import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ClientDO;
import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2RefreshTokenDO;
import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2AccessTokenMapper;
import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2RefreshTokenMapper;
import cn.iocoder.yudao.module.system.dal.redis.oauth2.OAuth2AccessTokenRedisDAO;
import org.assertj.core.util.Lists;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Import;
import javax.annotation.Resource;
import java.time.Duration;
import java.util.Date;
import java.util.List;
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.addTime;
import static cn.iocoder.yudao.framework.common.util.object.ObjectUtils.cloneIgnoreId;
import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertPojoEquals;
import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertServiceException;
import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
/**
* {@link OAuth2TokenServiceImpl}
*
* @author
*/
@Import({OAuth2TokenServiceImpl.class, OAuth2AccessTokenRedisDAO.class})
public class OAuth2TokenServiceImplTest extends BaseDbAndRedisUnitTest {
@Resource
private OAuth2TokenServiceImpl oauth2TokenService;
@Resource
private OAuth2AccessTokenMapper oauth2AccessTokenMapper;
@Resource
private OAuth2RefreshTokenMapper oauth2RefreshTokenMapper;
@Resource
private OAuth2AccessTokenRedisDAO oauth2AccessTokenRedisDAO;
@MockBean
private OAuth2ClientService oauth2ClientService;
@Test
public void testCreateAccessToken() {
TenantContextHolder.setTenantId(0L);
// 准备参数
Long userId = randomLongId();
Integer userType = RandomUtil.randomEle(UserTypeEnum.values()).getValue();
String clientId = randomString();
List<String> scopes = Lists.newArrayList("read", "write");
// mock 方法
OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId)
.setAccessTokenValiditySeconds(30).setRefreshTokenValiditySeconds(60);
when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
// 调用
OAuth2AccessTokenDO accessTokenDO = oauth2TokenService.createAccessToken(userId, userType, clientId, scopes);
// 断言访问令牌
OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken());
assertPojoEquals(accessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted");
assertEquals(userId, accessTokenDO.getUserId());
assertEquals(userType, accessTokenDO.getUserType());
assertEquals(clientId, accessTokenDO.getClientId());
assertEquals(scopes, accessTokenDO.getScopes());
assertFalse(DateUtils.isExpired(accessTokenDO.getExpiresTime()));
// 断言访问令牌的缓存
OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken());
assertPojoEquals(accessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted");
// 断言刷新令牌
OAuth2RefreshTokenDO refreshTokenDO = oauth2RefreshTokenMapper.selectList().get(0);
assertPojoEquals(accessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted");
assertFalse(DateUtils.isExpired(refreshTokenDO.getExpiresTime()));
}
@Test
public void testRefreshAccessToken_null() {
// 准备参数
String refreshToken = randomString();
String clientId = randomString();
// mock 方法
// 调用,并断言
assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
new ErrorCode(400, "无效的刷新令牌"));
}
@Test
public void testRefreshAccessToken_clientIdError() {
// 准备参数
String refreshToken = randomString();
String clientId = randomString();
// mock 方法
OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId);
when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
// mock 数据(访问令牌)
OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
.setRefreshToken(refreshToken).setClientId("error");
oauth2RefreshTokenMapper.insert(refreshTokenDO);
// 调用,并断言
assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
new ErrorCode(400, "刷新令牌的客户端编号不正确"));
}
@Test
public void testRefreshAccessToken_expired() {
// 准备参数
String refreshToken = randomString();
String clientId = randomString();
// mock 方法
OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId);
when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
// mock 数据(访问令牌)
OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
.setRefreshToken(refreshToken).setClientId(clientId)
.setExpiresTime(addTime(Duration.ofDays(-1)));
oauth2RefreshTokenMapper.insert(refreshTokenDO);
// 调用,并断言
assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
new ErrorCode(401, "刷新令牌已过期"));
}
@Test
public void testRefreshAccessToken_success() {
TenantContextHolder.setTenantId(0L);
// 准备参数
String refreshToken = randomString();
String clientId = randomString();
// mock 方法
OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId)
.setAccessTokenValiditySeconds(30);
when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
// mock 数据(访问令牌)
OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
.setRefreshToken(refreshToken).setClientId(clientId)
.setExpiresTime(addTime(Duration.ofDays(1)));
oauth2RefreshTokenMapper.insert(refreshTokenDO);
// mock 数据(访问令牌)
OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class).setRefreshToken(refreshToken);
oauth2AccessTokenMapper.insert(accessTokenDO);
oauth2AccessTokenRedisDAO.set(accessTokenDO);
// 调用
OAuth2AccessTokenDO newAccessTokenDO = oauth2TokenService.refreshAccessToken(refreshToken, clientId);
// 断言,老的访问令牌被删除
assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken()));
assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken()));
// 断言,新的访问令牌
OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(newAccessTokenDO.getAccessToken());
assertPojoEquals(newAccessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted");
assertPojoEquals(newAccessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted",
"creator", "updater");
assertFalse(DateUtils.isExpired(newAccessTokenDO.getExpiresTime()));
// 断言,新的访问令牌的缓存
OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(newAccessTokenDO.getAccessToken());
assertPojoEquals(newAccessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted");
}
@Test
public void testGetAccessToken() {
// mock 数据(访问令牌)
OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
.setExpiresTime(addTime(Duration.ofDays(1)));
oauth2AccessTokenMapper.insert(accessTokenDO);
// 准备参数
String accessToken = accessTokenDO.getAccessToken();
// 调用
OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken);
// 断言
assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
"creator", "updater");
assertPojoEquals(accessTokenDO, oauth2AccessTokenRedisDAO.get(accessToken), "createTime", "updateTime", "deleted",
"creator", "updater");
}
@Test
public void testCheckAccessToken_null() {
// 调研,并断言
assertServiceException(() -> oauth2TokenService.checkAccessToken(randomString()),
new ErrorCode(401, "访问令牌不存在"));
}
@Test
public void testCheckAccessToken_expired() {
// mock 数据(访问令牌)
OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
.setExpiresTime(addTime(Duration.ofDays(-1)));
oauth2AccessTokenMapper.insert(accessTokenDO);
// 准备参数
String accessToken = accessTokenDO.getAccessToken();
// 调研,并断言
assertServiceException(() -> oauth2TokenService.checkAccessToken(accessToken),
new ErrorCode(401, "访问令牌已过期"));
}
@Test
public void testCheckAccessToken_success() {
// mock 数据(访问令牌)
OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
.setExpiresTime(addTime(Duration.ofDays(1)));
oauth2AccessTokenMapper.insert(accessTokenDO);
// 准备参数
String accessToken = accessTokenDO.getAccessToken();
// 调研,并断言
OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken);
// 断言
assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
"creator", "updater");
}
@Test
public void testRemoveAccessToken_null() {
// 调用,并断言
assertNull(oauth2TokenService.removeAccessToken(randomString()));
}
@Test
public void testRemoveAccessToken_success() {
// mock 数据(访问令牌)
OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
.setExpiresTime(addTime(Duration.ofDays(1)));
oauth2AccessTokenMapper.insert(accessTokenDO);
// mock 数据(刷新令牌)
OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
.setRefreshToken(accessTokenDO.getRefreshToken());
oauth2RefreshTokenMapper.insert(refreshTokenDO);
// 调用
OAuth2AccessTokenDO result = oauth2TokenService.removeAccessToken(accessTokenDO.getAccessToken());
assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
"creator", "updater");
// 断言数据
assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken()));
assertNull(oauth2RefreshTokenMapper.selectByRefreshToken(accessTokenDO.getRefreshToken()));
assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken()));
}
@Test
public void testGetAccessTokenPage() {
// mock 数据
OAuth2AccessTokenDO dbAccessToken = randomPojo(OAuth2AccessTokenDO.class, o -> { // 等会查询到
o.setUserId(10L);
o.setUserType(1);
o.setClientId("test_client");
o.setExpiresTime(DateUtils.addTime(Duration.ofDays(1)));
});
oauth2AccessTokenMapper.insert(dbAccessToken);
// 测试 userId 不匹配
oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserId(20L)));
// 测试 userType 不匹配
oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserType(2)));
// 测试 userType 不匹配
oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setClientId("it_client")));
// 测试 expireTime 不匹配
oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setExpiresTime(new Date())));
// 准备参数
OAuth2AccessTokenPageReqVO reqVO = new OAuth2AccessTokenPageReqVO();
reqVO.setUserId(10L);
reqVO.setUserType(1);
reqVO.setClientId("test");
// 调用
PageResult<OAuth2AccessTokenDO> pageResult = oauth2TokenService.getAccessTokenPage(reqVO);
// 断言
assertEquals(1, pageResult.getTotal());
assertEquals(1, pageResult.getList().size());
assertPojoEquals(dbAccessToken, pageResult.getList().get(0));
}
}

View File

@ -22,3 +22,5 @@ DELETE FROM "system_tenant_package";
DELETE FROM "system_sensitive_word"; DELETE FROM "system_sensitive_word";
DELETE FROM "system_oauth2_client"; DELETE FROM "system_oauth2_client";
DELETE FROM "system_oauth2_approve"; DELETE FROM "system_oauth2_approve";
DELETE FROM "system_oauth2_access_token";
DELETE FROM "system_oauth2_refresh_token";

View File

@ -511,3 +511,39 @@ CREATE TABLE IF NOT EXISTS "system_oauth2_approve" (
"deleted" bit NOT NULL DEFAULT FALSE, "deleted" bit NOT NULL DEFAULT FALSE,
PRIMARY KEY ("id") PRIMARY KEY ("id")
) COMMENT 'OAuth2 '; ) COMMENT 'OAuth2 ';
CREATE TABLE IF NOT EXISTS "system_oauth2_access_token" (
"id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY,
"user_id" bigint NOT NULL,
"user_type" tinyint NOT NULL,
"access_token" varchar NOT NULL,
"refresh_token" varchar NOT NULL,
"client_id" varchar NOT NULL,
"scopes" varchar NOT NULL,
"approved" bit NOT NULL DEFAULT FALSE,
"expires_time" datetime NOT NULL,
"creator" varchar DEFAULT '',
"create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updater" varchar DEFAULT '',
"update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
"deleted" bit NOT NULL DEFAULT FALSE,
"tenant_id" bigint NOT NULL,
PRIMARY KEY ("id")
) COMMENT 'OAuth2 访';
CREATE TABLE IF NOT EXISTS "system_oauth2_refresh_token" (
"id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY,
"user_id" bigint NOT NULL,
"user_type" tinyint NOT NULL,
"refresh_token" varchar NOT NULL,
"client_id" varchar NOT NULL,
"scopes" varchar NOT NULL,
"approved" bit NOT NULL DEFAULT FALSE,
"expires_time" datetime NOT NULL,
"creator" varchar DEFAULT '',
"create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updater" varchar DEFAULT '',
"update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
"deleted" bit NOT NULL DEFAULT FALSE,
PRIMARY KEY ("id")
) COMMENT 'OAuth2 ';