更新依赖项,添加单元测试,优化用户信息和令牌管理功能,增强中间件以支持 JWT 验证,新增访问令牌模型并更新数据库迁移。
This commit is contained in:
425
main_test.go
Normal file
425
main_test.go
Normal file
@@ -0,0 +1,425 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"oidc-oauth2-server/handlers"
|
||||
"oidc-oauth2-server/middleware"
|
||||
"oidc-oauth2-server/models"
|
||||
"oidc-oauth2-server/services"
|
||||
)
|
||||
|
||||
func setupTestRouter() (*gin.Engine, *gorm.DB, error) {
|
||||
// 设置测试模式
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 初始化测试数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 运行数据库迁移
|
||||
if err := models.AutoMigrate(db); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 初始化服务
|
||||
authService := services.NewAuthService(db)
|
||||
oauthService, err := services.NewOAuthService(db)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
clientService := services.NewClientService(db)
|
||||
tokenService := services.NewTokenService(db, oauthService.GetKeyManager())
|
||||
|
||||
// 设置路由
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
|
||||
// 设置模板
|
||||
r.SetFuncMap(template.FuncMap{
|
||||
"subtract": func(a, b int) int { return a - b },
|
||||
"add": func(a, b int) int { return a + b },
|
||||
})
|
||||
r.LoadHTMLGlob("templates/*")
|
||||
|
||||
// 设置会话存储
|
||||
store := cookie.NewStore([]byte("test_secret"))
|
||||
r.Use(sessions.Sessions("oidc_session", store))
|
||||
|
||||
// 设置路由处理器
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// 设置认证相关路由
|
||||
authHandler := handlers.NewAuthHandler(authService)
|
||||
r.GET("/login", authHandler.ShowLogin)
|
||||
r.POST("/login", authHandler.HandleLogin)
|
||||
r.GET("/signup", authHandler.ShowSignup)
|
||||
r.POST("/signup", authHandler.HandleSignup)
|
||||
|
||||
// 设置 OIDC 相关路由
|
||||
oidcHandler := handlers.NewOIDCHandler("http://localhost:8080", oauthService, authService)
|
||||
r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration)
|
||||
r.GET("/authorize", oidcHandler.Authorize)
|
||||
r.POST("/token", oidcHandler.Token)
|
||||
r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo)
|
||||
r.GET("/jwks", oidcHandler.JWKS)
|
||||
|
||||
// 设置客户端注册路由
|
||||
registrationHandler := handlers.NewRegistrationHandler(clientService)
|
||||
r.POST("/register", registrationHandler.Register)
|
||||
|
||||
// 设置令牌管理路由
|
||||
tokenHandler := handlers.NewTokenHandler(tokenService)
|
||||
r.POST("/revoke", tokenHandler.Revoke)
|
||||
r.POST("/introspect", tokenHandler.Introspect)
|
||||
|
||||
return r, db, nil
|
||||
}
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]string
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ok", response["status"])
|
||||
}
|
||||
|
||||
func TestLoginPage(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/login", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestLoginAuthentication(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试用户
|
||||
password := "testpass"
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testUser := &models.User{
|
||||
Username: "testuser",
|
||||
Password: string(hashedPassword),
|
||||
Email: "test@example.com",
|
||||
}
|
||||
result := db.Create(testUser)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 测试登录请求 - 使用表单数据
|
||||
form := url.Values{}
|
||||
form.Add("username", "testuser")
|
||||
form.Add("password", "wrongpass")
|
||||
form.Add("client_id", "test_client")
|
||||
form.Add("response_type", "code")
|
||||
form.Add("scope", "openid profile")
|
||||
form.Add("redirect_uri", "http://localhost:8080/callback")
|
||||
form.Add("state", "test_state")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 检查是否返回错误页面
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "用户名或密码错误")
|
||||
|
||||
// 测试成功登录
|
||||
form.Set("password", password) // 使用正确的密码
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 检查重定向
|
||||
assert.Equal(t, 302, w.Code)
|
||||
location := w.Header().Get("Location")
|
||||
assert.Contains(t, location, "/authorize")
|
||||
}
|
||||
|
||||
func TestOIDCConfiguration(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response, "issuer")
|
||||
}
|
||||
|
||||
func TestUserRegistration(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 测试注册请求
|
||||
form := url.Values{}
|
||||
form.Add("username", "newuser")
|
||||
form.Add("password", "newpass")
|
||||
form.Add("email", "newuser@example.com")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/signup", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 检查重定向到登录页面
|
||||
assert.Equal(t, 302, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Location"), "/login")
|
||||
|
||||
// 验证用户是否已创建
|
||||
var user models.User
|
||||
err = db.Where("username = ?", "newuser").First(&user).Error
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newuser@example.com", user.Email)
|
||||
}
|
||||
|
||||
func TestTokenEndpoint(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试客户端
|
||||
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
|
||||
grantTypes, _ := json.Marshal([]string{"authorization_code"})
|
||||
responseTypes, _ := json.Marshal([]string{"code"})
|
||||
client := &models.Client{
|
||||
ClientID: "test_client",
|
||||
ClientSecret: "test_secret",
|
||||
RedirectURIs: datatypes.JSON(redirectURIs),
|
||||
GrantTypes: datatypes.JSON(grantTypes),
|
||||
ResponseTypes: datatypes.JSON(responseTypes),
|
||||
TokenEndpointAuthMethod: "client_secret_post",
|
||||
}
|
||||
result := db.Create(client)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试用户
|
||||
user := &models.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
result = db.Create(user)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试授权码
|
||||
code := &models.AuthorizationCode{
|
||||
Code: "test_code",
|
||||
ClientID: client.ClientID,
|
||||
UserID: user.ID,
|
||||
Scope: "openid profile email",
|
||||
RedirectURI: "http://localhost:8080/callback",
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
result = db.Create(code)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 测试令牌请求
|
||||
form := url.Values{}
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code", code.Code)
|
||||
form.Add("client_id", client.ClientID)
|
||||
form.Add("client_secret", client.ClientSecret)
|
||||
form.Add("redirect_uri", code.RedirectURI)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response, "access_token")
|
||||
assert.Contains(t, response, "id_token")
|
||||
|
||||
// 使用获取到的访问令牌测试用户信息端点
|
||||
accessToken := response["access_token"].(string)
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "/userinfo", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Userinfo response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var userinfoResponse map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &userinfoResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "testuser", userinfoResponse["sub"])
|
||||
assert.Equal(t, "test@example.com", userinfoResponse["email"])
|
||||
}
|
||||
|
||||
func TestUserinfoEndpoint(t *testing.T) {
|
||||
r, db, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试用户
|
||||
user := &models.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
result := db.Create(user)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试客户端
|
||||
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
|
||||
grantTypes, _ := json.Marshal([]string{"authorization_code"})
|
||||
responseTypes, _ := json.Marshal([]string{"code"})
|
||||
client := &models.Client{
|
||||
ClientID: "test_client",
|
||||
ClientSecret: "test_secret",
|
||||
RedirectURIs: datatypes.JSON(redirectURIs),
|
||||
GrantTypes: datatypes.JSON(grantTypes),
|
||||
ResponseTypes: datatypes.JSON(responseTypes),
|
||||
TokenEndpointAuthMethod: "client_secret_post",
|
||||
}
|
||||
result = db.Create(client)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 创建测试授权码
|
||||
code := &models.AuthorizationCode{
|
||||
Code: "test_code",
|
||||
ClientID: client.ClientID,
|
||||
UserID: user.ID,
|
||||
Scope: "openid profile email",
|
||||
RedirectURI: "http://localhost:8080/callback",
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
result = db.Create(code)
|
||||
assert.NoError(t, result.Error)
|
||||
|
||||
// 测试令牌请求
|
||||
form := url.Values{}
|
||||
form.Add("grant_type", "authorization_code")
|
||||
form.Add("code", code.Code)
|
||||
form.Add("client_id", client.ClientID)
|
||||
form.Add("client_secret", client.ClientSecret)
|
||||
form.Add("redirect_uri", code.RedirectURI)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var tokenResponse map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &tokenResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, tokenResponse, "access_token")
|
||||
|
||||
// 测试用户信息请求
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "/userinfo", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenResponse["access_token"].(string))
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "testuser", response["sub"])
|
||||
assert.Equal(t, "test@example.com", response["email"])
|
||||
}
|
||||
|
||||
func TestClientRegistration(t *testing.T) {
|
||||
r, _, err := setupTestRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test router: %v", err)
|
||||
}
|
||||
|
||||
// 测试客户端注册请求
|
||||
registrationRequest := map[string]interface{}{
|
||||
"client_name": "Test Client",
|
||||
"redirect_uris": []string{"http://localhost:8080/callback"},
|
||||
"response_types": []string{"code"},
|
||||
"grant_types": []string{"authorization_code"},
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(registrationRequest)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response, "client_id")
|
||||
assert.Contains(t, response, "client_secret")
|
||||
}
|
||||
Reference in New Issue
Block a user