更新依赖项,添加单元测试,优化用户信息和令牌管理功能,增强中间件以支持 JWT 验证,新增访问令牌模型并更新数据库迁移。

This commit is contained in:
2025-04-17 02:04:52 +08:00
parent 83c82f7135
commit d06e45e5d4
10 changed files with 482 additions and 14 deletions

3
go.mod
View File

@@ -8,6 +8,7 @@ require (
github.com/gin-contrib/sessions v0.0.5
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/stretchr/testify v1.10.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.5
gorm.io/driver/sqlite v1.5.7
@@ -16,12 +17,14 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gorm.io/driver/mysql v1.5.6 // indirect
)

3
go.sum
View File

@@ -96,8 +96,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=

View File

@@ -2,7 +2,6 @@ package handlers
import (
"crypto/rsa"
"fmt"
"net/http"
"net/url"
"strings"
@@ -154,7 +153,7 @@ func (h *OIDCHandler) Userinfo(c *gin.Context) {
// 准备返回的声明
claims := gin.H{
"sub": fmt.Sprintf("%d", user.ID),
"sub": user.Username,
}
// 根据授权范围添加相应的声明
@@ -164,6 +163,7 @@ func (h *OIDCHandler) Userinfo(c *gin.Context) {
claims["name"] = user.Username
case "email":
claims["email"] = user.Email
claims["email_verified"] = true
}
}

View File

@@ -87,7 +87,7 @@ func main() {
r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration)
r.GET("/authorize", oidcHandler.Authorize)
r.POST("/token", oidcHandler.Token)
r.GET("/userinfo", oidcHandler.Userinfo)
r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo)
r.GET("/jwks", oidcHandler.JWKS)
// 客户端注册端点

425
main_test.go Normal file
View File

