426 lines
12 KiB
Go
426 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"html/template"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-contrib/sessions"
|
|
"github.com/gin-contrib/sessions/cookie"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/datatypes"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
|
|
"oidc-oauth2-server/handlers"
|
|
"oidc-oauth2-server/middleware"
|
|
"oidc-oauth2-server/models"
|
|
"oidc-oauth2-server/services"
|
|
)
|
|
|
|
func setupTestRouter() (*gin.Engine, *gorm.DB, error) {
|
|
// 设置测试模式
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
// 初始化测试数据库
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// 运行数据库迁移
|
|
if err := models.AutoMigrate(db); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// 初始化服务
|
|
authService := services.NewAuthService(db)
|
|
oauthService, err := services.NewOAuthService(db)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
clientService := services.NewClientService(db)
|
|
tokenService := services.NewTokenService(db, oauthService.GetKeyManager())
|
|
|
|
// 设置路由
|
|
r := gin.New()
|
|
r.Use(gin.Recovery())
|
|
|
|
// 设置模板
|
|
r.SetFuncMap(template.FuncMap{
|
|
"subtract": func(a, b int) int { return a - b },
|
|
"add": func(a, b int) int { return a + b },
|
|
})
|
|
r.LoadHTMLGlob("templates/*")
|
|
|
|
// 设置会话存储
|
|
store := cookie.NewStore([]byte("test_secret"))
|
|
r.Use(sessions.Sessions("oidc_session", store))
|
|
|
|
// 设置路由处理器
|
|
r.GET("/health", func(c *gin.Context) {
|
|
c.JSON(200, gin.H{"status": "ok"})
|
|
})
|
|
|
|
// 设置认证相关路由
|
|
authHandler := handlers.NewAuthHandler(authService)
|
|
r.GET("/login", authHandler.ShowLogin)
|
|
r.POST("/login", authHandler.HandleLogin)
|
|
r.GET("/signup", authHandler.ShowSignup)
|
|
r.POST("/signup", authHandler.HandleSignup)
|
|
|
|
// 设置 OIDC 相关路由
|
|
oidcHandler := handlers.NewOIDCHandler("http://localhost:8080", oauthService, authService)
|
|
r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration)
|
|
r.GET("/authorize", oidcHandler.Authorize)
|
|
r.POST("/token", oidcHandler.Token)
|
|
r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo)
|
|
r.GET("/jwks", oidcHandler.JWKS)
|
|
|
|
// 设置客户端注册路由
|
|
registrationHandler := handlers.NewRegistrationHandler(clientService)
|
|
r.POST("/register", registrationHandler.Register)
|
|
|
|
// 设置令牌管理路由
|
|
tokenHandler := handlers.NewTokenHandler(tokenService)
|
|
r.POST("/revoke", tokenHandler.Revoke)
|
|
r.POST("/introspect", tokenHandler.Introspect)
|
|
|
|
return r, db, nil
|
|
}
|
|
|
|
func TestHealthCheck(t *testing.T) {
|
|
r, _, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/health", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
|
|
var response map[string]string
|
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "ok", response["status"])
|
|
}
|
|
|
|
func TestLoginPage(t *testing.T) {
|
|
r, _, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/login", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
}
|
|
|
|
func TestLoginAuthentication(t *testing.T) {
|
|
r, db, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
// 创建测试用户
|
|
password := "testpass"
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
assert.NoError(t, err)
|
|
|
|
testUser := &models.User{
|
|
Username: "testuser",
|
|
Password: string(hashedPassword),
|
|
Email: "test@example.com",
|
|
}
|
|
result := db.Create(testUser)
|
|
assert.NoError(t, result.Error)
|
|
|
|
// 测试登录请求 - 使用表单数据
|
|
form := url.Values{}
|
|
form.Add("username", "testuser")
|
|
form.Add("password", "wrongpass")
|
|
form.Add("client_id", "test_client")
|
|
form.Add("response_type", "code")
|
|
form.Add("scope", "openid profile")
|
|
form.Add("redirect_uri", "http://localhost:8080/callback")
|
|
form.Add("state", "test_state")
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
r.ServeHTTP(w, req)
|
|
|
|
// 检查是否返回错误页面
|
|
assert.Equal(t, 200, w.Code)
|
|
assert.Contains(t, w.Body.String(), "用户名或密码错误")
|
|
|
|
// 测试成功登录
|
|
form.Set("password", password) // 使用正确的密码
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest("POST", "/login", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
r.ServeHTTP(w, req)
|
|
|
|
// 检查重定向
|
|
assert.Equal(t, 302, w.Code)
|
|
location := w.Header().Get("Location")
|
|
assert.Contains(t, location, "/authorize")
|
|
}
|
|
|
|
func TestOIDCConfiguration(t *testing.T) {
|
|
r, _, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
|
|
var response map[string]interface{}
|
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Contains(t, response, "issuer")
|
|
}
|
|
|
|
func TestUserRegistration(t *testing.T) {
|
|
r, db, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
// 测试注册请求
|
|
form := url.Values{}
|
|
form.Add("username", "newuser")
|
|
form.Add("password", "newpass")
|
|
form.Add("email", "newuser@example.com")
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("POST", "/signup", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
r.ServeHTTP(w, req)
|
|
|
|
// 检查重定向到登录页面
|
|
assert.Equal(t, 302, w.Code)
|
|
assert.Contains(t, w.Header().Get("Location"), "/login")
|
|
|
|
// 验证用户是否已创建
|
|
var user models.User
|
|
err = db.Where("username = ?", "newuser").First(&user).Error
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "newuser@example.com", user.Email)
|
|
}
|
|
|
|
func TestTokenEndpoint(t *testing.T) {
|
|
r, db, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
// 创建测试客户端
|
|
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
|
|
grantTypes, _ := json.Marshal([]string{"authorization_code"})
|
|
responseTypes, _ := json.Marshal([]string{"code"})
|
|
client := &models.Client{
|
|
ClientID: "test_client",
|
|
ClientSecret: "test_secret",
|
|
RedirectURIs: datatypes.JSON(redirectURIs),
|
|
GrantTypes: datatypes.JSON(grantTypes),
|
|
ResponseTypes: datatypes.JSON(responseTypes),
|
|
TokenEndpointAuthMethod: "client_secret_post",
|
|
}
|
|
result := db.Create(client)
|
|
assert.NoError(t, result.Error)
|
|
|
|
// 创建测试用户
|
|
user := &models.User{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
}
|
|
result = db.Create(user)
|
|
assert.NoError(t, result.Error)
|
|
|
|
// 创建测试授权码
|
|
code := &models.AuthorizationCode{
|
|
Code: "test_code",
|
|
ClientID: client.ClientID,
|
|
UserID: user.ID,
|
|
Scope: "openid profile email",
|
|
RedirectURI: "http://localhost:8080/callback",
|
|
ExpiresAt: time.Now().Add(10 * time.Minute),
|
|
Used: false,
|
|
}
|
|
result = db.Create(code)
|
|
assert.NoError(t, result.Error)
|
|
|
|
// 测试令牌请求
|
|
form := url.Values{}
|
|
form.Add("grant_type", "authorization_code")
|
|
form.Add("code", code.Code)
|
|
form.Add("client_id", client.ClientID)
|
|
form.Add("client_secret", client.ClientSecret)
|
|
form.Add("redirect_uri", code.RedirectURI)
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != 200 {
|
|
t.Logf("Response body: %s", w.Body.String())
|
|
}
|
|
assert.Equal(t, 200, w.Code)
|
|
|
|
var response map[string]interface{}
|
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Contains(t, response, "access_token")
|
|
assert.Contains(t, response, "id_token")
|
|
|
|
// 使用获取到的访问令牌测试用户信息端点
|
|
accessToken := response["access_token"].(string)
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest("GET", "/userinfo", nil)
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != 200 {
|
|
t.Logf("Userinfo response body: %s", w.Body.String())
|
|
}
|
|
assert.Equal(t, 200, w.Code)
|
|
|
|
var userinfoResponse map[string]interface{}
|
|
err = json.Unmarshal(w.Body.Bytes(), &userinfoResponse)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "testuser", userinfoResponse["sub"])
|
|
assert.Equal(t, "test@example.com", userinfoResponse["email"])
|
|
}
|
|
|
|
func TestUserinfoEndpoint(t *testing.T) {
|
|
r, db, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
// 创建测试用户
|
|
user := &models.User{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
}
|
|
result := db.Create(user)
|
|
assert.NoError(t, result.Error)
|
|
|
|
// 创建测试客户端
|
|
redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"})
|
|
grantTypes, _ := json.Marshal([]string{"authorization_code"})
|
|
responseTypes, _ := json.Marshal([]string{"code"})
|
|
client := &models.Client{
|
|
ClientID: "test_client",
|
|
ClientSecret: "test_secret",
|
|
RedirectURIs: datatypes.JSON(redirectURIs),
|
|
GrantTypes: datatypes.JSON(grantTypes),
|
|
ResponseTypes: datatypes.JSON(responseTypes),
|
|
TokenEndpointAuthMethod: "client_secret_post",
|
|
}
|
|
result = db.Create(client)
|
|
assert.NoError(t, result.Error)
|
|
|
|
// 创建测试授权码
|
|
code := &models.AuthorizationCode{
|
|
Code: "test_code",
|
|
ClientID: client.ClientID,
|
|
UserID: user.ID,
|
|
Scope: "openid profile email",
|
|
RedirectURI: "http://localhost:8080/callback",
|
|
ExpiresAt: time.Now().Add(10 * time.Minute),
|
|
Used: false,
|
|
}
|
|
result = db.Create(code)
|
|
assert.NoError(t, result.Error)
|
|
|
|
// 测试令牌请求
|
|
form := url.Values{}
|
|
form.Add("grant_type", "authorization_code")
|
|
form.Add("code", code.Code)
|
|
form.Add("client_id", client.ClientID)
|
|
form.Add("client_secret", client.ClientSecret)
|
|
form.Add("redirect_uri", code.RedirectURI)
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != 200 {
|
|
t.Logf("Response body: %s", w.Body.String())
|
|
}
|
|
assert.Equal(t, 200, w.Code)
|
|
|
|
var tokenResponse map[string]interface{}
|
|
err = json.Unmarshal(w.Body.Bytes(), &tokenResponse)
|
|
assert.NoError(t, err)
|
|
assert.Contains(t, tokenResponse, "access_token")
|
|
|
|
// 测试用户信息请求
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest("GET", "/userinfo", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenResponse["access_token"].(string))
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != 200 {
|
|
t.Logf("Response body: %s", w.Body.String())
|
|
}
|
|
assert.Equal(t, 200, w.Code)
|
|
|
|
var response map[string]interface{}
|
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "testuser", response["sub"])
|
|
assert.Equal(t, "test@example.com", response["email"])
|
|
}
|
|
|
|
func TestClientRegistration(t *testing.T) {
|
|
r, _, err := setupTestRouter()
|
|
if err != nil {
|
|
t.Fatalf("Failed to setup test router: %v", err)
|
|
}
|
|
|
|
// 测试客户端注册请求
|
|
registrationRequest := map[string]interface{}{
|
|
"client_name": "Test Client",
|
|
"redirect_uris": []string{"http://localhost:8080/callback"},
|
|
"response_types": []string{"code"},
|
|
"grant_types": []string{"authorization_code"},
|
|
"token_endpoint_auth_method": "client_secret_basic",
|
|
}
|
|
|
|
jsonData, _ := json.Marshal(registrationRequest)
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 201, w.Code)
|
|
|
|
var response map[string]interface{}
|
|
err = json.Unmarshal(w.Body.Bytes(), &response)
|
|
assert.NoError(t, err)
|
|
assert.Contains(t, response, "client_id")
|
|
assert.Contains(t, response, "client_secret")
|
|
}
|