diff --git a/Jellyfin.Server/Filters/SecurityRequirementsOperationFilter.cs b/Jellyfin.Server/Filters/SecurityRequirementsOperationFilter.cs index fb9f6d0a6e..fb0bd817ce 100644 --- a/Jellyfin.Server/Filters/SecurityRequirementsOperationFilter.cs +++ b/Jellyfin.Server/Filters/SecurityRequirementsOperationFilter.cs @@ -1,91 +1,105 @@ using System; using System.Collections.Generic; using System.Linq; +using Jellyfin.Api.Auth.DefaultAuthorizationPolicy; using Jellyfin.Api.Constants; +using Jellyfin.Extensions; using Microsoft.AspNetCore.Authorization; using Microsoft.OpenApi.Models; using Swashbuckle.AspNetCore.SwaggerGen; -namespace Jellyfin.Server.Filters +namespace Jellyfin.Server.Filters; + +/// +/// Security requirement operation filter. +/// +public class SecurityRequirementsOperationFilter : IOperationFilter { + private const string DefaultAuthPolicy = "DefaultAuthorization"; + private static readonly Type _attributeType = typeof(AuthorizeAttribute); + + private readonly IAuthorizationPolicyProvider _authorizationPolicyProvider; + /// - /// Security requirement operation filter. + /// Initializes a new instance of the class. /// - public class SecurityRequirementsOperationFilter : IOperationFilter + /// The authorization policy provider. + public SecurityRequirementsOperationFilter(IAuthorizationPolicyProvider authorizationPolicyProvider) { - /// - public void Apply(OpenApiOperation operation, OperationFilterContext context) - { - var requiredScopes = new List(); + _authorizationPolicyProvider = authorizationPolicyProvider; + } - var requiresAuth = false; - // Add all method scopes. - foreach (var attribute in context.MethodInfo.GetCustomAttributes(true)) - { - if (attribute is not AuthorizeAttribute authorizeAttribute) - { - continue; - } + /// + public void Apply(OpenApiOperation operation, OperationFilterContext context) + { + var requiredScopes = new List(); - requiresAuth = true; - if (authorizeAttribute.Policy is not null - && !requiredScopes.Contains(authorizeAttribute.Policy, StringComparer.Ordinal)) - { - requiredScopes.Add(authorizeAttribute.Policy); - } + var requiresAuth = false; + // Add all method scopes. + foreach (var authorizeAttribute in context.MethodInfo.GetCustomAttributes(_attributeType, true).Cast()) + { + requiresAuth = true; + var policy = authorizeAttribute.Policy ?? DefaultAuthPolicy; + if (!requiredScopes.Contains(policy, StringComparer.Ordinal)) + { + requiredScopes.Add(policy); } + } - // Add controller scopes if any. - var controllerAttributes = context.MethodInfo.DeclaringType?.GetCustomAttributes(true); - if (controllerAttributes is not null) + // Add controller scopes if any. + var controllerAttributes = context.MethodInfo.DeclaringType?.GetCustomAttributes(_attributeType, true).Cast(); + if (controllerAttributes is not null) + { + foreach (var authorizeAttribute in controllerAttributes) { - foreach (var attribute in controllerAttributes) + requiresAuth = true; + var policy = authorizeAttribute.Policy ?? DefaultAuthPolicy; + if (!requiredScopes.Contains(policy, StringComparer.Ordinal)) { - if (attribute is not AuthorizeAttribute authorizeAttribute) - { - continue; - } - - requiresAuth = true; - if (authorizeAttribute.Policy is not null - && !requiredScopes.Contains(authorizeAttribute.Policy, StringComparer.Ordinal)) - { - requiredScopes.Add(authorizeAttribute.Policy); - } + requiredScopes.Add(policy); } } + } - if (!requiresAuth) - { - return; - } + if (!requiresAuth) + { + return; + } - if (!operation.Responses.ContainsKey("401")) - { - operation.Responses.Add("401", new OpenApiResponse { Description = "Unauthorized" }); - } + if (!operation.Responses.ContainsKey("401")) + { + operation.Responses.Add("401", new OpenApiResponse { Description = "Unauthorized" }); + } - if (!operation.Responses.ContainsKey("403")) - { - operation.Responses.Add("403", new OpenApiResponse { Description = "Forbidden" }); - } + if (!operation.Responses.ContainsKey("403")) + { + operation.Responses.Add("403", new OpenApiResponse { Description = "Forbidden" }); + } - var scheme = new OpenApiSecurityScheme + var scheme = new OpenApiSecurityScheme + { + Reference = new OpenApiReference { - Reference = new OpenApiReference - { - Type = ReferenceType.SecurityScheme, - Id = AuthenticationSchemes.CustomAuthentication - } - }; + Type = ReferenceType.SecurityScheme, + Id = AuthenticationSchemes.CustomAuthentication + }, + }; - operation.Security = new List + // Add DefaultAuthorization scope to any endpoint that has a policy with a requirement that is a subset of DefaultAuthorization. + if (!requiredScopes.Contains(DefaultAuthPolicy.AsSpan(), StringComparison.Ordinal)) + { + foreach (var scope in requiredScopes) { - new OpenApiSecurityRequirement + var authorizationPolicy = _authorizationPolicyProvider.GetPolicyAsync(scope).GetAwaiter().GetResult(); + if (authorizationPolicy is not null + && authorizationPolicy.Requirements.Any(r => r is DefaultAuthorizationRequirement)) { - [scheme] = requiredScopes + requiredScopes.Add(DefaultAuthPolicy); + break; } - }; + } } + + operation.Security = [new OpenApiSecurityRequirement { [scheme] = requiredScopes }]; } }