using System; using System.Collections.Generic; using System.Linq; using FluentValidation; using FluentValidation.Results; using Lidarr.Http.REST.Attributes; using Lidarr.Http.Validation; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Controllers; using Microsoft.AspNetCore.Mvc.Filters; using NzbDrone.Core.Datastore; namespace Lidarr.Http.REST { public abstract class RestController : Controller where TResource : RestResource, new() { private static readonly List VALIDATE_ID_ATTRIBUTES = new List { typeof(RestPutByIdAttribute), typeof(RestDeleteByIdAttribute) }; protected ResourceValidator PostValidator { get; private set; } protected ResourceValidator PutValidator { get; private set; } protected ResourceValidator SharedValidator { get; private set; } protected void ValidateId(int id) { if (id <= 0) { throw new BadRequestException(id + " is not a valid ID"); } } protected RestController() { PostValidator = new ResourceValidator(); PutValidator = new ResourceValidator(); SharedValidator = new ResourceValidator(); PutValidator.RuleFor(r => r.Id).ValidId(); } [RestGetById] public abstract TResource GetResourceById(int id); public override void OnActionExecuting(ActionExecutingContext context) { var descriptor = context.ActionDescriptor as ControllerActionDescriptor; var skipAttribute = (SkipValidationAttribute)Attribute.GetCustomAttribute(descriptor.MethodInfo, typeof(SkipValidationAttribute), true); var skipValidate = skipAttribute?.Skip ?? false; var skipShared = skipAttribute?.SkipShared ?? false; if (Request.Method == "POST" || Request.Method == "PUT") { var resourceArgs = context.ActionArguments.Values.Where(x => x.GetType() == typeof(TResource)) .Select(x => x as TResource) .ToList(); foreach (var resource in resourceArgs) { ValidateResource(resource, skipValidate, skipShared); } } var attributes = descriptor.MethodInfo.CustomAttributes; if (attributes.Any(x => VALIDATE_ID_ATTRIBUTES.Contains(x.GetType())) && !skipValidate) { if (context.ActionArguments.TryGetValue("id", out var idObj)) { ValidateId((int)idObj); } } base.OnActionExecuting(context); } public override void OnActionExecuted(ActionExecutedContext context) { var descriptor = context.ActionDescriptor as ControllerActionDescriptor; var attributes = descriptor.MethodInfo.CustomAttributes; if (context.Exception?.GetType() == typeof(ModelNotFoundException) && attributes.Any(x => x.AttributeType == typeof(RestGetByIdAttribute))) { context.Result = new NotFoundResult(); } } protected void ValidateResource(TResource resource, bool skipValidate = false, bool skipSharedValidate = false) { if (resource == null) { throw new BadRequestException("Request body can't be empty"); } var errors = new List(); if (!skipSharedValidate) { errors.AddRange(SharedValidator.Validate(resource).Errors); } if (Request.Method.Equals("POST", StringComparison.InvariantCultureIgnoreCase) && !skipValidate && !Request.Path.ToString().EndsWith("/test", StringComparison.InvariantCultureIgnoreCase)) { errors.AddRange(PostValidator.Validate(resource).Errors); } else if (Request.Method.Equals("PUT", StringComparison.InvariantCultureIgnoreCase)) { errors.AddRange(PutValidator.Validate(resource).Errors); } if (errors.Any()) { throw new ValidationException(errors); } } protected ActionResult Accepted(int id) { var result = GetResourceById(id); return AcceptedAtAction(nameof(GetResourceById), new { id = id }, result); } protected ActionResult Created(int id) { var result = GetResourceById(id); return CreatedAtAction(nameof(GetResourceById), new { id = id }, result); } } }