springboot项目中使用shiro 自定义过滤器和token的方式
实现步骤主要是以下几步:
1. 在项目中导入maven依赖
<dependency><groupId>org.apache.shiro</groupId><artifactId>shiro-core</artifactId><version>1.4.0</version></dependency><dependency><groupId>org.apache.shiro</groupId><artifactId>shiro-spring</artifactId><version>1.4.0</version></dependency>
2. shiro的核心配置类和代码
- 自定义(令牌实体)token
package com.ratel.fast.modules.sys.oauth2;
import org.apache.shiro.authc.AuthenticationToken;
/*** token*/
public class OAuth2Token implements AuthenticationToken {private String token;public OAuth2Token(String token){this.token = token;}@Overridepublic String getPrincipal() {return token;}@Overridepublic Object getCredentials() {return token;}
}
- token的生成工具
package com.ratel.fast.modules.sys.oauth2;
import com.ratel.fast.common.exception.RRException;
import java.security.MessageDigest;
import java.util.UUID;
/*** 生成token*/
public class TokenGenerator {public static String generateValue() {return generateValue(UUID.randomUUID().toString());}private static final char[] hexCode = "0123456789abcdef".toCharArray();public static String toHexString(byte[] data) {if(data == null) {return null;}StringBuilder r = new StringBuilder(data.length*2);for ( byte b : data) {r.append(hexCode[(b >> 4) & 0xF]);r.append(hexCode[(b & 0xF)]);}return r.toString();}public static String generateValue(String param) {try {MessageDigest algorithm = MessageDigest.getInstance("MD5");algorithm.reset();algorithm.update(param.getBytes());byte[] messageDigest = algorithm.digest();return toHexString(messageDigest);} catch (Exception e) {throw new RRException("生成Token失败", e);}}
}
3. 自定义的认证源
package com.ratel.fast.modules.sys.oauth2;
import com.ratel.fast.modules.sys.service.ShiroService;
import com.ratel.fast.modules.sys.entity.SysUserEntity;
import com.ratel.fast.modules.sys.entity.SysUserTokenEntity;
import com.ratel.fast.modules.sys.service.ShiroService;
import org.apache.shiro.authc.*;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.Set;/*** 认证***/
@Component
public class OAuth2Realm extends AuthorizingRealm {@Autowiredprivate ShiroService shiroService;@Overridepublic boolean supports(AuthenticationToken token) {return token instanceof OAuth2Token;}/*** 授权(验证权限时调用)* 前端在请求带@RequiresPermissions注解 注解的方法时会调用 doGetAuthorizationInfo 这个方法*/@Overrideprotected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {SysUserEntity user = (SysUserEntity)principals.getPrimaryPrincipal();Long userId = user.getUserId();//用户权限列表Set<String> permsSet = shiroService.getUserPermissions(userId);SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();info.setStringPermissions(permsSet);return info;}/*** 认证(登录时调用)* 每次请求的时候都会调用这个方法验证token是否失效和用户是否被锁定*/@Overrideprotected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException {String accessToken = (String) token.getPrincipal();//根据accessToken,查询用户信息SysUserTokenEntity tokenEntity = shiroService.queryByToken(accessToken);//token失效if(tokenEntity == null || tokenEntity.getExpireTime().getTime() < System.currentTimeMillis()){throw new IncorrectCredentialsException("token失效,请重新登录");}//查询用户信息SysUserEntity user = shiroService.queryUser(tokenEntity.getUserId());//账号锁定if(user.getStatus() == 0){throw new LockedAccountException("账号已被锁定,请联系管理员");}SimpleAuthenticationInfo info = new SimpleAuthenticationInfo(user, accessToken, getName());return info;}
}
4. 自定义的过滤器
package com.ratel.fast.modules.sys.oauth2;import com.google.gson.Gson;
import com.ratel.fast.common.utils.HttpContextUtils;
import com.ratel.fast.common.utils.R;
import org.apache.commons.lang.StringUtils;
import org.apache.http.HttpStatus;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.web.filter.authc.AuthenticatingFilter;
import org.springframework.web.bind.annotation.RequestMethod;import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;/*** oauth2过滤器***/
public class OAuth2Filter extends AuthenticatingFilter {@Overrideprotected AuthenticationToken createToken(ServletRequest request, ServletResponse response) throws Exception {//获取请求tokenString token = getRequestToken((HttpServletRequest) request);if(StringUtils.isBlank(token)){return null;}return new OAuth2Token(token);}/*** 判断用户是否已经登录,* 如果是options的请求则放行,否则进行调用onAccessDenied进行token认证流程* @param request* @param response* @param mappedValue* @return*/@Overrideprotected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {if(((HttpServletRequest) request).getMethod().equals(RequestMethod.OPTIONS.name())){return true;}return false;}@Overrideprotected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {//获取请求token,如果token不存在,直接返回401String token = getRequestToken((HttpServletRequest) request);if(StringUtils.isBlank(token)){HttpServletResponse httpResponse = (HttpServletResponse) response;httpResponse.setHeader("Access-Control-Allow-Credentials", "true");httpResponse.setHeader("Access-Control-Allow-Origin", HttpContextUtils.getOrigin());String json = new Gson().toJson(R.error(HttpStatus.SC_UNAUTHORIZED, "invalid token"));httpResponse.getWriter().print(json);return false;}return executeLogin(request, response);}@Overrideprotected boolean onLoginFailure(AuthenticationToken token, AuthenticationException e, ServletRequest request, ServletResponse response) {HttpServletResponse httpResponse = (HttpServletResponse) response;httpResponse.setContentType("application/json;charset=utf-8");httpResponse.setHeader("Access-Control-Allow-Credentials", "true");httpResponse.setHeader("Access-Control-Allow-Origin", HttpContextUtils.getOrigin());try {//处理登录失败的异常Throwable throwable = e.getCause() == null ? e : e.getCause();R r = R.error(HttpStatus.SC_UNAUTHORIZED, throwable.getMessage());String json = new Gson().toJson(r);httpResponse.getWriter().print(json);} catch (IOException e1) {}return false;}/*** 获取请求的token*/private String getRequestToken(HttpServletRequest httpRequest){//从header中获取tokenString token = httpRequest.getHeader("token");//如果header中不存在token,则从参数中获取tokenif(StringUtils.isBlank(token)){token = httpRequest.getParameter("token");}return token;}}
5. 核心配置类
package com.ratel.fast.config;import com.ratel.fast.modules.sys.oauth2.OAuth2Filter;
import com.ratel.fast.modules.sys.oauth2.OAuth2Realm;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.spring.LifecycleBeanPostProcessor;
import org.apache.shiro.spring.security.interceptor.AuthorizationAttributeSourceAdvisor;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;import javax.servlet.Filter;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;/*** Shiro配置***/
@Configuration
public class ShiroConfig {@Bean("securityManager")public SecurityManager securityManager(OAuth2Realm oAuth2Realm) {DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();securityManager.setRealm(oAuth2Realm);securityManager.setRememberMeManager(null);return securityManager;}/*** 这个bean的名字必须叫 shiroFilter ,否则启动的时候会报错* @Bean ("shiroFilter") 之后的括号可以不用写,spring默认方法名为的bean的名字* @param securityManager* @return*/@Bean("shiroFilter")public ShiroFilterFactoryBean shiroFilter(SecurityManager securityManager) {ShiroFilterFactoryBean shiroFilter = new ShiroFilterFactoryBean();shiroFilter.setSecurityManager(securityManager);//oauth过滤Map<String, Filter> filters = new HashMap<>();//添加自定义过滤器filters.put("oauth2", new OAuth2Filter());shiroFilter.setFilters(filters);Map<String, String> filterMap = new LinkedHashMap<>();filterMap.put("/webjars/**", "anon");filterMap.put("/druid/**", "anon");filterMap.put("/app/**", "anon");filterMap.put("/sys/login", "anon");filterMap.put("/swagger/**", "anon");filterMap.put("/v2/api-docs", "anon");filterMap.put("/swagger-ui.html", "anon");filterMap.put("/swagger-resources/**", "anon");filterMap.put("/captcha.jpg", "anon");filterMap.put("/aaa.txt", "anon");//使用自定义过滤器拦截除上边以外的所有请求filterMap.put("/**", "oauth2");shiroFilter.setFilterChainDefinitionMap(filterMap);return shiroFilter;}@Bean("lifecycleBeanPostProcessor")public LifecycleBeanPostProcessor lifecycleBeanPostProcessor() {return new LifecycleBeanPostProcessor();}@Beanpublic AuthorizationAttributeSourceAdvisor authorizationAttributeSourceAdvisor(SecurityManager securityManager) {AuthorizationAttributeSourceAdvisor advisor = new AuthorizationAttributeSourceAdvisor();advisor.setSecurityManager(securityManager);return advisor;}}
6. 用户token类 (用于往数据库中存储的时候用)
package com.ratel.fast.modules.sys.entity;import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.io.Serializable;
import java.util.Date;/*** 系统用户Token***/
@Data
@TableName("sys_user_token")
public class SysUserTokenEntity implements Serializable {private static final long serialVersionUID = 1L;//用户ID@TableId(type = IdType.INPUT)private Long userId;//tokenprivate String token;//过期时间private Date expireTime;//更新时间private Date updateTime;}
- 用户登录使用的Controller
package com.ratel.fast.modules.sys.controller;import com.ratel.fast.common.utils.R;
import com.ratel.fast.modules.sys.entity.SysUserEntity;
import com.ratel.fast.modules.sys.form.SysLoginForm;
import com.ratel.fast.modules.sys.service.SysCaptchaService;
import com.ratel.fast.modules.sys.service.SysUserService;
import com.ratel.fast.modules.sys.service.SysUserTokenService;
import org.apache.commons.io.IOUtils;
import org.apache.shiro.crypto.hash.Sha256Hash;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import javax.imageio.ImageIO;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.Map;/*** 登录相关***/
@RestController
public class SysLoginController extends AbstractController {@Autowiredprivate SysUserService sysUserService;@Autowiredprivate SysUserTokenService sysUserTokenService;@Autowiredprivate SysCaptchaService sysCaptchaService;/*** 验证码*/@GetMapping("captcha.jpg")public void captcha(HttpServletResponse response, String uuid)throws IOException {response.setHeader("Cache-Control", "no-store, no-cache");response.setContentType("image/jpeg");//获取图片验证码BufferedImage image = sysCaptchaService.getCaptcha(uuid);ServletOutputStream out = response.getOutputStream();ImageIO.write(image, "jpg", out);IOUtils.closeQuietly(out);}/*** 登录*/@PostMapping("/sys/login")public Map<String, Object> login(@RequestBody SysLoginForm form)throws IOException {boolean captcha = sysCaptchaService.validate(form.getUuid(), form.getCaptcha());if(!captcha){return R.error("验证码不正确");}//用户信息SysUserEntity user = sysUserService.queryByUserName(form.getUsername());//账号不存在、密码错误if(user == null || !user.getPassword().equals(new Sha256Hash(form.getPassword(), user.getSalt()).toHex())) {return R.error("账号或密码不正确");}//账号锁定if(user.getStatus() == 0){return R.error("账号已被锁定,请联系管理员");}//生成token,并保存到数据库R r = sysUserTokenService.createToken(user.getUserId());return r;}/*** 退出*/@PostMapping("/sys/logout")public R logout() {sysUserTokenService.logout(getUserId());return R.ok();}}
到此我们已经完成了shiro的认证过程的代码。
记住一点,Shiro 不会去维护用户、维护权限;这些需要我们自己去设计 / 提供;然后通过相应的接口注入给 Shiro 即可。