package services import ( "crypto/rand" "encoding/base64" "encoding/json" "errors" "time" "gorm.io/datatypes" "gorm.io/gorm" "oidc-oauth2-server/models" ) type ClientService struct { db *gorm.DB } type ClientResponse struct { ClientID string `json:"client_id"` ClientSecret string `json:"client_secret,omitempty"` RegistrationAccessToken string `json:"registration_access_token,omitempty"` RegistrationClientURI string `json:"registration_client_uri,omitempty"` ClientIDIssuedAt time.Time `json:"client_id_issued_at"` ClientSecretExpiresAt time.Time `json:"client_secret_expires_at"` RedirectURIs []string `json:"redirect_uris"` TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` GrantTypes []string `json:"grant_types"` ResponseTypes []string `json:"response_types"` ClientName string `json:"client_name"` ClientURI string `json:"client_uri"` LogoURI string `json:"logo_uri"` Scope string `json:"scope"` Contacts []string `json:"contacts"` TosURI string `json:"tos_uri"` PolicyURI string `json:"policy_uri"` SoftwareID string `json:"software_id"` SoftwareVersion string `json:"software_version"` } type ClientRegistrationRequest struct { RedirectURIs []string `json:"redirect_uris" binding:"required"` TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` GrantTypes []string `json:"grant_types"` ResponseTypes []string `json:"response_types"` ClientName string `json:"client_name"` ClientURI string `json:"client_uri"` LogoURI string `json:"logo_uri"` Scope string `json:"scope"` Contacts []string `json:"contacts"` TosURI string `json:"tos_uri"` PolicyURI string `json:"policy_uri"` SoftwareID string `json:"software_id"` SoftwareVersion string `json:"software_version"` } func NewClientService(db *gorm.DB) *ClientService { return &ClientService{db: db} } func (s *ClientService) RegisterClient(req *ClientRegistrationRequest) (*ClientResponse, error) { // 生成客户端凭证 clientID := generateSecureToken(32) clientSecret := generateSecureToken(48) // 设置默认值 if len(req.GrantTypes) == 0 { req.GrantTypes = []string{"authorization_code"} } if len(req.ResponseTypes) == 0 { req.ResponseTypes = []string{"code"} } if req.TokenEndpointAuthMethod == "" { req.TokenEndpointAuthMethod = "client_secret_basic" } // 将字符串数组转换为 JSON redirectURIsJSON, _ := json.Marshal(req.RedirectURIs) grantTypesJSON, _ := json.Marshal(req.GrantTypes) responseTypesJSON, _ := json.Marshal(req.ResponseTypes) scopesJSON, _ := json.Marshal([]string{req.Scope}) contactsJSON, _ := json.Marshal(req.Contacts) // 创建客户端记录 client := &models.Client{ ClientID: clientID, ClientSecret: clientSecret, RedirectURIs: datatypes.JSON(redirectURIsJSON), TokenEndpointAuthMethod: req.TokenEndpointAuthMethod, GrantTypes: datatypes.JSON(grantTypesJSON), ResponseTypes: datatypes.JSON(responseTypesJSON), ClientName: req.ClientName, ClientURI: req.ClientURI, LogoURI: req.LogoURI, Scopes: datatypes.JSON(scopesJSON), Contacts: datatypes.JSON(contactsJSON), TosURI: req.TosURI, PolicyURI: req.PolicyURI, SoftwareID: req.SoftwareID, SoftwareVersion: req.SoftwareVersion, IsActive: true, CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := s.db.Create(client).Error; err != nil { return nil, err } // 构建响应 return &ClientResponse{ ClientID: clientID, ClientSecret: clientSecret, ClientIDIssuedAt: client.CreatedAt, ClientSecretExpiresAt: time.Time{}, // 永不过期 RedirectURIs: req.RedirectURIs, TokenEndpointAuthMethod: req.TokenEndpointAuthMethod, GrantTypes: req.GrantTypes, ResponseTypes: req.ResponseTypes, ClientName: req.ClientName, ClientURI: req.ClientURI, LogoURI: req.LogoURI, Scope: req.Scope, Contacts: req.Contacts, TosURI: req.TosURI, PolicyURI: req.PolicyURI, SoftwareID: req.SoftwareID, SoftwareVersion: req.SoftwareVersion, }, nil } func (s *ClientService) GetClient(clientID string) (*ClientResponse, error) { var client models.Client if err := s.db.Where("client_id = ? AND is_active = ?", clientID, true).First(&client).Error; err != nil { return nil, err } // 解析 JSON 字段 var redirectURIs, grantTypes, responseTypes, scopes, contacts []string json.Unmarshal(client.RedirectURIs, &redirectURIs) json.Unmarshal(client.GrantTypes, &grantTypes) json.Unmarshal(client.ResponseTypes, &responseTypes) json.Unmarshal(client.Scopes, &scopes) json.Unmarshal(client.Contacts, &contacts) return &ClientResponse{ ClientID: client.ClientID, ClientIDIssuedAt: client.CreatedAt, RedirectURIs: redirectURIs, TokenEndpointAuthMethod: client.TokenEndpointAuthMethod, GrantTypes: grantTypes, ResponseTypes: responseTypes, ClientName: client.ClientName, ClientURI: client.ClientURI, LogoURI: client.LogoURI, Scope: scopes[0], Contacts: contacts, TosURI: client.TosURI, PolicyURI: client.PolicyURI, SoftwareID: client.SoftwareID, SoftwareVersion: client.SoftwareVersion, }, nil } func (s *ClientService) UpdateClient(clientID string, req *ClientRegistrationRequest) (*ClientResponse, error) { var client models.Client if err := s.db.Where("client_id = ? AND is_active = ?", clientID, true).First(&client).Error; err != nil { return nil, err } // 将字符串数组转换为 JSON redirectURIsJSON, _ := json.Marshal(req.RedirectURIs) grantTypesJSON, _ := json.Marshal(req.GrantTypes) responseTypesJSON, _ := json.Marshal(req.ResponseTypes) scopesJSON, _ := json.Marshal([]string{req.Scope}) contactsJSON, _ := json.Marshal(req.Contacts) // 更新客户端信息 client.RedirectURIs = datatypes.JSON(redirectURIsJSON) client.GrantTypes = datatypes.JSON(grantTypesJSON) client.ResponseTypes = datatypes.JSON(responseTypesJSON) client.Scopes = datatypes.JSON(scopesJSON) client.Contacts = datatypes.JSON(contactsJSON) client.UpdatedAt = time.Now() if err := s.db.Save(&client).Error; err != nil { return nil, err } return s.GetClient(clientID) } func (s *ClientService) DeleteClient(clientID string) error { result := s.db.Model(&models.Client{}).Where("client_id = ?", clientID).Update("is_active", false) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return errors.New("client not found") } return nil } func generateSecureToken(length int) string { b := make([]byte, length) if _, err := rand.Read(b); err != nil { return "" } return base64.RawURLEncoding.EncodeToString(b) }