From 1ec59735206e3e97c594f35407fbd1d94cf6644b Mon Sep 17 00:00:00 2001 From: Taloth Saldono Date: Wed, 8 Nov 2017 22:26:45 +0100 Subject: [PATCH] New: Round-robin over available Download Client instead of the first enabled one --- .../Download/DownloadClientProviderFixture.cs | 185 ++++++++++++++++++ .../Download/DownloadClientProvider.cs | 38 +++- 2 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 src/NzbDrone.Core.Test/Download/DownloadClientProviderFixture.cs diff --git a/src/NzbDrone.Core.Test/Download/DownloadClientProviderFixture.cs b/src/NzbDrone.Core.Test/Download/DownloadClientProviderFixture.cs new file mode 100644 index 000000000..a8e60165e --- /dev/null +++ b/src/NzbDrone.Core.Test/Download/DownloadClientProviderFixture.cs @@ -0,0 +1,185 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using FizzWare.NBuilder; +using FluentAssertions; +using Moq; +using NUnit.Framework; +using NzbDrone.Core.Download; +using NzbDrone.Core.Indexers; +using NzbDrone.Core.Test.Framework; + +namespace NzbDrone.Core.Test.Download +{ + [TestFixture] + public class DownloadClientProviderFixture : CoreTest + { + private List _downloadClients; + private List _blockedProviders; + private int _nextId; + + [SetUp] + public void SetUp() + { + _downloadClients = new List(); + _blockedProviders = new List(); + _nextId = 1; + + Mocker.GetMock() + .Setup(v => v.GetAvailableProviders()) + .Returns(_downloadClients); + + Mocker.GetMock() + .Setup(v => v.GetBlockedProviders()) + .Returns(_blockedProviders); + } + + private Mock WithUsenetClient() + { + var mock = new Mock(MockBehavior.Default); + mock.SetupGet(s => s.Definition) + .Returns(Builder + .CreateNew() + .With(v => v.Id = _nextId++) + .Build()); + + _downloadClients.Add(mock.Object); + + mock.SetupGet(v => v.Protocol).Returns(DownloadProtocol.Usenet); + + return mock; + } + + private Mock WithTorrentClient() + { + var mock = new Mock(MockBehavior.Default); + mock.SetupGet(s => s.Definition) + .Returns(Builder + .CreateNew() + .With(v => v.Id = _nextId++) + .Build()); + + _downloadClients.Add(mock.Object); + + mock.SetupGet(v => v.Protocol).Returns(DownloadProtocol.Torrent); + + return mock; + } + + private void GivenBlockedClient(int id) + { + _blockedProviders.Add(new DownloadClientStatus + { + ProviderId = id, + DisabledTill = DateTime.UtcNow.AddHours(3) + }); + } + + [Test] + public void should_roundrobin_over_usenet_client() + { + WithUsenetClient(); + WithUsenetClient(); + WithUsenetClient(); + WithTorrentClient(); + + var client1 = Subject.GetDownloadClient(DownloadProtocol.Usenet); + var client2 = Subject.GetDownloadClient(DownloadProtocol.Usenet); + var client3 = Subject.GetDownloadClient(DownloadProtocol.Usenet); + var client4 = Subject.GetDownloadClient(DownloadProtocol.Usenet); + var client5 = Subject.GetDownloadClient(DownloadProtocol.Usenet); + + client1.Definition.Id.Should().Be(1); + client2.Definition.Id.Should().Be(2); + client3.Definition.Id.Should().Be(3); + client4.Definition.Id.Should().Be(1); + client5.Definition.Id.Should().Be(2); + } + + [Test] + public void should_roundrobin_over_torrent_client() + { + WithUsenetClient(); + WithTorrentClient(); + WithTorrentClient(); + WithTorrentClient(); + + var client1 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client2 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client3 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client4 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client5 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + + client1.Definition.Id.Should().Be(2); + client2.Definition.Id.Should().Be(3); + client3.Definition.Id.Should().Be(4); + client4.Definition.Id.Should().Be(2); + client5.Definition.Id.Should().Be(3); + } + + [Test] + public void should_roundrobin_over_protocol_separately() + { + WithUsenetClient(); + WithTorrentClient(); + WithTorrentClient(); + + var client1 = Subject.GetDownloadClient(DownloadProtocol.Usenet); + var client2 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client3 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client4 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + + client1.Definition.Id.Should().Be(1); + client2.Definition.Id.Should().Be(2); + client3.Definition.Id.Should().Be(3); + client4.Definition.Id.Should().Be(2); + } + + [Test] + public void should_skip_blocked_torrent_client() + { + WithUsenetClient(); + WithTorrentClient(); + WithTorrentClient(); + WithTorrentClient(); + + GivenBlockedClient(3); + + var client1 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client2 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client3 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client4 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client5 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + + client1.Definition.Id.Should().Be(2); + client2.Definition.Id.Should().Be(4); + client3.Definition.Id.Should().Be(2); + client4.Definition.Id.Should().Be(4); + } + + [Test] + public void should_not_skip_blocked_torrent_client_if_all_blocked() + { + WithUsenetClient(); + WithTorrentClient(); + WithTorrentClient(); + WithTorrentClient(); + + GivenBlockedClient(2); + GivenBlockedClient(3); + GivenBlockedClient(4); + + var client1 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client2 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client3 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client4 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + var client5 = Subject.GetDownloadClient(DownloadProtocol.Torrent); + + client1.Definition.Id.Should().Be(2); + client2.Definition.Id.Should().Be(3); + client3.Definition.Id.Should().Be(4); + client4.Definition.Id.Should().Be(2); + } + } +} diff --git a/src/NzbDrone.Core/Download/DownloadClientProvider.cs b/src/NzbDrone.Core/Download/DownloadClientProvider.cs index 7ed7cd5b9..0e799c609 100644 --- a/src/NzbDrone.Core/Download/DownloadClientProvider.cs +++ b/src/NzbDrone.Core/Download/DownloadClientProvider.cs @@ -1,6 +1,8 @@ using System.Linq; using System.Collections.Generic; using NzbDrone.Core.Indexers; +using NzbDrone.Common.Cache; +using NLog; namespace NzbDrone.Core.Download { @@ -13,16 +15,48 @@ namespace NzbDrone.Core.Download public class DownloadClientProvider : IProvideDownloadClient { + private readonly Logger _logger; private readonly IDownloadClientFactory _downloadClientFactory; + private readonly IDownloadClientStatusService _downloadClientStatusService; + private readonly ICached _lastUsedDownloadClient; - public DownloadClientProvider(IDownloadClientFactory downloadClientFactory) + public DownloadClientProvider(IDownloadClientStatusService downloadClientStatusService, IDownloadClientFactory downloadClientFactory, ICacheManager cacheManager, Logger logger) { + _logger = logger; _downloadClientFactory = downloadClientFactory; + _downloadClientStatusService = downloadClientStatusService; + _lastUsedDownloadClient = cacheManager.GetCache(GetType(), "lastDownloadClientId"); } public IDownloadClient GetDownloadClient(DownloadProtocol downloadProtocol) { - return _downloadClientFactory.GetAvailableProviders().FirstOrDefault(v => v.Protocol == downloadProtocol); + var availableProviders = _downloadClientFactory.GetAvailableProviders().Where(v => v.Protocol == downloadProtocol).ToList(); + + if (!availableProviders.Any()) return null; + + var blockedProviders = new HashSet(_downloadClientStatusService.GetBlockedProviders().Select(v => v.ProviderId)); + + if (blockedProviders.Any()) + { + var nonBlockedProviders = availableProviders.Where(v => !blockedProviders.Contains(v.Definition.Id)).ToList(); + + if (nonBlockedProviders.Any()) + { + availableProviders = nonBlockedProviders; + } + else + { + _logger.Trace("No non-blocked Download Client available, retrying blocked one."); + } + } + + var lastId = _lastUsedDownloadClient.Find(downloadProtocol.ToString()); + + var provider = availableProviders.FirstOrDefault(v => v.Definition.Id > lastId) ?? availableProviders.First(); + + _lastUsedDownloadClient.Set(downloadProtocol.ToString(), provider.Definition.Id); + + return provider; } public IEnumerable GetDownloadClients()