Files
oidc-server/main_test.go

426 lines
12 KiB
Go

package main
import (
"bytes"
"encoding/json"
"html/template"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"gorm.io/datatypes"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"oidc-oauth2-server/handlers"
"oidc-oauth2-server/middleware"
"oidc-oauth2-server/models"
"oidc-oauth2-server/services"
)
func setupTestRouter() (*gin.Engine, *gorm.DB, error) {
// 设置测试模式
gin.SetMode(gin.TestMode)
// 初始化测试数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
return nil, nil, err
}
// 运行数据库迁移
if err := models.AutoMigrate(db); err != nil {
return nil, nil, err
}
// 初始化服务
authService := services.NewAuthService(db)
oauthService, err := services.NewOAuthService(db)
if err != nil {
return nil, nil, err
}
clientService := services.NewClientService(db)
tokenService := services.NewTokenService(db, oauthService.GetKeyManager())
// 设置路由
r := gin.New()
r.Use(gin.Recovery())
// 设置模板
r.SetFuncMap(template.FuncMap{
"subtract": func(a, b int) int { return a - b },
"add": func(a, b int) int { return a + b },
})
r.LoadHTMLGlob("templates/*")
// 设置会话存储
store := cookie.NewStore([]byte("test_secret"))
r.Use(sessions.Sessions("oidc_session", store))
// 设置路由处理器
r.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})
// 设置认证相关路由
authHandler := handlers.NewAuthHandler(authService)
r.GET("/login", authHandler.ShowLogin)
r.POST("/login", authHandler.HandleLogin)
r.GET("/signup", authHandler.ShowSignup)
r.POST("/signup", authHandler.HandleSignup)
// 设置 OIDC 相关路由
oidcHandler := handlers.NewOIDCHandler("http://localhost:8080", oauthService, authService)
r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration)
r.GET("/authorize", oidcHandler.Authorize)
r.POST("/token", oidcHandler.Token)
r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo)
r.GET("/jwks", oidcHandler.JWKS)
// 设置客户端注册路由
registrationHandler := handlers.NewRegistrationHandler(clientService)
r.POST("/register", registrationHandler.Register)
// 设置令牌管理路由
tokenHandler := handlers.NewTokenHandler(tokenService)
r.POST("/revoke", tokenHandler.Revoke)
r.POST("/introspect", tokenHandler.Introspect)
return r, db, nil
}
func TestHealthCheck(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/health", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var response map[string]string
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "ok", response["status"])
}
func TestLoginPage(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/login", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestLoginAuthentication(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 创建测试用户
password := "testpass"
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
assert.NoError(t, err)
testUser := &models.User{
Username: "testuser",
Password: string(hashedPassword),
Email: "test@example.com",
}
result := db.Create(testUser)
assert.NoError(t, result.Error)
// 测试登录请求 - 使用表单数据
form := url.Values{}
form.Add("username", "testuser")
form.Add("password", "wrongpass")
form.Add("client_id", "test_client")
form.Add("response_type", "code")
form.Add("scope", "openid profile")
form.Add("redirect_uri", "http://localhost:8080/callback")
form.Add("state", "test_state")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
// 检查是否返回错误页面
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "用户名或密码错误")
// 测试成功登录
form.Set("password", password) // 使用正确的密码
w = httptest.NewRecorder()
req, _ = http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
// 检查重定向
assert.Equal(t, 302, w.Code)
location := w.Header().Get("Location")
assert.Contains(t, location, "/authorize")
}
func TestOIDCConfiguration(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/.well-known/openid-configuration", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Contains(t, response, "issuer")
}
func TestUserRegistration(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 测试注册请求
form := url.Values{}
form.Add("username", "newuser")
form.Add("password", "newpass")
form.Add("email", "newuser@example.com")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/signup", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
// 检查重定向到登录页面
assert.Equal(t, 302, w.Code)
assert.Contains(t, w.Header().Get("Location"), "/login")
// 验证用户是否已创建
var user models.User
err = db.Where("username = ?", "newuser").First(&user).Error
assert.NoError(t, err)
assert.Equal(t, "newuser@example.com", user.Email)
}
func TestTokenEndpoint(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 创建测试客户端
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
grantTypes, _ := json.Marshal([]string{"authorization_code"})
responseTypes, _ := json.Marshal([]string{"code"})
client := &models.Client{
ClientID: "test_client",
ClientSecret: "test_secret",
RedirectURIs: datatypes.JSON(redirectURIs),
GrantTypes: datatypes.JSON(grantTypes),
ResponseTypes: datatypes.JSON(responseTypes),
TokenEndpointAuthMethod: "client_secret_post",
}
result := db.Create(client)
assert.NoError(t, result.Error)
// 创建测试用户
user := &models.User{
Username: "testuser",
Email: "test@example.com",
}
result = db.Create(user)
assert.NoError(t, result.Error)
// 创建测试授权码
code := &models.AuthorizationCode{
Code: "test_code",
ClientID: client.ClientID,
UserID: user.ID,
Scope: "openid profile email",
RedirectURI: "http://localhost:8080/callback",
ExpiresAt: time.Now().Add(10 * time.Minute),
Used: false,
}
result = db.Create(code)
assert.NoError(t, result.Error)
// 测试令牌请求
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("code", code.Code)
form.Add("client_id", client.ClientID)
form.Add("client_secret", client.ClientSecret)
form.Add("redirect_uri", code.RedirectURI)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Contains(t, response, "access_token")
assert.Contains(t, response, "id_token")
// 使用获取到的访问令牌测试用户信息端点
accessToken := response["access_token"].(string)
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Userinfo response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var userinfoResponse map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &userinfoResponse)
assert.NoError(t, err)
assert.Equal(t, "testuser", userinfoResponse["sub"])
assert.Equal(t, "test@example.com", userinfoResponse["email"])
}
func TestUserinfoEndpoint(t *testing.T) {
r, db, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 创建测试用户
user := &models.User{
Username: "testuser",
Email: "test@example.com",
}
result := db.Create(user)
assert.NoError(t, result.Error)
// 创建测试客户端
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
grantTypes, _ := json.Marshal([]string{"authorization_code"})
responseTypes, _ := json.Marshal([]string{"code"})
client := &models.Client{
ClientID: "test_client",
ClientSecret: "test_secret",
RedirectURIs: datatypes.JSON(redirectURIs),
GrantTypes: datatypes.JSON(grantTypes),
ResponseTypes: datatypes.JSON(responseTypes),
TokenEndpointAuthMethod: "client_secret_post",
}
result = db.Create(client)
assert.NoError(t, result.Error)
// 创建测试授权码
code := &models.AuthorizationCode{
Code: "test_code",
ClientID: client.ClientID,
UserID: user.ID,
Scope: "openid profile email",
RedirectURI: "http://localhost:8080/callback",
ExpiresAt: time.Now().Add(10 * time.Minute),
Used: false,
}
result = db.Create(code)
assert.NoError(t, result.Error)
// 测试令牌请求
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("code", code.Code)
form.Add("client_id", client.ClientID)
form.Add("client_secret", client.ClientSecret)
form.Add("redirect_uri", code.RedirectURI)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var tokenResponse map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &tokenResponse)
assert.NoError(t, err)
assert.Contains(t, tokenResponse, "access_token")
// 测试用户信息请求
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+tokenResponse["access_token"].(string))
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Logf("Response body: %s", w.Body.String())
}
assert.Equal(t, 200, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, "testuser", response["sub"])
assert.Equal(t, "test@example.com", response["email"])
}
func TestClientRegistration(t *testing.T) {
r, _, err := setupTestRouter()
if err != nil {
t.Fatalf("Failed to setup test router: %v", err)
}
// 测试客户端注册请求
registrationRequest := map[string]interface{}{
"client_name": "Test Client",
"redirect_uris": []string{"http://localhost:8080/callback"},
"response_types": []string{"code"},
"grant_types": []string{"authorization_code"},
"token_endpoint_auth_method": "client_secret_basic",
}
jsonData, _ := json.Marshal(registrationRequest)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
var response map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Contains(t, response, "client_id")
assert.Contains(t, response, "client_secret")
}