@@ -0,0 +1,425 @@
package main
import (
"bytes"
"encoding/json"
"html/template"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"gorm.io/datatypes"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"oidc-oauth2-server/handlers"
"oidc-oauth2-server/middleware"
"oidc-oauth2-server/models"
"oidc-oauth2-server/services"
)
func setupTestRouter() (*gin.Engine, *gorm.DB, error) {
// 设置测试模式
gin.SetMode(gin.TestMode)
// 初始化测试数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
return nil, nil, err
}
// 运行数据库迁移
if err := models.AutoMigrate(db); err != nil {
return nil, nil, err
}
// 初始化服务
authService := services.NewAuthService(db)
oauthService, err := services.NewOAuthService(db)
if err != nil {
return nil, nil, err
}
clientService := services.NewClientService(db)
tokenService := services.NewTokenService(db, oauthService.GetKeyManager())
// 设置路由
r := gin.New()
r.Use(gin.Recovery())
// 设置模板
r.SetFuncMap(template.FuncMap{
"subtract": func(a, b int) int { return a - b },
"add": func(a, b int) int { return a + b },
})
r.LoadHTMLGlob("templates/*")
// 设置会话存储
store := cookie.NewStore([]byte("test_secret"))
r.Use(sessions.Sessions("oidc_session", store))
// 设置路由处理器
r.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})
// 设置认证相关路由
authHandler := handlers.NewAuthHandler(authService)
r.GET("/login", authHandler.ShowLogin)
r.POST("/login", authHandler.HandleLogin)
r.GET("/signup", authHandler.ShowSignup)
r.POST("/signup", authHandler.HandleSignup)
// 设置 OIDC 相关路由
oidcHandler := handlers.NewOIDCHandler("http://localhost:8080", oauthService, authService)
r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration)
r.GET("/authorize", oidcHandler.Authorize)
r.POST("/token", oidcHandler.Token)
r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo)
r.GET("/jwks", oidcHandler.JWKS)
// 设置客户端注册路由
registrationHandler := handlers.NewRegistrationHandler(clientService)
r.POST("/register", registrationHandler.Register)
// 设置令牌管理路由
tokenHandler := handlers.NewTokenHandler(tokenService)
r.POST("/revoke", tokenHandler.Revoke)
r.POST("/introspect", tokenHandler.Introspect)
return r, db, nil
}
func TestHealthCheck(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/health", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var response map[string]string
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "ok", response["status"])
}
func TestLoginPage(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/login", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestLoginAuthentication(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 创建测试用户
password := "testpass"
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
assert.NoError(t, err)
testUser := &models.User{
Username: "testuser",
Password: string(hashedPassword),
Email: "test@example.com",
}
result := db.Create(testUser)
assert.NoError(t, result.Error)
// 测试登录请求 - 使用表单数据
form := url.Values{}
form.Add("username", "testuser")
form.Add("password", "wrongpass")
form.Add("client_id", "test_client")
form.Add("response_type", "code")
form.Add("scope", "openid profile")
form.Add("redirect_uri", "http://localhost:8080/callback")
form.Add("state", "test_state")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
// 检查是否返回错误页面
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "用户名或密码错误")
// 测试成功登录
form.Set("password", password) // 使用正确的密码
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
// 检查重定向
assert.Equal(t, 302, w.Code)
location := w.Header().Get("Location")
assert.Contains(t, location, "/authorize")
}
func TestOIDCConfiguration(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/.well-known/openid-configuration", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Contains(t, response, "issuer")
}
func TestUserRegistration(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 测试注册请求
form := url.Values{}
form.Add("username", "newuser")
form.Add("password", "newpass")
form.Add("email", "newuser@example.com")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
// 检查重定向到登录页面
assert.Equal(t, 302, w.Code)
assert.Contains(t, w.Header().Get("Location"), "/login")
// 验证用户是否已创建
var user models.User
err = db.Where("username = ?", "newuser").First(&user).Error
assert.NoError(t, err)
assert.Equal(t, "newuser@example.com", user.Email)
}
func TestTokenEndpoint(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 创建测试客户端
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
grantTypes, _ := json.Marshal([]string{"authorization_code"})
responseTypes, _ := json.Marshal([]string{"code"})
client := &models.Client{
ClientID: "test_client",
ClientSecret: "test_secret",
RedirectURIs: datatypes.JSON(redirectURIs),
GrantTypes: datatypes.JSON(grantTypes),
ResponseTypes: datatypes.JSON(responseTypes),
TokenEndpointAuthMethod: "client_secret_post",
}
result := db.Create(client)
assert.NoError(t, result.Error)
// 创建测试用户
user := &models.User{
Username: "testuser",
Email: "test@example.com",
}
result = db.Create(user)
assert.NoError(t, result.Error)
// 创建测试授权码
code := &models.AuthorizationCode{
Code: "test_code",
ClientID: client.ClientID,
UserID: user.ID,
Scope: "openid profile email",
RedirectURI: "http://localhost:8080/callback",
ExpiresAt: time.Now().Add(10 * time.Minute),
Used: false,
}
result = db.Create(code)
assert.NoError(t, result.Error)
// 测试令牌请求
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("code", code.Code)
form.Add("client_id", client.ClientID)
form.Add("client_secret", client.ClientSecret)
form.Add("redirect_uri", code.RedirectURI)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Contains(t, response, "access_token")
assert.Contains(t, response, "id_token")
// 使用获取到的访问令牌测试用户信息端点
accessToken := response["access_token"].(string)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Userinfo response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var userinfoResponse map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &userinfoResponse)
assert.NoError(t, err)
assert.Equal(t, "testuser", userinfoResponse["sub"])
assert.Equal(t, "test@example.com", userinfoResponse["email"])
}
func TestUserinfoEndpoint(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 创建测试用户
user := &models.User{
Username: "testuser",
Email: "test@example.com",
}
result := db.Create(user)
assert.NoError(t, result.Error)
// 创建测试客户端
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
grantTypes, _ := json.Marshal([]string{"authorization_code"})
responseTypes, _ := json.Marshal([]string{"code"})
client := &models.Client{
ClientID: "test_client",
ClientSecret: "test_secret",
RedirectURIs: datatypes.JSON(redirectURIs),
GrantTypes: datatypes.JSON(grantTypes),
ResponseTypes: datatypes.JSON(responseTypes),
TokenEndpointAuthMethod: "client_secret_post",
}
result = db.Create(client)
assert.NoError(t, result.Error)
// 创建测试授权码
code := &models.AuthorizationCode{
Code: "test_code",
ClientID: client.ClientID,
UserID: user.ID,
Scope: "openid profile email",
RedirectURI: "http://localhost:8080/callback",
ExpiresAt: time.Now().Add(10 * time.Minute),
Used: false,
}
result = db.Create(code)
assert.NoError(t, result.Error)
// 测试令牌请求
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("code", code.Code)
form.Add("client_id", client.ClientID)
form.Add("client_secret", client.ClientSecret)
form.Add("redirect_uri", code.RedirectURI)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var tokenResponse map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &tokenResponse)
assert.NoError(t, err)
assert.Contains(t, tokenResponse, "access_token")
// 测试用户信息请求
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+tokenResponse["access_token"].(string))
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "testuser", response["sub"])
assert.Equal(t, "test@example.com", response["email"])
}
func TestClientRegistration(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 测试客户端注册请求
registrationRequest := map[string]interface{}{
"client_name": "Test Client",
"redirect_uris": []string{"http://localhost:8080/callback"},
"response_types": []string{"code"},
"grant_types": []string{"authorization_code"},
"token_endpoint_auth_method": "client_secret_basic",
}
jsonData, _ := json.Marshal(registrationRequest)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Contains(t, response, "client_id")
assert.Contains(t, response, "client_secret")
}

