更新依赖项,添加单元测试,优化用户信息和令牌管理功能,增强中间件以支持 JWT 验证,新增访问令牌模型并更新数据库迁移。
This commit is contained in:
3
go.mod
3
go.mod
@@ -8,6 +8,7 @@ require (
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/datatypes v1.2.5
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
@@ -16,12 +17,14 @@ require (
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gorm.io/driver/mysql v1.5.6 // indirect
|
||||
)
|
||||
|
||||
|
||||
3
go.sum
3
go.sum
@@ -96,8 +96,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
|
||||
@@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -154,7 +153,7 @@ func (h *OIDCHandler) Userinfo(c *gin.Context) {
|
||||
|
||||
// 准备返回的声明
|
||||
claims := gin.H{
|
||||
"sub": fmt.Sprintf("%d", user.ID),
|
||||
"sub": user.Username,
|
||||
}
|
||||
|
||||
// 根据授权范围添加相应的声明
|
||||
@@ -164,6 +163,7 @@ func (h *OIDCHandler) Userinfo(c *gin.Context) {
|
||||
claims["name"] = user.Username
|
||||
case "email":
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
2
main.go
2
main.go
@@ -87,7 +87,7 @@ func main() {
|
||||
r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration)
|
||||
r.GET("/authorize", oidcHandler.Authorize)
|
||||
r.POST("/token", oidcHandler.Token)
|
||||
r.GET("/userinfo", oidcHandler.Userinfo)
|
||||
r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo)
|
||||
r.GET("/jwks", oidcHandler.JWKS)
|
||||
|
||||
// 客户端注册端点
|
||||
|
||||
425
main_test.go
Normal file
425
main_test.go
Normal file
@@ -0,0 +1,425 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"oidc-oauth2-server/handlers"
|
||||
"oidc-oauth2-server/middleware"
|
||||
"oidc-oauth2-server/models"
|
||||
"oidc-oauth2-server/services"
|
||||
)
|
||||
|
||||
func setupTestRouter() (*gin.Engine, *gorm.DB, error) {
|
||||
// 设置测试模式
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 初始化测试数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 运行数据库迁移
|
||||
if err := models.AutoMigrate(db); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 初始化服务
|
||||
authService := services.NewAuthService(db)
|
||||
oauthService, err := services.NewOAuthService(db)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
clientService := services.NewClientService(db)
|
||||
tokenService := services.NewTokenService(db, oauthService.GetKeyManager())
|
||||
|
||||
// 设置路由
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
|
||||
// 设置模板
|
||||
r.SetFuncMap(template.FuncMap{
|
||||
"subtract": func(a, b int) int { return a - b },
|
||||
"add": func(a, b int) int { return a + b },
|
||||
})
|
||||
r.LoadHTMLGlob("templates/*")
|
||||
|
||||
// 设置会话存储
|
||||
store := cookie.NewStore([]byte("test_secret"))
|
||||
r.Use(sessions.Sessions("oidc_session", store))
|
||||
|
||||
// 设置路由处理器
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// 设置认证相关路由
|
||||
authHandler := handlers.NewAuthHandler(authService)
|
||||
r.GET("/login", authHandler.ShowLogin)
|
||||
r.POST("/login", authHandler.HandleLogin)
|
||||
r.GET("/signup", authHandler.ShowSignup)
|
||||
r.POST("/signup", authHandler.HandleSignup)
|
||||
|
||||
// 设置 OIDC 相关路由
|
||||
oidcHandler := handlers.NewOIDCHandler("http://localhost:8080", oauthService, authService)
|
||||
r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration)
|
||||
r.GET("/authorize", oidcHandler.Authorize)
|
||||
r.POST("/token", oidcHandler.Token)
|
||||
r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo)
|
||||
r.GET("/jwks", oidcHandler.JWKS)
|
||||
|
||||
// 设置客户端注册路由
|
||||
registrationHandler := handlers.NewRegistrationHandler(clientService)
|
||||
r.POST("/register", registrationHandler.Register)
|
||||
|
||||
// 设置令牌管理路由
|
||||
tokenHandler := handlers.NewTokenHandler(tokenService)
|
||||
r.POST("/revoke", tokenHandler.Revoke)
|
||||
r.POST("/introspect", tokenHandler.Introspect)
|
||||
|
||||
return r, db, nil
|
||||
}
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ok", response["status"])
|
||||
}
|
||||
|
||||
func TestLoginPage(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/login", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestLoginAuthentication(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试用户
|
||||
password := "testpass"
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testUser := &models.User{
|
||||
Username: "testuser",
|
||||
Password: string(hashedPassword),
|
||||
Email: "test@example.com",
|
||||
}
|
||||
result := db.Create(testUser)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 测试登录请求 - 使用表单数据
|
||||
form := url.Values{}
|
||||
form.Add("username", "testuser")
|
||||
form.Add("password", "wrongpass")
|
||||
form.Add("client_id", "test_client")
|
||||
form.Add("response_type", "code")
|
||||
form.Add("scope", "openid profile")
|
||||
form.Add("redirect_uri", "http://localhost:8080/callback")
|
||||
form.Add("state", "test_state")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 检查是否返回错误页面
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "用户名或密码错误")
|
||||
|
||||
// 测试成功登录
|
||||
form.Set("password", password) // 使用正确的密码
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 检查重定向
|
||||
assert.Equal(t, 302, w.Code)
|
||||
location := w.Header().Get("Location")
|
||||
assert.Contains(t, location, "/authorize")
|
||||
}
|
||||
|
||||
func TestOIDCConfiguration(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response, "issuer")
|
||||
}
|
||||
|
||||
func TestUserRegistration(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 测试注册请求
|
||||
form := url.Values{}
|
||||
form.Add("username", "newuser")
|
||||
form.Add("password", "newpass")
|
||||
form.Add("email", "newuser@example.com")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/signup", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 检查重定向到登录页面
|
||||
assert.Equal(t, 302, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Location"), "/login")
|
||||
|
||||
// 验证用户是否已创建
|
||||
var user models.User
|
||||
err = db.Where("username = ?", "newuser").First(&user).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newuser@example.com", user.Email)
|
||||
}
|
||||
|
||||
func TestTokenEndpoint(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试客户端
|
||||
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
|
||||
grantTypes, _ := json.Marshal([]string{"authorization_code"})
|
||||
responseTypes, _ := json.Marshal([]string{"code"})
|
||||
client := &models.Client{
|
||||
ClientID: "test_client",
|
||||
ClientSecret: "test_secret",
|
||||
RedirectURIs: datatypes.JSON(redirectURIs),
|
||||
GrantTypes: datatypes.JSON(grantTypes),
|
||||
ResponseTypes: datatypes.JSON(responseTypes),
|
||||
TokenEndpointAuthMethod: "client_secret_post",
|
||||
}
|
||||
result := db.Create(client)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试用户
|
||||
user := &models.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
result = db.Create(user)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试授权码
|
||||
code := &models.AuthorizationCode{
|
||||
Code: "test_code",
|
||||
ClientID: client.ClientID,
|
||||
UserID: user.ID,
|
||||
Scope: "openid profile email",
|
||||
RedirectURI: "http://localhost:8080/callback",
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
result = db.Create(code)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 测试令牌请求
|
||||
form := url.Values{}
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code", code.Code)
|
||||
form.Add("client_id", client.ClientID)
|
||||
form.Add("client_secret", client.ClientSecret)
|
||||
form.Add("redirect_uri", code.RedirectURI)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response, "access_token")
|
||||
assert.Contains(t, response, "id_token")
|
||||
|
||||
// 使用获取到的访问令牌测试用户信息端点
|
||||
accessToken := response["access_token"].(string)
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "/userinfo", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Userinfo response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var userinfoResponse map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &userinfoResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "testuser", userinfoResponse["sub"])
|
||||
assert.Equal(t, "test@example.com", userinfoResponse["email"])
|
||||
}
|
||||
|
||||
func TestUserinfoEndpoint(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试用户
|
||||
user := &models.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
result := db.Create(user)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试客户端
|
||||
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
|
||||
grantTypes, _ := json.Marshal([]string{"authorization_code"})
|
||||
responseTypes, _ := json.Marshal([]string{"code"})
|
||||
client := &models.Client{
|
||||
ClientID: "test_client",
|
||||
ClientSecret: "test_secret",
|
||||
RedirectURIs: datatypes.JSON(redirectURIs),
|
||||
GrantTypes: datatypes.JSON(grantTypes),
|
||||
ResponseTypes: datatypes.JSON(responseTypes),
|
||||
TokenEndpointAuthMethod: "client_secret_post",
|
||||
}
|
||||
result = db.Create(client)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试授权码
|
||||
code := &models.AuthorizationCode{
|
||||
Code: "test_code",
|
||||
ClientID: client.ClientID,
|
||||
UserID: user.ID,
|
||||
Scope: "openid profile email",
|
||||
RedirectURI: "http://localhost:8080/callback",
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
result = db.Create(code)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 测试令牌请求
|
||||
form := url.Values{}
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code", code.Code)
|
||||
form.Add("client_id", client.ClientID)
|
||||
form.Add("client_secret", client.ClientSecret)
|
||||
form.Add("redirect_uri", code.RedirectURI)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var tokenResponse map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &tokenResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, tokenResponse, "access_token")
|
||||
|
||||
// 测试用户信息请求
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "/userinfo", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenResponse["access_token"].(string))
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "testuser", response["sub"])
|
||||
assert.Equal(t, "test@example.com", response["email"])
|
||||
}
|
||||
|
||||
func TestClientRegistration(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 测试客户端注册请求
|
||||
registrationRequest := map[string]interface{}{
|
||||
"client_name": "Test Client",
|
||||
"redirect_uris": []string{"http://localhost:8080/callback"},
|
||||
"response_types": []string{"code"},
|
||||
"grant_types": []string{"authorization_code"},
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(registrationRequest)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response, "client_id")
|
||||
assert.Contains(t, response, "client_secret")
|
||||
}
|
||||
@@ -2,13 +2,18 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"oidc-oauth2-server/models"
|
||||
"oidc-oauth2-server/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func BearerAuth(jwtSecret []byte) gin.HandlerFunc {
|
||||
func BearerAuth(keyManager *services.KeyManager, db *gorm.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
@@ -30,10 +35,10 @@ func BearerAuth(jwtSecret []byte) gin.HandlerFunc {
|
||||
// 解析和验证令牌
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名算法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return jwtSecret, nil
|
||||
return keyManager.GetPrivateKey().Public(), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -50,8 +55,18 @@ func BearerAuth(jwtSecret []byte) gin.HandlerFunc {
|
||||
|
||||
// 将令牌中的声明存储在上下文中
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
c.Set("user_id", uint(claims["sub"].(float64)))
|
||||
c.Set("scope", claims["scope"].(string))
|
||||
if sub, ok := claims["sub"].(string); ok {
|
||||
if userID, err := strconv.ParseUint(sub, 10, 32); err == nil {
|
||||
var user models.User
|
||||
if err := db.Where("id = ?", uint(userID)).First(&user).Error; err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set("user_id", user.ID)
|
||||
c.Set("scope", claims["scope"].(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
||||
22
models/access_token.go
Normal file
22
models/access_token.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AccessToken 表示 OAuth2 访问令牌
|
||||
type AccessToken struct {
|
||||
gorm.Model
|
||||
Token string `gorm:"uniqueIndex;not null"`
|
||||
UserID uint `gorm:"not null"`
|
||||
ClientID string `gorm:"not null"`
|
||||
Scope string `gorm:"type:text"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
IsRevoked bool `gorm:"default:false"`
|
||||
}
|
||||
|
||||
func (t *AccessToken) TableName() string {
|
||||
return "oauth_access_tokens"
|
||||
}
|
||||
@@ -11,5 +11,6 @@ func AutoMigrate(db *gorm.DB) error {
|
||||
&Client{},
|
||||
&AuthorizationCode{},
|
||||
&Admin{},
|
||||
&AccessToken{},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -174,7 +174,7 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error)
|
||||
}
|
||||
|
||||
// 生成访问令牌
|
||||
accessToken, err := s.generateAccessToken(user, client, authCode.Scope)
|
||||
accessToken, err := s.GenerateAccessToken(user, client, authCode.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -197,17 +197,18 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OAuthService) generateAccessToken(user *models.User, client *models.Client, scope string) (string, error) {
|
||||
func (s *OAuthService) GenerateAccessToken(user *models.User, client *models.Client, scope string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"sub": fmt.Sprintf("%d", user.ID),
|
||||
"iss": client.ClientID,
|
||||
"aud": client.ClientID,
|
||||
"exp": now.Add(s.tokenTTL).Unix(),
|
||||
"iat": now.Unix(),
|
||||
"iss": client.ClientID,
|
||||
"scope": scope,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
return token.SignedString(s.keyManager.GetPrivateKey())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user