View File

@@ -2,13 +2,18 @@ package middleware
import (
"net/http"
"strconv"
"strings"
"oidc-oauth2-server/models"
"oidc-oauth2-server/services"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"gorm.io/gorm"
)
func BearerAuth(jwtSecret []byte) gin.HandlerFunc {
func BearerAuth(keyManager *services.KeyManager, db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
@@ -30,10 +35,10 @@ func BearerAuth(jwtSecret []byte) gin.HandlerFunc {
// 解析和验证令牌
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 验证签名算法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, jwt.ErrSignatureInvalid
}
return jwtSecret, nil
return keyManager.GetPrivateKey().Public(), nil
})
if err != nil {
@@ -50,8 +55,18 @@ func BearerAuth(jwtSecret []byte) gin.HandlerFunc {
// 将令牌中的声明存储在上下文中
if claims, ok := token.Claims.(jwt.MapClaims); ok {
c.Set("user_id", uint(claims["sub"].(float64)))
c.Set("scope", claims["scope"].(string))
if sub, ok := claims["sub"].(string); ok {
if userID, err := strconv.ParseUint(sub, 10, 32); err == nil {
var user models.User
if err := db.Where("id = ?", uint(userID)).First(&user).Error; err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
c.Abort()
return
}
c.Set("user_id", user.ID)
c.Set("scope", claims["scope"].(string))
}
}
}
c.Next()

22
models/access_token.go Normal file
View File

@@ -0,0 +1,22 @@
package models
import (
"time"
"gorm.io/gorm"
)
// AccessToken 表示 OAuth2 访问令牌
type AccessToken struct {
gorm.Model
Token string `gorm:"uniqueIndex;not null"`
UserID uint `gorm:"not null"`
ClientID string `gorm:"not null"`
Scope string `gorm:"type:text"`
ExpiresAt time.Time `gorm:"not null"`
IsRevoked bool `gorm:"default:false"`
}
func (t *AccessToken) TableName() string {
return "oauth_access_tokens"
}

View File

@@ -11,5 +11,6 @@ func AutoMigrate(db *gorm.DB) error {
&Client{},
&AuthorizationCode{},
&Admin{},
&AccessToken{},
)
}

View File

View File

@@ -174,7 +174,7 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error)
}
// 生成访问令牌
accessToken, err := s.generateAccessToken(user, client, authCode.Scope)
accessToken, err := s.GenerateAccessToken(user, client, authCode.Scope)
if err != nil {
return nil, err
}
@@ -197,17 +197,18 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error)
}, nil
}
func (s *OAuthService) generateAccessToken(user *models.User, client *models.Client, scope string) (string, error) {
func (s *OAuthService) GenerateAccessToken(user *models.User, client *models.Client, scope string) (string, error) {
now := time.Now()
claims := jwt.MapClaims{
"sub": user.ID,
"sub": fmt.Sprintf("%d", user.ID),
"iss": client.ClientID,
"aud": client.ClientID,
"exp": now.Add(s.tokenTTL).Unix(),
"iat": now.Unix(),
"iss": client.ClientID,
"scope": scope,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
return token.SignedString(s.keyManager.GetPrivateKey())
}