diff --git a/TechnitiumLibrary.Tests/README.md b/TechnitiumLibrary.Tests/README.md new file mode 100644 index 0000000..24d3e27 --- /dev/null +++ b/TechnitiumLibrary.Tests/README.md @@ -0,0 +1,130 @@ +# TechnitiumLibrary.Tests + +This project contains the unit and integration-style test coverage for `TechnitiumLibrary.sln`. + +The goal is to keep tests close to the module they cover, make socket-dependent behavior deterministic with local simulators, and steadily improve coverage without changing production implementation just to make tests easier. + +## Running Tests + +Run the full test project: + +```powershell +dotnet test .\TechnitiumLibrary.Tests\TechnitiumLibrary.Tests.csproj +``` + +Run with coverage: + +```powershell +dotnet test .\TechnitiumLibrary.Tests\TechnitiumLibrary.Tests.csproj --collect:"XPlat Code Coverage" -- DataCollectionRunSettings.DataCollectors.DataCollector.Configuration.Format=cobertura +``` + +Run from WSL/Ubuntu: + +```bash +cd /mnt/d/AIProjects/DNS/TechnitiumLibrary +dotnet restore ./TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj +dotnet test ./TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj --no-restore +``` + +When switching between Windows and WSL, run `dotnet restore` in the OS you are about to test from. The generated NuGet assets can contain OS-specific package paths. + +Run a module or class slice: + +```powershell +dotnet test .\TechnitiumLibrary.Tests\TechnitiumLibrary.Tests.csproj --filter "FullyQualifiedName~TechnitiumLibrary.Net.Tor" +dotnet test .\TechnitiumLibrary.Tests\TechnitiumLibrary.Tests.csproj --filter "DnsDatagramTests" +``` + +## OS-Specific Tests + +The test project should remain cross-OS by default. OS-specific tests are allowed only when the production API itself is platform-specific, and they must be guarded so the full test suite still passes on Windows, Linux, and macOS. + +Current OS-specific tests: + +```text +TechnitiumLibrary.Security.Cryptography/KeyAgreementTests.cs + ECDiffieHellmanDerivesSameSecretOnSupportedPlatforms + ECDiffieHellmanUnsupportedHashThrowsOnSupportedPlatforms +``` + +These tests exercise `TechnitiumLibrary.Security.Cryptography.ECDiffieHellman`, which uses `ECDiffieHellmanCng`. They run their assertions only on Windows and return immediately on non-Windows platforms. + +Socket and protocol simulator tests are not considered OS-specific. They must continue to use loopback, ephemeral ports, and local simulators so they can run on all supported operating systems. + +The full test project was verified on Ubuntu under WSL with .NET SDK `10.0.108`: + +```text +Total tests: 356 +Passed: 356 +``` + +## Project Structure + +Tests are grouped by the production assembly or module they cover: + +```text +TechnitiumLibrary.Tests/ + TechnitiumLibrary/ Core library tests + TechnitiumLibrary.ByteTree/ ByteTree tests + TechnitiumLibrary.IO/ IO and stream/package tests + TechnitiumLibrary.Net/ Networking, DNS, HTTP, proxy, socket helpers + TechnitiumLibrary.Net.BitTorrent/ BitTorrent protocol tests + TechnitiumLibrary.Net.Mail/ Mail protocol tests + TechnitiumLibrary.Net.Tor/ Tor controller and hidden service tests + TechnitiumLibrary.Net.UPnP/ UPnP tests + TechnitiumLibrary.Security.Cryptography/ Cryptography tests + TechnitiumLibrary.Security.OTP/ OTP tests + Simulators/ Local protocol/socket simulators used by tests +``` + +Nested folders should mirror the production module when useful. For example, DNS resource-record tests belong under: + +```text +TechnitiumLibrary.Net/Dns/ResourceRecords/ +``` + +## Contribution Guidelines + +When adding tests: + +- Keep production code unchanged unless a real production bug is discovered and explicitly being fixed. +- Put tests in the matching module folder. Avoid large catch-all test files for unrelated behavior. +- Prefer public APIs. Use reflection only when testing socket/protocol behavior would otherwise require unsafe or non-portable setup. +- Name tests by behavior, not implementation detail. +- Keep assertions meaningful. Avoid asserting incidental details such as object hash codes unless the hash behavior is the actual contract. +- Add focused tests first, then broaden only when the covered behavior is shared or high risk. +- All tests must run cross-OS. Avoid Windows-only commands, shell scripts, fixed ports, real network dependencies, or a real Tor/DNS/mail service. + +## Simulator Guidelines + +Socket-related tests should use local simulators instead of external services. + +Simulator expectations: + +- Place simulators under `Simulators//`. +- Bind to `IPAddress.Loopback` and an ephemeral port. +- Avoid fixed ports. +- Implement `IDisposable` and clean up listeners, sockets, streams, tasks, and cancellation tokens. +- Keep protocol behavior scriptable so tests can cover success, error, timeout, malformed, and disconnect scenarios. +- Prefer deterministic command/response queues over sleeps. +- Do not rely on internet access. +- Keep simulators small and protocol-specific. + +Examples: + +```text +Simulators/TechnitiumLibrary.Net/DnsTestServer.cs +Simulators/TechnitiumLibrary.Net.Mail/Pop3TestServer.cs +Simulators/TechnitiumLibrary.Net.Tor/TorControlTestServer.cs +``` + +## Coverage Work + +Coverage improvements should be done module by module. A good coverage PR usually includes: + +- New tests in the correct module folder. +- Simulator improvements when socket behavior is involved. +- A short note about coverage before and after, when coverage is the purpose of the change. +- A full `dotnet test` pass before submission. + +High-value areas for future coverage include DNS parsing/serialization, DNS client behavior, proxy/socket flows, protocol simulators, and error handling paths. diff --git a/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.BitTorrent/TestTrackerClient.cs b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.BitTorrent/TestTrackerClient.cs new file mode 100644 index 0000000..52f8828 --- /dev/null +++ b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.BitTorrent/TestTrackerClient.cs @@ -0,0 +1,39 @@ +using System.Net; +using TechnitiumLibrary.Net.BitTorrent; + +namespace TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net.BitTorrent +{ + internal sealed class TestTrackerClient : TrackerClient + { + public TestTrackerClient(Uri? trackerUri = null, byte[]? infoHash = null, int customUpdateInterval = 0) + : base(trackerUri ?? new Uri("http://tracker.example/announce"), infoHash ?? CreateInfoHash(), CreateClientId(), customUpdateInterval) + { } + + public Exception? ExceptionToThrow { get; set; } + + public TrackerClientEvent LastEvent { get; private set; } + + public IPEndPoint? LastUpdateEndpoint { get; private set; } + + protected override Task UpdateTrackerAsync(TrackerClientEvent @event, IPEndPoint clientEP) + { + LastEvent = @event; + LastUpdateEndpoint = clientEP; + + if (ExceptionToThrow is not null) + throw ExceptionToThrow; + + return Task.CompletedTask; + } + + private static byte[] CreateInfoHash() + { + return Enumerable.Range(0, 20).Select(Convert.ToByte).ToArray(); + } + + private static TrackerClientID CreateClientId() + { + return new TrackerClientID(Enumerable.Range(20, 20).Select(Convert.ToByte).ToArray(), [1, 2, 3, 4], "agent", "gzip", 50, true, true); + } + } +} diff --git a/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.Mail/Pop3TestServer.cs b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.Mail/Pop3TestServer.cs new file mode 100644 index 0000000..34fe983 --- /dev/null +++ b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.Mail/Pop3TestServer.cs @@ -0,0 +1,72 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; + +namespace TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net.Mail +{ + internal sealed class Pop3TestServer : IDisposable + { + private readonly TcpListener _listener = new TcpListener(IPAddress.Loopback, 0); + private readonly string _greeting; + private readonly Queue _responses = new Queue(); + private readonly List _commands = new List(); + private Task? _serverTask; + + public Pop3TestServer(string greeting) + { + _greeting = greeting; + } + + public int Port + { get { return ((IPEndPoint)_listener.LocalEndpoint).Port; } } + + public IReadOnlyList Commands + { get { return _commands; } } + + public void Enqueue(params string[] lines) + { + _responses.Enqueue(lines); + } + + public Task StartAsync() + { + _listener.Start(); + _serverTask = Task.Run(ServeAsync); + return Task.CompletedTask; + } + + private async Task ServeAsync() + { + using TcpClient client = await _listener.AcceptTcpClientAsync(); + using NetworkStream stream = client.GetStream(); + using StreamReader reader = new StreamReader(stream, Encoding.ASCII, false, 1024, leaveOpen: true); + using StreamWriter writer = new StreamWriter(stream, Encoding.ASCII, 1024, leaveOpen: true) { AutoFlush = true, NewLine = "\r\n" }; + + await writer.WriteLineAsync(_greeting); + + while (_responses.Count > 0) + { + string? command = await reader.ReadLineAsync(); + if (command is null) + break; + + _commands.Add(command); + + foreach (string line in _responses.Dequeue()) + await writer.WriteLineAsync(line); + } + } + + public void Dispose() + { + _listener.Stop(); + + try + { + _serverTask?.Wait(TimeSpan.FromSeconds(2)); + } + catch (AggregateException) + { } + } + } +} diff --git a/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.Tor/TorControlTestServer.cs b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.Tor/TorControlTestServer.cs new file mode 100644 index 0000000..9fb482a --- /dev/null +++ b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net.Tor/TorControlTestServer.cs @@ -0,0 +1,79 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; + +namespace TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net.Tor +{ + internal sealed class TorControlTestServer : IDisposable + { + private readonly TcpListener _listener = new TcpListener(IPAddress.Loopback, 0); + private readonly Queue _responses = new Queue(); + private readonly List _commands = new List(); + private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); + private Task? _serverTask; + private TcpClient? _client; + + public int Port + { get { return ((IPEndPoint)_listener.LocalEndpoint).Port; } } + + public IReadOnlyList Commands + { get { return _commands; } } + + public void Enqueue(params string[] lines) + { + _responses.Enqueue(lines); + } + + public void Start() + { + _listener.Start(); + _serverTask = Task.Run(ServeAsync); + } + + private async Task ServeAsync() + { + try + { + _client = await _listener.AcceptTcpClientAsync(_cancellationTokenSource.Token); + + using NetworkStream stream = _client.GetStream(); + using StreamReader reader = new StreamReader(stream, Encoding.ASCII, false, 1024, leaveOpen: true); + using StreamWriter writer = new StreamWriter(stream, Encoding.ASCII, 1024, leaveOpen: true) { AutoFlush = true, NewLine = "\n" }; + + while (!_cancellationTokenSource.IsCancellationRequested && (_responses.Count > 0)) + { + string? command = await reader.ReadLineAsync(_cancellationTokenSource.Token); + if (command is null) + break; + + _commands.Add(command); + + foreach (string line in _responses.Dequeue()) + await writer.WriteLineAsync(line); + } + } + catch (OperationCanceledException) + { } + catch (ObjectDisposedException) + { } + catch (IOException) + { } + } + + public void Dispose() + { + _cancellationTokenSource.Cancel(); + _client?.Dispose(); + _listener.Stop(); + + try + { + _serverTask?.Wait(TimeSpan.FromSeconds(2)); + } + catch + { } + + _cancellationTokenSource.Dispose(); + } + } +} diff --git a/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net/DnsTestServer.cs b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net/DnsTestServer.cs new file mode 100644 index 0000000..9b9d986 --- /dev/null +++ b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net/DnsTestServer.cs @@ -0,0 +1,262 @@ +using System.Net; +using System.Net.Sockets; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net +{ + internal sealed class DnsTestServer : IDisposable + { + private readonly TcpListener _tcpListener; + private readonly UdpClient _udpClient; + private readonly Dictionary>> _answers = new Dictionary>>(); + private readonly Dictionary _responseCodes = new Dictionary(); + private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); + private Task? _udpTask; + private Task? _tcpTask; + + public DnsTestServer() + { + for (int i = 0; ; i++) + { + TcpListener tcpListener = new TcpListener(IPAddress.Loopback, 0); + tcpListener.Start(); + int port = ((IPEndPoint)tcpListener.LocalEndpoint).Port; + + try + { + _udpClient = new UdpClient(new IPEndPoint(IPAddress.Loopback, port)); + _tcpListener = tcpListener; + Port = port; + break; + } + catch (SocketException) when (i < 10) + { + tcpListener.Stop(); + } + } + } + + public int Port { get; } + + public int UdpQueryCount { get; private set; } + + public int TcpQueryCount { get; private set; } + + public bool TruncateUdpResponses { get; set; } + + public bool DropUdpResponses { get; set; } + + public void AddAddress(string domain, IPAddress address) + { + DnsResourceRecordType type = address.AddressFamily == AddressFamily.InterNetwork ? DnsResourceRecordType.A : DnsResourceRecordType.AAAA; + + _answers[new QuestionKey(domain, type)] = question => + [ + new DnsResourceRecord( + question.Name, + type, + DnsClass.IN, + 60, + type == DnsResourceRecordType.A ? new DnsARecordData(address) : new DnsAAAARecordData(address)) + ]; + } + + public void AddCNameAddress(string aliasDomain, string canonicalDomain, IPAddress address) + { + _answers[new QuestionKey(aliasDomain, DnsResourceRecordType.A)] = question => + [ + new DnsResourceRecord(question.Name, DnsResourceRecordType.CNAME, DnsClass.IN, 60, new DnsCNAMERecordData(canonicalDomain)), + new DnsResourceRecord(canonicalDomain, DnsResourceRecordType.A, DnsClass.IN, 60, new DnsARecordData(address)) + ]; + } + + public void AddMx(string domain, ushort preference, string exchange) + { + _answers[new QuestionKey(domain, DnsResourceRecordType.MX)] = question => + [ + new DnsResourceRecord(question.Name, DnsResourceRecordType.MX, DnsClass.IN, 60, new DnsMXRecordData(preference, exchange)) + ]; + } + + public void SetResponseCode(string domain, DnsResourceRecordType type, DnsResponseCode responseCode) + { + _responseCodes[new QuestionKey(domain, type)] = responseCode; + } + + public void Start() + { + _udpTask = Task.Run(ServeUdpAsync); + _tcpTask = Task.Run(ServeTcpAsync); + } + + private async Task ServeUdpAsync() + { + while (!_cancellationTokenSource.IsCancellationRequested) + { + try + { + UdpReceiveResult result = await _udpClient.ReceiveAsync(_cancellationTokenSource.Token); + UdpQueryCount++; + + if (DropUdpResponses) + continue; + + using MemoryStream requestStream = new MemoryStream(result.Buffer); + DnsDatagram request = DnsDatagram.ReadFrom(requestStream); + DnsDatagram response = TruncateUdpResponses ? CreateTruncatedResponse(request) : CreateResponse(request); + using MemoryStream responseStream = new MemoryStream(); + response.WriteTo(responseStream); + + await _udpClient.SendAsync(responseStream.ToArray(), result.RemoteEndPoint, _cancellationTokenSource.Token); + } + catch (OperationCanceledException) + { + break; + } + catch (ObjectDisposedException) + { + break; + } + } + } + + private async Task ServeTcpAsync() + { + while (!_cancellationTokenSource.IsCancellationRequested) + { + try + { + TcpClient client = await _tcpListener.AcceptTcpClientAsync(_cancellationTokenSource.Token); + _ = Task.Run(() => ServeTcpClientAsync(client)); + } + catch (OperationCanceledException) + { + break; + } + catch (ObjectDisposedException) + { + break; + } + } + } + + private async Task ServeTcpClientAsync(TcpClient client) + { + using (client) + using (NetworkStream stream = client.GetStream()) + { + while (!_cancellationTokenSource.IsCancellationRequested) + { + DnsDatagram request; + + try + { + request = await DnsDatagram.ReadFromTcpAsync(stream, cancellationToken: _cancellationTokenSource.Token); + } + catch + { + break; + } + + TcpQueryCount++; + + DnsDatagram response = CreateResponse(request); + await response.WriteToTcpAsync(stream, cancellationToken: _cancellationTokenSource.Token); + await stream.FlushAsync(_cancellationTokenSource.Token); + } + } + } + + private DnsDatagram CreateTruncatedResponse(DnsDatagram request) + { + return new DnsDatagram( + request.Identifier, + true, + DnsOpcode.StandardQuery, + true, + true, + request.RecursionDesired, + true, + false, + false, + DnsResponseCode.NoError, + request.Question); + } + + private DnsDatagram CreateResponse(DnsDatagram request) + { + DnsResponseCode responseCode = DnsResponseCode.NoError; + IReadOnlyList answers = Array.Empty(); + + if (request.Question.Count > 0) + { + DnsQuestionRecord question = request.Question[0]; + QuestionKey key = new QuestionKey(question.Name, question.Type); + + if (_responseCodes.TryGetValue(key, out DnsResponseCode configuredResponseCode)) + responseCode = configuredResponseCode; + + if ((responseCode == DnsResponseCode.NoError) && _answers.TryGetValue(key, out var answerFactory)) + answers = answerFactory(question); + } + + return new DnsDatagram( + request.Identifier, + true, + DnsOpcode.StandardQuery, + true, + false, + request.RecursionDesired, + true, + false, + false, + responseCode, + request.Question, + answers); + } + + public void Dispose() + { + _cancellationTokenSource.Cancel(); + _udpClient.Dispose(); + _tcpListener.Stop(); + + try + { + Task.WaitAll(new[] { _udpTask, _tcpTask }.Where(task => task is not null).Cast().ToArray(), TimeSpan.FromSeconds(2)); + } + catch + { } + + _cancellationTokenSource.Dispose(); + } + + private readonly struct QuestionKey : IEquatable + { + private readonly string _domain; + private readonly DnsResourceRecordType _type; + + public QuestionKey(string domain, DnsResourceRecordType type) + { + _domain = domain.TrimEnd('.').ToLowerInvariant(); + _type = type; + } + + public bool Equals(QuestionKey other) + { + return (_domain == other._domain) && (_type == other._type); + } + + public override bool Equals(object? obj) + { + return obj is QuestionKey other && Equals(other); + } + + public override int GetHashCode() + { + return HashCode.Combine(_domain, _type); + } + } + } +} diff --git a/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net/UnsupportedEndPoint.cs b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net/UnsupportedEndPoint.cs new file mode 100644 index 0000000..2769a9e --- /dev/null +++ b/TechnitiumLibrary.Tests/Simulators/TechnitiumLibrary.Net/UnsupportedEndPoint.cs @@ -0,0 +1,18 @@ +using System.Net; +using System.Net.Sockets; + +namespace TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net +{ + internal sealed class UnsupportedEndPoint : EndPoint + { + private readonly AddressFamily _addressFamily; + + public UnsupportedEndPoint(AddressFamily addressFamily) + { + _addressFamily = addressFamily; + } + + public override AddressFamily AddressFamily + { get { return _addressFamily; } } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeEnumerationTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeEnumerationTests.cs new file mode 100644 index 0000000..b5799be --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeEnumerationTests.cs @@ -0,0 +1,51 @@ +using TechnitiumLibrary.ByteTree; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.ByteTree +{ + public class ByteTreeEnumerationTests + { + [Fact] + public void EnumeratesInKeyOrderAndReverseKeyOrder() + { + ByteTree tree = new ByteTree(); + tree.Add(new byte[] { 2 }, "two"); + tree.Add(new byte[] { 1 }, "one"); + tree.Add(new byte[] { 1, 1 }, "one-one"); + + Assert.Equal(new[] { "one", "one-one", "two" }, tree.ToArray()); + Assert.Equal(new[] { "two", "one-one", "one" }, tree.GetReverseEnumerable().ToArray()); + } + + [Fact] + public void NonGenericEnumerator_CurrentResetAndFinishedState_Work() + { + ByteTree tree = new ByteTree(); + + Assert.Empty(((System.Collections.IEnumerable)tree).Cast()); + tree.Add(new byte[] { 1 }, "one"); + + System.Collections.IEnumerator enumerator = ((System.Collections.IEnumerable)tree).GetEnumerator(); + Assert.Null(enumerator.Current); + Assert.True(enumerator.MoveNext()); + Assert.Equal("one", enumerator.Current); + enumerator.Reset(); + Assert.True(enumerator.MoveNext()); + Assert.Equal("one", enumerator.Current); + Assert.False(enumerator.MoveNext()); + Assert.Null(enumerator.Current); + } + + [Fact] + public void DeepStemEnumeration_CoversNestedTraversal() + { + ByteTree tree = new ByteTree(); + tree.Add(new byte[] { 1, 1, 1 }, "a"); + tree.Add(new byte[] { 1, 1, 2 }, "b"); + tree.Add(new byte[] { 1, 2 }, "c"); + tree.Add(new byte[] { 2 }, "d"); + + Assert.Equal(new[] { "a", "b", "c", "d" }, tree.ToArray()); + Assert.Equal(new[] { "d", "c", "b", "a" }, tree.GetReverseEnumerable().ToArray()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeMutationTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeMutationTests.cs new file mode 100644 index 0000000..19d75d8 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeMutationTests.cs @@ -0,0 +1,79 @@ +using TechnitiumLibrary.ByteTree; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.ByteTree +{ + public class ByteTreeMutationTests + { + [Fact] + public void AddUpdateRemove_TracksValuesByByteKey() + { + ByteTree tree = new ByteTree(); + byte[] key = new byte[] { 1, 2, 3 }; + + Assert.True(tree.IsEmpty); + tree.Add(key, "one"); + + Assert.False(tree.IsEmpty); + Assert.True(tree.ContainsKey(key)); + Assert.Equal("one", tree[key]); + + string updated = tree.AddOrUpdate(key, "new", (_, existing) => existing + "-updated"); + Assert.Equal("one-updated", updated); + Assert.Equal("one-updated", tree[key]); + + Assert.True(tree.TryRemove(key, out string removed)); + Assert.Equal("one-updated", removed); + Assert.False(tree.ContainsKey(key)); + } + + [Fact] + public void TryAddGetOrAddIndexerAndClear_CoverCommonBranches() + { + ByteTree tree = new ByteTree(); + byte[] key = new byte[] { 1, 2 }; + + Assert.True(tree.TryAdd(key, "one")); + Assert.False(tree.TryAdd(key, "duplicate")); + Assert.True(tree.TryGet(key, out string value)); + Assert.Equal("one", value); + Assert.Equal("one", tree.GetOrAdd(key, "two")); + Assert.Equal("three", tree.GetOrAdd(new byte[] { 1, 3 }, _ => "three")); + + tree[key] = "updated"; + Assert.Equal("updated", tree[key]); + + tree.Clear(); + Assert.True(tree.IsEmpty); + Assert.False(tree.TryGet(key, out _)); + Assert.Throws(() => tree[key]); + } + + [Fact] + public void TryUpdate_UsesReferenceComparison() + { + ByteTree tree = new ByteTree(); + byte[] key = new byte[] { 1 }; + object original = new object(); + object replacement = new object(); + + tree.Add(key, original); + + Assert.False(tree.TryUpdate(new byte[] { 2 }, replacement, original)); + Assert.False(tree.TryUpdate(key, replacement, new object())); + Assert.True(tree.TryUpdate(key, replacement, original)); + Assert.Same(replacement, tree[key]); + } + + [Fact] + public void AddOrUpdate_AddFactoryBranchAndIndexerInsert() + { + ByteTree tree = new ByteTree(); + + Assert.Equal("added", tree.AddOrUpdate(new byte[] { 7 }, _ => "added", (_, existing) => existing)); + tree[new byte[] { 8 }] = "inserted"; + + Assert.Equal("added", tree[new byte[] { 7 }]); + Assert.Equal("inserted", tree[new byte[] { 8 }]); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeValidationTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeValidationTests.cs new file mode 100644 index 0000000..6dd4df0 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeValidationTests.cs @@ -0,0 +1,63 @@ +using TechnitiumLibrary.ByteTree; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.ByteTree +{ + public class ByteTreeValidationTests + { + [Fact] + public void DuplicateAddThrows() + { + ByteTree tree = new ByteTree(); + byte[] key = new byte[] { 42 }; + tree.Add(key, new object()); + + Assert.Throws(() => tree.Add(key, new object())); + } + + [Fact] + public void NullKeysThrow() + { + ByteTree tree = new ByteTree(); + + Assert.Throws(() => tree.Add(null, "value")); + Assert.Throws(() => tree.TryAdd(null, "value")); + Assert.Throws(() => tree.ContainsKey(null)); + Assert.Throws(() => tree.TryGet(null, out _)); + Assert.Throws(() => tree.GetOrAdd(null, "value")); + Assert.Throws(() => tree.AddOrUpdate(null, "value", (_, existing) => existing)); + Assert.Throws(() => tree.TryRemove(null, out _)); + Assert.Throws(() => tree.TryUpdate(null, "new", "old")); + Assert.Throws(() => tree[null]); + } + + [Fact] + public void InvalidKeySpaceThrows() + { + Assert.Throws(() => new ByteTree(-1)); + Assert.Throws(() => new ByteTree(257)); + } + + [Fact] + public void RemoveMissingReturnsFalseAndDefaultValue() + { + ByteTree tree = new ByteTree(); + + Assert.False(tree.TryRemove(new byte[] { 99 }, out string removed)); + Assert.Null(removed); + } + + [Fact] + public void CustomConverterNullKey_ReturnsFalseForTryMethods() + { + RejectingByteTree tree = new RejectingByteTree(); + + Assert.False(tree.TryAdd("reject-add", "value")); + Assert.False(tree.ContainsKey("reject-contains")); + Assert.False(tree.TryGet("reject-get", out string value)); + Assert.Null(value); + Assert.False(tree.TryRemove("reject-remove", out string removed)); + Assert.Null(removed); + Assert.False(tree.TryUpdate("reject-update", "new", "old")); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/RejectingByteTree.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/RejectingByteTree.cs new file mode 100644 index 0000000..a67b887 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/RejectingByteTree.cs @@ -0,0 +1,19 @@ +using TechnitiumLibrary.ByteTree; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.ByteTree +{ + internal sealed class RejectingByteTree : ByteTree + { + public RejectingByteTree() + : base(256) + { } + + protected override byte[] ConvertToByteKey(string key, bool throwException = true) + { + if (key.StartsWith("reject")) + return null!; + + return key.Select(c => (byte)c).ToArray(); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryReaderWriterExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryReaderWriterExtensionsTests.cs new file mode 100644 index 0000000..bdfc6ab --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryReaderWriterExtensionsTests.cs @@ -0,0 +1,35 @@ +using System.Text; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class BinaryReaderWriterExtensionsTests + { + [Fact] + public void WriteBufferAndReadBuffer_RoundtripShortAndLongBuffers() + { + byte[] shortBuffer = new byte[] { 1, 2, 3 }; + byte[] longBuffer = Enumerable.Range(0, 300).Select(i => (byte)i).ToArray(); + using MemoryStream stream = new MemoryStream(); + using BinaryWriter writer = new BinaryWriter(stream, Encoding.UTF8, leaveOpen: true); + + writer.WriteBuffer(shortBuffer); + writer.WriteBuffer(longBuffer, 10, 200); + writer.Flush(); + stream.Position = 0; + + using BinaryReader reader = new BinaryReader(stream, Encoding.UTF8, leaveOpen: true); + Assert.Equal(shortBuffer, reader.ReadBuffer()); + Assert.Equal(longBuffer.Skip(10).Take(200).ToArray(), reader.ReadBuffer()); + } + + [Fact] + public void ReadLength_UnsupportedLengthThrows() + { + using MemoryStream stream = new MemoryStream(new byte[] { 0x85, 1, 2, 3, 4, 5 }); + using BinaryReader reader = new BinaryReader(stream); + + Assert.Throws(() => reader.ReadLength()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/JointTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/JointTests.cs new file mode 100644 index 0000000..e16dfe2 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/JointTests.cs @@ -0,0 +1,27 @@ +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class JointTests + { + [Fact] + public void Dispose_IsIdempotentAndDisposesBothStreams() + { + MemoryStream stream1 = new MemoryStream(); + MemoryStream stream2 = new MemoryStream(); + using Joint joint = new Joint(stream1, stream2); + int disposingCount = 0; + joint.Disposing += (_, _) => disposingCount++; + + Assert.Same(stream1, joint.Stream1); + Assert.Same(stream2, joint.Stream2); + + joint.Dispose(); + joint.Dispose(); + + Assert.Equal(1, disposingCount); + Assert.False(stream1.CanRead); + Assert.False(stream2.CanRead); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/OffsetStreamTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/OffsetStreamTests.cs new file mode 100644 index 0000000..5bf583d --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/OffsetStreamTests.cs @@ -0,0 +1,96 @@ +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class OffsetStreamTests + { + [Fact] + public void Read_ReadsOnlyConfiguredWindow() + { + byte[] source = new byte[] { 10, 20, 30, 40, 50 }; + using MemoryStream baseStream = new MemoryStream(source); + using OffsetStream offsetStream = new OffsetStream(baseStream, offset: 1, length: 3, readOnly: true); + + byte[] buffer = new byte[10]; + int read = offsetStream.Read(buffer, 0, buffer.Length); + + Assert.Equal(3, read); + Assert.Equal(new byte[] { 20, 30, 40 }, buffer.Take(read).ToArray()); + Assert.Equal(0, offsetStream.Read(buffer, 0, buffer.Length)); + } + + [Fact] + public void Write_StartsAtBaseOffsetAndExpandsVirtualLength() + { + using MemoryStream baseStream = new MemoryStream(new byte[] { 1, 2, 3, 4, 5 }); + using OffsetStream offsetStream = new OffsetStream(baseStream, offset: 2); + + offsetStream.Write(new byte[] { 9, 8, 7 }, 0, 3); + + Assert.Equal(3, offsetStream.Length); + Assert.Equal(new byte[] { 1, 2, 9, 8, 7 }, baseStream.ToArray()); + } + + [Fact] + public void ReadOnlyWrite_Throws() + { + using MemoryStream baseStream = new MemoryStream(new byte[] { 1, 2, 3 }); + using OffsetStream offsetStream = new OffsetStream(baseStream, length: 3, readOnly: true); + + Assert.False(offsetStream.CanWrite); + Assert.Throws(() => offsetStream.Write(new byte[] { 4 }, 0, 1)); + } + + [Fact] + public void SeekSetLengthWriteToAndFlush_CoverSyncPaths() + { + using MemoryStream baseStream = new MemoryStream(new byte[] { 1, 2, 3, 4, 5 }); + using OffsetStream offsetStream = new OffsetStream(baseStream, offset: 1, length: 3); + + Assert.True(offsetStream.CanRead); + Assert.True(offsetStream.CanSeek); + Assert.True(offsetStream.CanWrite); + Assert.False(offsetStream.CanTimeout); + Assert.Same(baseStream, offsetStream.BaseStream); + Assert.Equal(1, offsetStream.BaseStreamOffset); + + Assert.Equal(1, offsetStream.Seek(1, SeekOrigin.Begin)); + Assert.Equal(2, offsetStream.Seek(1, SeekOrigin.Current)); + Assert.Equal(1, offsetStream.Seek(-2, SeekOrigin.End)); + Assert.Throws(() => offsetStream.Seek(3, SeekOrigin.Begin)); + + offsetStream.SetLength(4); + Assert.Equal(4, offsetStream.Length); + offsetStream.Flush(); + + using MemoryStream copy = new MemoryStream(); + offsetStream.WriteTo(copy, 8); + Assert.Equal(new byte[] { 2, 3, 4, 5 }, copy.ToArray()); + } + + [Fact] + public async Task AsyncPaths_ReadWriteFlushAndCopy() + { + using MemoryStream baseStream = new MemoryStream(new byte[10]); + using OffsetStream offsetStream = new OffsetStream(baseStream, offset: 2); + + await offsetStream.WriteAsync(new byte[] { 1, 2, 3 }, 0, 3, default); + await offsetStream.WriteAsync(new byte[] { 4, 5 }.AsMemory(), default); + await offsetStream.FlushAsync(default); + + offsetStream.Position = 0; + byte[] buffer = new byte[5]; + Assert.Equal(5, await offsetStream.ReadAsync(buffer, 0, 5, default)); + Assert.Equal(new byte[] { 1, 2, 3, 4, 5 }, buffer); + + offsetStream.Position = 0; + byte[] memoryBuffer = new byte[5]; + Assert.Equal(5, await offsetStream.ReadAsync(memoryBuffer.AsMemory(), default)); + Assert.Equal(new byte[] { 1, 2, 3, 4, 5 }, memoryBuffer); + + using MemoryStream copy = new MemoryStream(); + await offsetStream.WriteToAsync(copy, 8); + Assert.Equal(new byte[] { 1, 2, 3, 4, 5 }, copy.ToArray()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageItemTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageItemTests.cs new file mode 100644 index 0000000..b783614 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageItemTests.cs @@ -0,0 +1,61 @@ +using System.Text; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class PackageItemTests + { + [Fact] + public void WriteParseExtractAndProperties_Work() + { + DateTime modified = DateTime.UnixEpoch.AddSeconds(1000); + using MemoryStream data = new MemoryStream(Encoding.UTF8.GetBytes("payload")); + PackageItem item = new PackageItem("file.txt", modified, data, PackageItemAttributes.FixedExtractLocation, ExtractLocation.Custom, Path.GetTempPath()); + + using MemoryStream serialized = new MemoryStream(); + item.WriteTo(serialized); + serialized.Position = 0; + + PackageItem parsed = PackageItem.Parse(serialized); + Assert.Equal("file.txt", parsed.Name); + Assert.Equal(modified, parsed.LastModifiedUTC); + Assert.Equal(PackageItemAttributes.FixedExtractLocation, parsed.Attribute); + Assert.Equal(ExtractLocation.Custom, parsed.ExtractTo); + Assert.Equal(Path.GetTempPath(), parsed.ExtractToCustomLocation); + Assert.True(parsed.IsAttributeSet(PackageItemAttributes.FixedExtractLocation)); + + using StreamReader reader = new StreamReader(parsed.DataStream, Encoding.UTF8, leaveOpen: true); + Assert.Equal("payload", reader.ReadToEnd()); + + string dir = Path.Combine(Path.GetTempPath(), "TechnitiumLibraryTests", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(dir); + try + { + string filePath = Path.Combine(dir, "file.txt"); + PackageItemTransactionLog log = parsed.Extract(filePath); + Assert.Equal(filePath, log.FilePath); + Assert.Null(log.OriginalFilePath); + Assert.Equal("payload", File.ReadAllText(filePath)); + Assert.Null(parsed.Extract(filePath)); + + File.WriteAllText(filePath, "old"); + PackageItemTransactionLog overwriteLog = parsed.Extract(filePath, overwrite: true); + Assert.Equal(filePath, overwriteLog.FilePath); + Assert.NotNull(overwriteLog.OriginalFilePath); + Assert.True(File.Exists(overwriteLog.OriginalFilePath)); + } + finally + { + Directory.Delete(dir, recursive: true); + } + } + + [Fact] + public void InvalidVersionsThrowOrReturnNull() + { + Assert.Throws(() => PackageItem.Parse(new MemoryStream())); + Assert.Null(PackageItem.Parse(new MemoryStream(new byte[] { 0 }))); + Assert.Throws(() => PackageItem.Parse(new MemoryStream(new byte[] { 99 }))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageTests.cs new file mode 100644 index 0000000..02f2c08 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageTests.cs @@ -0,0 +1,59 @@ +using System.Text; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class PackageTests + { + [Fact] + public void CreateOpenItemsAndExtractAll_Work() + { + string dir = Path.Combine(Path.GetTempPath(), "TechnitiumLibraryTests", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(dir); + try + { + string packagePath = Path.Combine(dir, "test.pkg"); + string extractDir = Path.Combine(dir, "extract"); + Directory.CreateDirectory(extractDir); + + using (Package package = new Package(packagePath, PackageMode.Create)) + { + Assert.Throws(() => _ = package.Items); + package.AddItem(new PackageItem("a.txt", DateTime.UnixEpoch, new MemoryStream(Encoding.UTF8.GetBytes("a")), extractTo: ExtractLocation.Custom, extractToCustomLocation: extractDir)); + package.Close(); + Assert.Throws(() => package.AddItem(new PackageItem("b.txt", new MemoryStream()))); + } + + using Package opened = new Package(packagePath, PackageMode.Open); + Assert.Single(opened.Items); + Assert.Equal("a.txt", opened.Items[0].Name); + Assert.Throws(() => opened.AddItem(new PackageItem("b.txt", new MemoryStream()))); + + opened.ExtractAll(overwrite: true); + Assert.Equal("a", File.ReadAllText(Path.Combine(extractDir, "a.txt"))); + } + finally + { + Directory.Delete(dir, recursive: true); + } + } + + [Fact] + public void InvalidHeadersAndVersionsThrow() + { + Assert.Throws(() => new Package(new MemoryStream(new byte[] { (byte)'X', (byte)'Y', 1 }), PackageMode.Open)); + Assert.Throws(() => new Package(new MemoryStream(new byte[] { (byte)'T', (byte)'P' }), PackageMode.Open)); + Assert.Throws(() => new Package(new MemoryStream(new byte[] { (byte)'T', (byte)'P', 99 }), PackageMode.Open)); + } + + [Fact] + public void GetExtractLocation_CoversStableLocations() + { + string custom = Path.Combine(Path.GetTempPath(), "custom"); + + Assert.Null(Package.GetExtractLocation(ExtractLocation.None, null)); + Assert.Equal(Path.GetTempPath(), Package.GetExtractLocation(ExtractLocation.Temp, null)); + Assert.Equal(custom, Package.GetExtractLocation(ExtractLocation.Custom, custom)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PipeTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PipeTests.cs new file mode 100644 index 0000000..4af9a26 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PipeTests.cs @@ -0,0 +1,44 @@ +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class PipeTests + { + [Fact] + public void TransfersBytesAndSupportsTimeoutAndSeekExceptions() + { + Pipe pipe = new Pipe(); + Stream stream1 = pipe.Stream1; + Stream stream2 = pipe.Stream2; + + Assert.True(stream1.CanRead); + Assert.True(stream1.CanWrite); + Assert.True(stream1.CanTimeout); + Assert.False(stream1.CanSeek); + + stream1.WriteTimeout = 50; + stream2.ReadTimeout = 50; + Assert.Equal(50, stream1.WriteTimeout); + Assert.Equal(50, stream2.ReadTimeout); + + stream1.Write(new byte[] { 1, 2, 3 }, 0, 3); + byte[] buffer = new byte[5]; + Assert.Equal(2, stream2.Read(buffer, 0, 2)); + Assert.Equal(new byte[] { 1, 2 }, buffer.Take(2).ToArray()); + Assert.Equal(1, stream2.Read(buffer, 0, 5)); + Assert.Equal(3, buffer[0]); + Assert.Equal(0, stream2.Read(buffer, 0, 0)); + + Assert.Throws(() => _ = stream1.Length); + Assert.Throws(() => _ = stream1.Position); + Assert.Throws(() => stream1.Position = 0); + Assert.Throws(() => stream1.Seek(0, SeekOrigin.Begin)); + Assert.Throws(() => stream1.SetLength(1)); + Assert.Throws(() => stream2.Read(buffer, 0, 1)); + + stream1.Dispose(); + Assert.Equal(0, stream2.Read(buffer, 0, 1)); + Assert.Throws(() => stream2.Write(new byte[] { 1 }, 0, 1)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/StreamExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/StreamExtensionsTests.cs new file mode 100644 index 0000000..e301ae1 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/StreamExtensionsTests.cs @@ -0,0 +1,84 @@ +using System.Text; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class StreamExtensionsTests + { + [Fact] + public void ReadWriteShortStringAndDateTime_Roundtrips() + { + DateTime date = DateTime.UnixEpoch.AddMilliseconds(123456789); + using MemoryStream stream = new MemoryStream(); + + stream.WriteShortString("hello"); + stream.WriteDateTime(date); + stream.Position = 0; + + Assert.Equal("hello", stream.ReadShortString()); + Assert.Equal(date, stream.ReadDateTime()); + } + + [Fact] + public async Task AsyncReadWriteShortStringAndDateTime_Roundtrips() + { + DateTime date = DateTime.UnixEpoch.AddMilliseconds(987654321); + using MemoryStream stream = new MemoryStream(); + + await stream.WriteShortStringAsync("hello"); + await stream.WriteDateTimeAsync(date); + stream.Position = 0; + + Assert.Equal("hello", await stream.ReadShortStringAsync()); + Assert.Equal(date, await stream.ReadDateTimeAsync()); + } + + [Fact] + public void CopyTo_CopiesExactLengthAndThrowsOnShortSource() + { + using MemoryStream source = new MemoryStream(new byte[] { 1, 2, 3, 4, 5 }); + using MemoryStream destination = new MemoryStream(); + + source.CopyTo(destination, bufferSize: 10, length: 3); + + Assert.Equal(new byte[] { 1, 2, 3 }, destination.ToArray()); + Assert.Throws(() => source.CopyTo(Stream.Null, bufferSize: 4, length: 10)); + } + + [Fact] + public async Task CopyToAsync_CopiesExactLengthAndThrowsOnShortSource() + { + using MemoryStream source = new MemoryStream(new byte[] { 1, 2, 3, 4, 5 }); + using MemoryStream destination = new MemoryStream(); + + await source.CopyToAsync(destination, bufferSize: 10, length: 3); + + Assert.Equal(new byte[] { 1, 2, 3 }, destination.ToArray()); + await Assert.ThrowsAsync(() => source.CopyToAsync(Stream.Null, bufferSize: 4, length: 10)); + } + + [Fact] + public void ReadByteValueAndWriteShortString_ThrowOnInvalidData() + { + using MemoryStream empty = new MemoryStream(); + Assert.Throws(() => empty.ReadByteValue()); + + using MemoryStream stream = new MemoryStream(); + Assert.Throws(() => stream.WriteShortString(new string('x', 256), Encoding.ASCII)); + } + + [Fact] + public async Task AsyncByteAndLongStringPaths_Work() + { + using MemoryStream stream = new MemoryStream(); + + await stream.WriteByteAsync(123); + await stream.WriteShortStringAsync(new string('x', 255), Encoding.ASCII); + stream.Position = 0; + + Assert.Equal(123, await stream.ReadByteValueAsync()); + Assert.Equal(new string('x', 255), await stream.ReadShortStringAsync(Encoding.ASCII)); + await Assert.ThrowsAsync(() => stream.WriteShortStringAsync(new string('x', 256), Encoding.ASCII)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/WriteBufferedStreamTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/WriteBufferedStreamTests.cs new file mode 100644 index 0000000..0f3cdc8 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/WriteBufferedStreamTests.cs @@ -0,0 +1,65 @@ +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + public class WriteBufferedStreamTests + { + private sealed class NonWritableStream : MemoryStream + { + public override bool CanWrite => false; + } + + [Fact] + public async Task BuffersFlushesAndDelegatesReads() + { + using MemoryStream baseStream = new MemoryStream(); + using WriteBufferedStream stream = new WriteBufferedStream(baseStream, bufferSize: 3); + + stream.Write(new byte[] { 1, 2 }, 0, 2); + Assert.Empty(baseStream.ToArray()); + + stream.Write(new byte[] { 3, 4 }, 0, 2); + Assert.Equal(new byte[] { 1, 2, 3 }, baseStream.ToArray()); + + await stream.WriteAsync(new byte[] { 5, 6 }, 0, 2); + await stream.WriteAsync(new byte[] { 7, 8 }.AsMemory()); + await stream.FlushAsync(default); + Assert.Equal(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }, baseStream.ToArray()); + + baseStream.Position = 0; + byte[] read = new byte[2]; + Assert.Equal(2, stream.Read(read, 0, 2)); + Assert.Equal(new byte[] { 1, 2 }, read); + + Span span = stackalloc byte[2]; + Assert.Equal(2, stream.Read(span)); + Assert.Equal(new byte[] { 3, 4 }, span.ToArray()); + + byte[] asyncRead = new byte[2]; + Assert.Equal(2, await stream.ReadAsync(asyncRead, 0, 2, default)); + Assert.Equal(new byte[] { 5, 6 }, asyncRead); + + byte[] memoryRead = new byte[2]; + Assert.Equal(2, await stream.ReadAsync(memoryRead.AsMemory(), default)); + Assert.Equal(new byte[] { 7, 8 }, memoryRead); + } + + [Fact] + public void ThrowsForUnsupportedAndDisposedOperations() + { + Assert.Throws(() => new WriteBufferedStream(new NonWritableStream())); + + MemoryStream baseStream = new MemoryStream(); + WriteBufferedStream stream = new WriteBufferedStream(baseStream); + + Assert.False(stream.CanSeek); + Assert.Throws(() => stream.Position = 1); + Assert.Throws(() => stream.Seek(0, SeekOrigin.Begin)); + Assert.Throws(() => stream.SetLength(1)); + + stream.Dispose(); + Assert.Throws(() => stream.Write(new byte[] { 1 }, 0, 1)); + Assert.Throws(() => stream.Flush()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/BencodingTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/BencodingTests.cs new file mode 100644 index 0000000..57e5425 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/BencodingTests.cs @@ -0,0 +1,135 @@ +using System.Collections.Generic; +using System.IO; +using System.Text; +using TechnitiumLibrary.Net.BitTorrent; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.BitTorrent +{ + public class BencodingTests + { + [Fact] + public void StringRoundtrips() + { + Bencoding value = new Bencoding(BencodingType.String, "spam"); + using MemoryStream stream = new MemoryStream(); + + value.Encode(stream); + + Assert.Equal("4:spam", Encoding.ASCII.GetString(stream.ToArray())); + Bencoding decoded = Bencoding.Decode(stream.ToArray()); + Assert.Equal(BencodingType.String, decoded.Type); + Assert.Equal("spam", decoded.ValueString); + Assert.Equal(decoded.Value, decoded.Value as byte[]); + } + + [Fact] + public void ByteStringRoundtrips() + { + byte[] bytes = [0, 1, 2, 255]; + Bencoding value = new Bencoding(BencodingType.String, bytes); + using MemoryStream stream = new MemoryStream(); + + value.Encode(stream); + stream.Position = 0; + + Bencoding decoded = Bencoding.Decode(stream); + + Assert.Equal(bytes, decoded.Value as byte[]); + } + + [Fact] + public void IntegerRoundtrips() + { + Bencoding value = new Bencoding(BencodingType.Integer, -42L); + using MemoryStream stream = new MemoryStream(); + + value.Encode(stream); + + Assert.Equal("i-42e", Encoding.ASCII.GetString(stream.ToArray())); + Assert.Equal(-42L, Bencoding.Decode(stream.ToArray()).ValueInteger); + } + + [Fact] + public void ListRoundtrips() + { + Bencoding value = new Bencoding(BencodingType.List, new List + { + new Bencoding(BencodingType.String, "alpha"), + new Bencoding(BencodingType.Integer, 7L) + }); + using MemoryStream stream = new MemoryStream(); + + value.Encode(stream); + stream.Position = 0; + + Bencoding decoded = Bencoding.Decode(stream); + + Assert.Equal(BencodingType.List, decoded.Type); + Assert.Equal("alpha", decoded.ValueList[0].ValueString); + Assert.Equal(7L, decoded.ValueList[1].ValueInteger); + } + + [Fact] + public void DictionaryRoundtrips() + { + Bencoding value = new Bencoding( + BencodingType.Dictionary, + new Dictionary + { + ["answer"] = new Bencoding(BencodingType.Integer, 42L), + ["items"] = new Bencoding(BencodingType.List, new List + { + new Bencoding(BencodingType.String, "alpha"), + new Bencoding(BencodingType.Integer, -1L) + }) + }); + + using MemoryStream stream = new MemoryStream(); + value.Encode(stream); + stream.Position = 0; + + Bencoding decoded = Bencoding.Decode(stream); + + Assert.Equal(BencodingType.Dictionary, decoded.Type); + Assert.Equal(42L, decoded.ValueDictionary["answer"].ValueInteger); + Assert.Equal("alpha", decoded.ValueDictionary["items"].ValueList[0].ValueString); + Assert.Equal(-1L, decoded.ValueDictionary["items"].ValueList[1].ValueInteger); + } + + [Fact] + public void DecodeReturnsNullForEndMarkerInsideCollections() + { + using MemoryStream stream = new MemoryStream(new byte[] { (byte)'e' }); + + Assert.Null(Bencoding.Decode(stream)); + } + + [Fact] + public void DecodeRejectsTruncatedIntegerAndString() + { + Assert.Throws(() => Bencoding.Decode(Encoding.ASCII.GetBytes("i42"))); + Assert.ThrowsAny(() => Bencoding.Decode(Encoding.ASCII.GetBytes("4:abc"))); + Assert.ThrowsAny(() => Bencoding.Decode(Encoding.ASCII.GetBytes("x"))); + } + + [Fact] + public void DecodeRejectsDictionaryWithNonStringKey() + { + Assert.ThrowsAny(() => Bencoding.Decode(Encoding.ASCII.GetBytes("di1e4:spame"))); + } + + [Fact] + public void DecodeThrowsOnEmptyStream() + { + Assert.Throws(() => Bencoding.Decode(new MemoryStream())); + } + + [Fact] + public void EncodeRejectsInvalidType() + { + Bencoding value = new Bencoding((BencodingType)99, null); + + Assert.ThrowsAny(() => value.Encode(new MemoryStream())); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/HttpTrackerClientTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/HttpTrackerClientTests.cs new file mode 100644 index 0000000..d690ca5 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/HttpTrackerClientTests.cs @@ -0,0 +1,59 @@ +using System.Collections.Generic; +using System.Net; +using System.Reflection; +using TechnitiumLibrary.Net.BitTorrent; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.BitTorrent +{ + public class HttpTrackerClientTests + { + [Fact] + public void ParseCompactPeersIPv4AddsValidPeersAndSkipsUnusableAddresses() + { + List peers = new List(); + byte[] data = + [ + 0, 0, 0, 0, 0x1A, 0xE1, + 127, 0, 0, 1, 0x1A, 0xE1, + 192, 0, 2, 10, 0x1A, 0xE1 + ]; + + InvokeParser("ParseCompactPeersIPv4", data, peers); + + IPEndPoint peer = Assert.Single(peers); + Assert.Equal(IPAddress.Parse("192.0.2.10"), peer.Address); + Assert.Equal(6881, peer.Port); + } + + [Fact] + public void ParseCompactPeersIPv6AddsValidPeersAndSkipsUnusableAddresses() + { + List peers = new List(); + byte[] data = new byte[54]; + WriteCompactIPv6(data, 0, IPAddress.IPv6Any, 6881); + WriteCompactIPv6(data, 18, IPAddress.IPv6Loopback, 6881); + WriteCompactIPv6(data, 36, IPAddress.Parse("2001:db8::10"), 6881); + + InvokeParser("ParseCompactPeersIPv6", data, peers); + + IPEndPoint peer = Assert.Single(peers); + Assert.Equal(IPAddress.Parse("2001:db8::10"), peer.Address); + Assert.Equal(6881, peer.Port); + } + + private static void InvokeParser(string methodName, byte[] data, List peers) + { + Type type = typeof(TrackerClient).Assembly.GetType("TechnitiumLibrary.Net.BitTorrent.HttpTrackerClient")!; + MethodInfo method = type.GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Static)!; + + method.Invoke(null, [data, peers]); + } + + private static void WriteCompactIPv6(byte[] buffer, int offset, IPAddress address, ushort port) + { + address.GetAddressBytes().CopyTo(buffer, offset); + buffer[offset + 16] = (byte)(port >> 8); + buffer[offset + 17] = (byte)port; + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientExceptionTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientExceptionTests.cs new file mode 100644 index 0000000..3a71dd6 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientExceptionTests.cs @@ -0,0 +1,21 @@ +using TechnitiumLibrary.Net.BitTorrent; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.BitTorrent +{ + public class TrackerClientExceptionTests + { + [Fact] + public void ConstructorsSetMessageAndInnerException() + { + TrackerClientException empty = new TrackerClientException(); + TrackerClientException withMessage = new TrackerClientException("message"); + InvalidOperationException inner = new InvalidOperationException("inner"); + TrackerClientException withInner = new TrackerClientException("outer", inner); + + Assert.NotNull(empty.Message); + Assert.Equal("message", withMessage.Message); + Assert.Equal("outer", withInner.Message); + Assert.Same(inner, withInner.InnerException); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientIDTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientIDTests.cs new file mode 100644 index 0000000..45793d4 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientIDTests.cs @@ -0,0 +1,68 @@ +using System.IO; +using System.Linq; +using System.Text; +using TechnitiumLibrary.Net.BitTorrent; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.BitTorrent +{ + public class TrackerClientIDTests + { + [Fact] + public void WriteAndReadRoundtrips() + { + byte[] peerId = Enumerable.Range(1, 20).Select(Convert.ToByte).ToArray(); + byte[] clientKey = [1, 2, 3, 4]; + TrackerClientID expected = new TrackerClientID(peerId, clientKey, "agent", "gzip", 25, compact: true, noPeerID: false); + using MemoryStream stream = new MemoryStream(); + + expected.WriteTo(stream); + stream.Position = 0; + + TrackerClientID actual = new TrackerClientID(stream); + + Assert.Equal(peerId, actual.PeerID); + Assert.Equal(clientKey, actual.ClientKey); + Assert.Equal("agent", actual.HttpUserAgent); + Assert.Equal("gzip", actual.HttpAcceptEncoding); + Assert.Equal(25, actual.NumWant); + Assert.True(actual.Compact); + Assert.False(actual.NoPeerID); + } + + [Fact] + public void GenerateClientKeyReturnsFourBytes() + { + Assert.Equal(4, TrackerClientID.GenerateClientKey().Length); + } + + [Fact] + public void GeneratePeerIDUsesAzureusStylePrefixAndTwentyBytes() + { + byte[] peerId = TrackerClientID.GeneratePeerID("UT3430"); + + Assert.Equal(20, peerId.Length); + Assert.StartsWith("-UT3430-", Encoding.UTF8.GetString(peerId)); + } + + [Fact] + public void CreateDefaultIDUsesExpectedPublicDefaults() + { + TrackerClientID id = TrackerClientID.CreateDefaultID(); + + Assert.Equal(20, id.PeerID.Length); + Assert.Equal(4, id.ClientKey.Length); + Assert.Equal("uTorrent/343(109551416)(40760)", id.HttpUserAgent); + Assert.Equal("gzip", id.HttpAcceptEncoding); + Assert.Equal(50, id.NumWant); + Assert.True(id.Compact); + Assert.True(id.NoPeerID); + } + + [Fact] + public void ConstructorRejectsInvalidHeaderAndUnsupportedVersion() + { + Assert.ThrowsAny(() => new TrackerClientID(new MemoryStream(Encoding.ASCII.GetBytes("NO")))); + Assert.Throws(() => new TrackerClientID(new MemoryStream(new byte[] { (byte)'I', (byte)'D', 99 }))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientTests.cs new file mode 100644 index 0000000..4a5a7bf --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/TrackerClientTests.cs @@ -0,0 +1,125 @@ +using System.Net; +using TechnitiumLibrary.Net.BitTorrent; +using TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net.BitTorrent; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.BitTorrent +{ + public class TrackerClientTests + { + [Theory] + [InlineData("http://tracker.example/announce")] + [InlineData("https://tracker.example/announce")] + [InlineData("udp://tracker.example:6969/announce")] + public void CreateReturnsClientForSupportedSchemes(string uri) + { + using TrackerClient client = TrackerClient.Create(new Uri(uri), CreateInfoHash(), CreateClientId(), 60); + + Assert.Equal(new Uri(uri), client.TrackerUri); + Assert.Equal(60, client.CustomUpdateInterval); + Assert.Equal(20, client.InfoHash.Length); + Assert.NotNull(client.ClientID); + } + + [Fact] + public void CreateAddsDefaultPortForUdpTrackerWithoutPort() + { + using TrackerClient client = TrackerClient.Create(new Uri("udp://tracker.example/announce"), CreateInfoHash(), CreateClientId()); + + Assert.Equal(80, client.TrackerUri.Port); + } + + [Fact] + public void CreateRejectsUnsupportedScheme() + { + Assert.Throws(() => TrackerClient.Create(new Uri("ftp://tracker.example/announce"), CreateInfoHash(), CreateClientId())); + } + + [Fact] + public void ScheduleUpdateNowMakesNextUpdateDue() + { + using TestTrackerClient client = new TestTrackerClient(customUpdateInterval: 60); + + client.ScheduleUpdateNow(); + + Assert.True(client.NextUpdateIn() <= TimeSpan.Zero); + } + + [Fact] + public async Task UpdateAsyncStoresStateAfterSuccess() + { + using TestTrackerClient client = new TestTrackerClient(); + IPEndPoint endpoint = new IPEndPoint(IPAddress.Parse("192.0.2.1"), 6881); + + await client.UpdateAsync(TrackerClientEvent.Started, endpoint); + + Assert.Equal(endpoint, client.LastClientEP); + Assert.Null(client.LastException); + Assert.Equal(0, client.RetriesDone); + Assert.False(client.IsUpdating); + Assert.Equal(TrackerClientEvent.Started, client.LastEvent); + Assert.Equal(endpoint, client.LastUpdateEndpoint); + } + + [Fact] + public async Task UpdateAsyncStoresInnerExceptionAndRetryStateAfterFailure() + { + using TestTrackerClient client = new TestTrackerClient(); + IPEndPoint endpoint = new IPEndPoint(IPAddress.Parse("192.0.2.1"), 6881); + InvalidOperationException inner = new InvalidOperationException("inner"); + client.ExceptionToThrow = new ApplicationException("outer", inner); + + ApplicationException thrown = await Assert.ThrowsAsync(() => client.UpdateAsync(TrackerClientEvent.Completed, endpoint)); + + Assert.Same(inner, client.LastException); + Assert.Same(inner, thrown.InnerException); + Assert.Equal(1, client.RetriesDone); + Assert.Equal(client.MinimumInterval, client.Interval); + Assert.False(client.IsUpdating); + Assert.Equal(endpoint, client.LastClientEP); + } + + [Fact] + public void EqualsComparesTrackerUriAndInfoHash() + { + byte[] infoHash = CreateInfoHash(); + using TestTrackerClient client = new TestTrackerClient(infoHash: infoHash); + using TestTrackerClient same = new TestTrackerClient(infoHash: infoHash.ToArray()); + using TestTrackerClient differentUri = new TestTrackerClient(trackerUri: new Uri("http://other.example/announce"), infoHash: infoHash.ToArray()); + byte[] differentHash = infoHash.ToArray(); + differentHash[19] = 255; + using TestTrackerClient differentInfoHash = new TestTrackerClient(infoHash: differentHash); + + Assert.True(client.Equals(same)); + Assert.False(client.Equals(null)); + Assert.False(client.Equals("tracker")); + Assert.False(client.Equals(differentUri)); + Assert.False(client.Equals(differentInfoHash)); + Assert.Equal(client.GetHashCode(), client.GetHashCode()); + } + + [Fact] + public void PublicPropertiesExposeMutableSettingsAndTrackerState() + { + using TestTrackerClient client = new TestTrackerClient(); + + client.CustomUpdateInterval = 45; + client.Peers.Add(new IPEndPoint(IPAddress.Parse("192.0.2.2"), 6881)); + + Assert.Equal(45, client.CustomUpdateInterval); + Assert.Null(client.Proxy); + Assert.Single(client.Peers); + Assert.Equal(0, client.Leachers); + Assert.Equal(0, client.Seeders); + } + + private static byte[] CreateInfoHash() + { + return Enumerable.Range(0, 20).Select(Convert.ToByte).ToArray(); + } + + private static TrackerClientID CreateClientId() + { + return new TrackerClientID(Enumerable.Range(20, 20).Select(Convert.ToByte).ToArray(), [1, 2, 3, 4], "agent", "gzip", 50, true, true); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/UdpTrackerClientTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/UdpTrackerClientTests.cs new file mode 100644 index 0000000..fe3a401 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.BitTorrent/UdpTrackerClientTests.cs @@ -0,0 +1,64 @@ +using System.Collections.Generic; +using System.Net; +using System.Reflection; +using TechnitiumLibrary.Net.BitTorrent; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.BitTorrent +{ + public class UdpTrackerClientTests + { + [Fact] + public void ParsePeersIPv4AddsValidPeersAndSkipsUnusableAddresses() + { + List peers = new List(); + byte[] response = new byte[38]; + WriteUdpIPv4(response, 20, IPAddress.Any, 6881); + WriteUdpIPv4(response, 26, IPAddress.Loopback, 6881); + WriteUdpIPv4(response, 32, IPAddress.Parse("192.0.2.10"), 6881); + + InvokeParser("ParsePeersIPv4", response, response.Length, peers); + + IPEndPoint peer = Assert.Single(peers); + Assert.Equal(IPAddress.Parse("192.0.2.10"), peer.Address); + Assert.Equal(6881, peer.Port); + } + + [Fact] + public void ParsePeersIPv6AddsValidPeersAndSkipsUnusableAddresses() + { + List peers = new List(); + byte[] response = new byte[74]; + WriteUdpIPv6(response, 20, IPAddress.IPv6Any, 6881); + WriteUdpIPv6(response, 38, IPAddress.IPv6Loopback, 6881); + WriteUdpIPv6(response, 56, IPAddress.Parse("2001:db8::10"), 6881); + + InvokeParser("ParsePeersIPv6", response, response.Length, peers); + + IPEndPoint peer = Assert.Single(peers); + Assert.Equal(IPAddress.Parse("2001:db8::10"), peer.Address); + Assert.Equal(6881, peer.Port); + } + + private static void InvokeParser(string methodName, byte[] response, int responseLength, List peers) + { + Type type = typeof(TrackerClient).Assembly.GetType("TechnitiumLibrary.Net.BitTorrent.UdpTrackerClient")!; + MethodInfo method = type.GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Static)!; + + method.Invoke(null, [response, responseLength, peers]); + } + + private static void WriteUdpIPv4(byte[] buffer, int offset, IPAddress address, ushort port) + { + address.GetAddressBytes().CopyTo(buffer, offset); + buffer[offset + 4] = (byte)(port >> 8); + buffer[offset + 5] = (byte)port; + } + + private static void WriteUdpIPv6(byte[] buffer, int offset, IPAddress address, ushort port) + { + address.GetAddressBytes().CopyTo(buffer, offset); + buffer[offset + 16] = (byte)(port >> 8); + buffer[offset + 17] = (byte)port; + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Mail/Pop3ClientTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Mail/Pop3ClientTests.cs new file mode 100644 index 0000000..21e6650 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Mail/Pop3ClientTests.cs @@ -0,0 +1,304 @@ +using System.Net.Security; +using System.Reflection; +using System.Text; +using TechnitiumLibrary.Net.Mail; +using TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net.Mail; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Mail +{ + public class Pop3ClientTests + { + [Fact] + public void StatsFieldsCanBeAssigned() + { + Pop3Stats stats = new Pop3Stats + { + TotalMessages = 3, + TotalSize = 4096 + }; + + Assert.Equal(3, stats.TotalMessages); + Assert.Equal(4096, stats.TotalSize); + } + + [Fact] + public void MessageInfoFieldsCanBeAssigned() + { + Pop3MessageInfo info = new Pop3MessageInfo + { + MessageNumber = 2, + MessageSize = 1024 + }; + + Assert.Equal(2, info.MessageNumber); + Assert.Equal(1024, info.MessageSize); + } + + [Fact] + public void ExceptionsPreserveMessageAndInnerException() + { + Exception inner = new Exception("inner"); + + Pop3Exception baseException = new Pop3Exception("base", inner); + Pop3InvalidUsernamePasswordException invalidCredentials = new Pop3InvalidUsernamePasswordException("bad credentials", inner); + + Assert.Equal("base", baseException.Message); + Assert.Same(inner, baseException.InnerException); + Assert.IsAssignableFrom(invalidCredentials); + Assert.Equal("bad credentials", invalidCredentials.Message); + Assert.Same(inner, invalidCredentials.InnerException); + Assert.NotNull(new Pop3Exception().Message); + Assert.NotNull(new Pop3InvalidUsernamePasswordException().Message); + Assert.Equal("bad", new Pop3InvalidUsernamePasswordException("bad").Message); + } + + [Fact] + public void DisposeCanBeCalledMoreThanOnce() + { + Pop3Client client = new Pop3Client("127.0.0.1", 110, "user", "pass"); + + client.Dispose(); + client.Dispose(); + } + + [Fact] + public void RemoteCertificateValidationCallbackAlwaysAllowsCertificate() + { + using Pop3Client client = new Pop3Client("127.0.0.1", 110, "user", "pass"); + MethodInfo callback = typeof(Pop3Client).GetMethod("RemoteCertificateValidationCallback", BindingFlags.NonPublic | BindingFlags.Instance)!; + + bool result = (bool)callback.Invoke(client, [new object(), null, null, SslPolicyErrors.RemoteCertificateChainErrors])!; + + Assert.True(result); + } + + [Fact] + public async Task ConnectAuthenticatesWithUserPassWhenSecureAuthIsNotPreferred() + { + using Pop3TestServer server = new Pop3TestServer("+OK ready"); + server.Enqueue("+OK user"); + server.Enqueue("+OK pass"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass", preferSecureAuth: false); + + client.Connect(); + + Assert.Equal(["USER user", "PASS pass"], server.Commands); + } + + [Fact] + public async Task ConnectAuthenticatesWithApopWhenTimestampIsAvailable() + { + using Pop3TestServer server = new Pop3TestServer("+OK ready <12345@example>"); + server.Enqueue("+OK apop"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass"); + + client.Connect(); + + string command = Assert.Single(server.Commands); + Assert.StartsWith("APOP user ", command); + Assert.Equal("APOP user " + Convert.ToHexString(System.Security.Cryptography.MD5.HashData(Encoding.ASCII.GetBytes("<12345@example>pass"))).ToLowerInvariant(), command); + } + + [Fact] + public async Task CommandsParseSuccessfulResponses() + { + using Pop3TestServer server = new Pop3TestServer("+OK ready"); + server.Enqueue("+OK user"); + server.Enqueue("+OK pass"); + server.Enqueue("+OK 2 512"); + server.Enqueue("+OK list follows", "1 100", "2 412", "."); + server.Enqueue("+OK message follows", "Subject: Test", "", "Body", "."); + server.Enqueue("+OK top follows", "Subject: Test", "."); + server.Enqueue("+OK deleted"); + server.Enqueue("+OK noop"); + server.Enqueue("+OK reset"); + server.Enqueue("+OK bye"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass", preferSecureAuth: false); + + client.Connect(); + Pop3Stats stats = client.STAT(); + Pop3MessageInfo[] list = client.LIST(); + byte[] message = client.RETR(1); + byte[] top = client.TOP(1, 2); + client.DELE(1); + client.NOOP(); + client.RSET(); + client.QUIT(); + + Assert.Equal(2, stats.TotalMessages); + Assert.Equal(512, stats.TotalSize); + Assert.Equal(2, list.Length); + Assert.Equal(1, list[0].MessageNumber); + Assert.Equal(100, list[0].MessageSize); + Assert.Equal("Subject: Test\r\n\r\nBody\r\n", Encoding.ASCII.GetString(message)); + Assert.Equal("Subject: Test\r\n", Encoding.ASCII.GetString(top)); + Assert.Contains("STAT", server.Commands); + Assert.Contains("LIST", server.Commands); + Assert.Contains("RETR 1", server.Commands); + Assert.Contains("TOP 1 2", server.Commands); + Assert.Contains("DELE 1", server.Commands); + Assert.Contains("NOOP", server.Commands); + Assert.Contains("RSET", server.Commands); + Assert.Contains("QUIT", server.Commands); + } + + [Fact] + public async Task ConnectRejectsSecondConnectionUntilClosed() + { + using Pop3TestServer server = new Pop3TestServer("+OK ready"); + server.Enqueue("+OK user"); + server.Enqueue("+OK pass"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass", preferSecureAuth: false); + + client.Connect(); + + Assert.Throws(() => client.Connect()); + client.Close(); + client.Close(); + } + + [Fact] + public async Task ConnectThrowsWhenGreetingIsError() + { + using Pop3TestServer server = new Pop3TestServer("-ERR down"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass"); + + Pop3Exception exception = Assert.Throws(() => client.Connect()); + Assert.Equal("Server returned: down", exception.Message); + } + + [Theory] + [InlineData("-ERR bad user", typeof(Pop3InvalidUsernamePasswordException))] + [InlineData("+OK user", typeof(Pop3InvalidUsernamePasswordException))] + public async Task ConnectThrowsInvalidCredentialsWhenAuthenticationFails(string firstAuthResponse, Type expectedExceptionType) + { + using Pop3TestServer server = new Pop3TestServer("+OK ready"); + server.Enqueue(firstAuthResponse); + if (firstAuthResponse.StartsWith("+OK", StringComparison.Ordinal)) + server.Enqueue("-ERR bad pass"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass", preferSecureAuth: false); + + Exception exception = Assert.Throws(expectedExceptionType, () => client.Connect()); + Assert.StartsWith("Server returned: bad", exception.Message); + } + + [Fact] + public async Task CommandThrowsWhenServerReturnsError() + { + using Pop3TestServer server = new Pop3TestServer("+OK ready"); + server.Enqueue("+OK user"); + server.Enqueue("+OK pass"); + server.Enqueue("-ERR stat failed"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass", preferSecureAuth: false); + + client.Connect(); + Pop3Exception exception = Assert.Throws(() => client.STAT()); + + Assert.Equal("Server returned: stat failed", exception.Message); + } + + [Theory] + [InlineData("QUIT")] + [InlineData("LIST")] + [InlineData("RETR")] + [InlineData("TOP")] + [InlineData("DELE")] + [InlineData("NOOP")] + [InlineData("RSET")] + public async Task CommandsThrowWhenServerReturnsError(string commandName) + { + using Pop3TestServer server = new Pop3TestServer("+OK ready"); + server.Enqueue("+OK user"); + server.Enqueue("+OK pass"); + server.Enqueue("-ERR command failed"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass", preferSecureAuth: false); + + client.Connect(); + Pop3Exception exception = Assert.Throws(() => InvokeCommand(client, commandName)); + + Assert.Equal("Server returned: command failed", exception.Message); + } + + [Fact] + public async Task CommandThrowsWhenServerClosesConnection() + { + using Pop3TestServer server = new Pop3TestServer("+OK ready"); + server.Enqueue("+OK user"); + server.Enqueue("+OK pass"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass", preferSecureAuth: false); + + client.Connect(); + Exception exception = Record.Exception(() => client.STAT()); + + Assert.True(exception is Pop3Exception or IOException); + if (exception is Pop3Exception pop3Exception) + Assert.Equal("No response from server.", pop3Exception.Message); + } + + [Fact] + public async Task ConnectThrowsInvalidCredentialsWhenApopFails() + { + using Pop3TestServer server = new Pop3TestServer("+OK ready <12345@example>"); + server.Enqueue("-ERR bad apop"); + await server.StartAsync(); + + using Pop3Client client = new Pop3Client("127.0.0.1", server.Port, "user", "pass"); + + Pop3InvalidUsernamePasswordException exception = Assert.Throws(() => client.Connect()); + + Assert.Equal("Server returned: bad apop", exception.Message); + } + + private static void InvokeCommand(Pop3Client client, string commandName) + { + switch (commandName) + { + case "QUIT": + client.QUIT(); + break; + + case "LIST": + client.LIST(); + break; + + case "RETR": + client.RETR(1); + break; + + case "TOP": + client.TOP(1, 1); + break; + + case "DELE": + client.DELE(1); + break; + + case "NOOP": + client.NOOP(); + break; + + case "RSET": + client.RSET(); + break; + } + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Mail/SmtpClientExTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Mail/SmtpClientExTests.cs new file mode 100644 index 0000000..22349e2 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Mail/SmtpClientExTests.cs @@ -0,0 +1,176 @@ +using System.Net; +using System.Net.Mail; +using System.Net.Security; +using System.Reflection; +using TechnitiumLibrary.Net.Mail; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Mail +{ + public class SmtpClientExTests + { + [Fact] + public void ConstructorsAndPropertiesExposeConfiguredValues() + { + using SmtpClientEx defaultClient = new SmtpClientEx(); + using SmtpClientEx hostClient = new SmtpClientEx("smtp.example"); + using SmtpClientEx hostPortClient = new SmtpClientEx("smtp.example", 2525); + + Assert.Null(defaultClient.Host); + Assert.Equal("smtp.example", hostClient.Host); + Assert.Equal(25, hostClient.Port); + Assert.Equal("smtp.example", hostPortClient.Host); + Assert.Equal(2525, hostPortClient.Port); + + hostPortClient.Host = "mail.example"; + hostPortClient.Port = 587; + hostPortClient.SmtpOverTls = true; + hostPortClient.IgnoreCertificateErrors = true; + hostPortClient.DnsClient = null; + hostPortClient.Proxy = null; + + Assert.Equal("mail.example", hostPortClient.Host); + Assert.Equal(587, hostPortClient.Port); + Assert.True(hostPortClient.SmtpOverTls); + Assert.True(hostPortClient.IgnoreCertificateErrors); + Assert.Null(hostPortClient.DnsClient); + Assert.Null(hostPortClient.Proxy); + } + + [Fact] + public void LocalHostNameCanBeSetAndRandomized() + { + using SmtpClientEx client = new SmtpClientEx(); + + client.LocalHostName = "local.example"; + Assert.Equal("local.example", client.LocalHostName); + + client.SetRandomLocalHostName(); + + Assert.False(string.IsNullOrWhiteSpace(client.LocalHostName)); + Assert.Equal(8, client.LocalHostName.Length); + } + + [Fact] + public void StaticIgnoreCertificateErrorsForProxyCanBeChanged() + { + bool original = SmtpClientEx.IgnoreCertificateErrorsForProxy; + + try + { + SmtpClientEx.IgnoreCertificateErrorsForProxy = true; + Assert.True(SmtpClientEx.IgnoreCertificateErrorsForProxy); + + SmtpClientEx.IgnoreCertificateErrorsForProxy = false; + Assert.False(SmtpClientEx.IgnoreCertificateErrorsForProxy); + } + finally + { + SmtpClientEx.IgnoreCertificateErrorsForProxy = original; + } + } + + [Fact] + public void SendAsyncOverloadsAreNotSupported() + { + using SmtpClientEx client = new SmtpClientEx(); + using MailMessage message = new MailMessage(); + + Assert.Throws(() => client.SendAsync("from@example.com", "to@example.com", "subject", "body", new object())); + Assert.Throws(() => client.SendAsync(message, new object())); + } + + [Fact] + public async Task SendMailAsyncRejectsMessageWithoutRecipientsBeforeNetworkAccess() + { + using SmtpClientEx client = new SmtpClientEx("smtp.example", 25); + using MailMessage message = new MailMessage(); + message.From = new MailAddress("from@example.com"); + + ArgumentException exception = await Assert.ThrowsAsync(() => client.SendMailAsync(message)); + + Assert.Equal("Message does not contain receipent email address.", exception.Message); + } + + [Fact] + public async Task SendMailAsyncThrowsWhenDisposed() + { + SmtpClientEx client = new SmtpClientEx("smtp.example", 25); + using MailMessage message = new MailMessage("from@example.com", "to@example.com", "subject", "body"); + + client.Dispose(); + client.Dispose(); + + await Assert.ThrowsAsync(() => client.SendMailAsync(message)); + } + + [Fact] + public async Task SendMailAsyncUsesBaseClientForPickupDirectoryDelivery() + { + string pickupDirectory = Path.Combine(Path.GetTempPath(), "TechnitiumLibraryTests", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(pickupDirectory); + + try + { + using SmtpClientEx client = new SmtpClientEx("smtp.example", 25); + client.DeliveryMethod = SmtpDeliveryMethod.SpecifiedPickupDirectory; + client.PickupDirectoryLocation = pickupDirectory; + + await client.SendMailAsync("from@example.com", "to@example.com", "subject", "body"); + + string messageFile = Assert.Single(Directory.GetFiles(pickupDirectory)); + string message = File.ReadAllText(messageFile); + Assert.Contains("subject", message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("body", message, StringComparison.OrdinalIgnoreCase); + } + finally + { + if (Directory.Exists(pickupDirectory)) + Directory.Delete(pickupDirectory, recursive: true); + } + } + + [Fact] + public void SendUsesBaseClientForPickupDirectoryDelivery() + { + string pickupDirectory = Path.Combine(Path.GetTempPath(), "TechnitiumLibraryTests", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(pickupDirectory); + + try + { + using SmtpClientEx client = new SmtpClientEx("smtp.example", 25); + client.DeliveryMethod = SmtpDeliveryMethod.SpecifiedPickupDirectory; + client.PickupDirectoryLocation = pickupDirectory; + + client.Send("from@example.com", "to@example.com", "subject", "body"); + + Assert.Single(Directory.GetFiles(pickupDirectory)); + } + finally + { + if (Directory.Exists(pickupDirectory)) + Directory.Delete(pickupDirectory, recursive: true); + } + } + + [Fact] + public void ServerCertificateValidationCallbackHandlesSmtpClientAndFallbackSender() + { + using SmtpClientEx client = new SmtpClientEx(); + MethodInfo callback = typeof(SmtpClientEx).GetMethod("ServerCertificateValidationCallback", BindingFlags.NonPublic | BindingFlags.Static)!; + + Assert.True(InvokeCertificateCallback(callback, client, SslPolicyErrors.None)); + Assert.False(InvokeCertificateCallback(callback, client, SslPolicyErrors.RemoteCertificateChainErrors)); + + client.IgnoreCertificateErrors = true; + Assert.True(InvokeCertificateCallback(callback, client, SslPolicyErrors.RemoteCertificateChainErrors)); + + Assert.True(InvokeCertificateCallback(callback, new object(), SslPolicyErrors.None)); + Assert.False(InvokeCertificateCallback(callback, new object(), SslPolicyErrors.RemoteCertificateChainErrors)); + } + + private static bool InvokeCertificateCallback(MethodInfo callback, object sender, SslPolicyErrors sslPolicyErrors) + { + return (bool)callback.Invoke(null, [sender, null, null, sslPolicyErrors])!; + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Tor/TorControllerTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Tor/TorControllerTests.cs new file mode 100644 index 0000000..1d0df9a --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Tor/TorControllerTests.cs @@ -0,0 +1,175 @@ +using System.Net; +using System.Net.Sockets; +using System.Reflection; +using TechnitiumLibrary.Net.Tor; +using TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net.Tor; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Tor +{ + public class TorControllerTests + { + [Fact] + public void ConstructorRequiresExistingTorExecutablePath() + { + Assert.Throws(() => new TorController(Path.Combine(Path.GetTempPath(), Guid.NewGuid() + ".exe"))); + + using TorController controller = CreateController(); + + Assert.Equal(typeof(TorController).Assembly.Location, controller.TorExecutableFile); + Assert.False(controller.IsRunning); + } + + [Fact] + public void ConfigurationPropertiesCanBeSetBeforeStart() + { + using TorController controller = CreateController(); + NetworkCredential credential = new NetworkCredential("user", "pass"); + IPEndPoint socksEndpoint = new IPEndPoint(IPAddress.Loopback, 19050); + + controller.ControlPort = 19051; + controller.Socks5EndPoint = socksEndpoint; + controller.ProxyType = TorProxyType.Socks5; + controller.ProxyHost = "proxy.example.test"; + controller.ProxyPort = 1080; + controller.ProxyCredential = credential; + + Assert.Equal(19051, controller.ControlPort); + Assert.Same(socksEndpoint, controller.Socks5EndPoint); + Assert.Equal(TorProxyType.Socks5, controller.ProxyType); + Assert.Equal("proxy.example.test", controller.ProxyHost); + Assert.Equal(1080, controller.ProxyPort); + Assert.Same(credential, controller.ProxyCredential); + } + + [Fact] + public async Task SignalCommandsWriteTorControlCommandsAndAcceptSuccessResponses() + { + using TorControlTestServer server = CreateStartedServer(); + server.Enqueue("250 OK"); + server.Enqueue("250 OK"); + server.Enqueue("250 OK"); + server.Enqueue("250 OK"); + using TorController controller = await CreateConnectedControllerAsync(server); + + controller.SwitchCircuit(); + controller.ClearDnsCache(); + controller.ImmediateShutdown(); + controller.Shutdown(); + + Assert.Equal( + [ + "SIGNAL NEWNYM", + "SIGNAL CLEARDNSCACHE", + "SIGNAL HALT", + "SIGNAL SHUTDOWN" + ], server.Commands); + } + + [Fact] + public async Task SignalCommandsThrowWhenControlServerReturnsError() + { + using TorControlTestServer server = CreateStartedServer(); + server.Enqueue("551 bad signal"); + using TorController controller = await CreateConnectedControllerAsync(server); + + TorControllerException exception = Assert.Throws(() => controller.SwitchCircuit()); + + Assert.Equal("Server returned: 551 bad signal", exception.Message); + Assert.Equal("SIGNAL NEWNYM", Assert.Single(server.Commands)); + } + + [Fact] + public async Task CreateHiddenServiceParsesMultiLineResponse() + { + using TorControlTestServer server = CreateStartedServer(); + server.Enqueue( + "250-ServiceID=examplehiddenservice", + "250-PrivateKey=ED25519-V3:private-key", + "250-ClientAuth=alice:cookie", + "250 OK"); + using TorController controller = await CreateConnectedControllerAsync(server); + + TorHiddenServiceInfo info = controller.CreateHiddenService(443, new IPEndPoint(IPAddress.Loopback, 8443), "alice", "cookie"); + + Assert.Equal("examplehiddenservice", info.ServiceId); + Assert.Equal("ED25519-V3:private-key", info.PrivateKey); + Assert.Equal("alice", info.ClientBasicAuthenticationUsername); + Assert.Equal("cookie", info.ClientBasicAuthenticationCookie); + Assert.Equal("ADD_ONION NEW:BEST Flags=BasicAuth Port=443,127.0.0.1:8443 ClientAuth=alice:cookie", Assert.Single(server.Commands)); + } + + [Fact] + public async Task CreateHiddenServiceWithPrivateKeyWritesExpectedCommand() + { + using TorControlTestServer server = CreateStartedServer(); + server.Enqueue("250-ServiceID=restoredservice", "250 OK"); + using TorController controller = await CreateConnectedControllerAsync(server); + + TorHiddenServiceInfo info = controller.CreateHiddenService(80, "ED25519-V3:private-key", new IPEndPoint(IPAddress.Loopback, 8080), "bob"); + + Assert.Equal("restoredservice", info.ServiceId); + Assert.Equal("ADD_ONION ED25519-V3:private-key Flags=BasicAuth Port=80,127.0.0.1:8080 ClientAuth=bob:", Assert.Single(server.Commands)); + } + + [Fact] + public async Task HiddenServiceParsingThrowsOnErrorResponse() + { + using TorControlTestServer server = CreateStartedServer(); + server.Enqueue("551 onion failed"); + using TorController controller = await CreateConnectedControllerAsync(server); + + TorControllerException exception = Assert.Throws(() => controller.CreateHiddenService(80)); + + Assert.Equal("Server returned: 551 onion failed", exception.Message); + } + + [Fact] + public async Task DeleteHiddenServiceWritesCommandAndHandlesErrors() + { + using TorControlTestServer server = CreateStartedServer(); + server.Enqueue("250 OK"); + server.Enqueue("552 unknown service"); + using TorController controller = await CreateConnectedControllerAsync(server); + + controller.DeleteHiddenService("serviceid"); + TorControllerException exception = Assert.Throws(() => controller.DeleteHiddenService("missing")); + + Assert.Equal("Server returned: 552 unknown service", exception.Message); + Assert.Equal(["DEL_ONION serviceid", "DEL_ONION missing"], server.Commands); + } + + private static TorControlTestServer CreateStartedServer() + { + TorControlTestServer server = new TorControlTestServer(); + server.Start(); + return server; + } + + private static TorController CreateController() + { + return new TorController(typeof(TorController).Assembly.Location); + } + + private static async Task CreateConnectedControllerAsync(TorControlTestServer server) + { + TorController controller = CreateController(); + Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await socket.ConnectAsync(IPAddress.Loopback, server.Port); + NetworkStream stream = new NetworkStream(socket, ownsSocket: true); + StreamReader reader = new StreamReader(stream); + StreamWriter writer = new StreamWriter(stream) { AutoFlush = true }; + + SetField(controller, "_socket", socket); + SetField(controller, "_sR", reader); + SetField(controller, "_sW", writer); + + return controller; + } + + private static void SetField(TorController controller, string fieldName, object value) + { + typeof(TorController).GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance)!.SetValue(controller, value); + } + + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Tor/TorProjectTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Tor/TorProjectTests.cs new file mode 100644 index 0000000..f9a0106 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Tor/TorProjectTests.cs @@ -0,0 +1,31 @@ +using TechnitiumLibrary.Net.Tor; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Tor +{ + public class TorProjectTests + { + [Fact] + public void TorControllerException_PreservesMessageAndInnerException() + { + InvalidOperationException inner = new InvalidOperationException("inner"); + TorControllerException defaultException = new TorControllerException(); + TorControllerException messageException = new TorControllerException("controller failed"); + TorControllerException ex = new TorControllerException("controller failed", inner); + + Assert.NotNull(defaultException.Message); + Assert.Equal("controller failed", messageException.Message); + Assert.Equal("controller failed", ex.Message); + Assert.Same(inner, ex.InnerException); + } + + [Fact] + public void TorProxyType_ValuesRemainStable() + { + Assert.Equal(0, (int)TorProxyType.None); + Assert.Equal(1, (int)TorProxyType.Http); + Assert.Equal(2, (int)TorProxyType.Https); + Assert.Equal(3, (int)TorProxyType.Socks4); + Assert.Equal(4, (int)TorProxyType.Socks5); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.UPnP/UPnPProjectTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.UPnP/UPnPProjectTests.cs new file mode 100644 index 0000000..dc7644a --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.UPnP/UPnPProjectTests.cs @@ -0,0 +1,43 @@ +using System.Net; +using System.Net.Sockets; +using TechnitiumLibrary.Net.UPnP; +using TechnitiumLibrary.Net.UPnP.Networking; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.UPnP +{ + public class UPnPProjectTests + { + [Fact] + public void GenericPortMappingEntry_ExposesExternalAndInternalEndpointData() + { + IPAddress remoteHost = IPAddress.Parse("203.0.113.10"); + IPAddress internalClient = IPAddress.Parse("192.168.1.5"); + GenericPortMappingEntry entry = new GenericPortMappingEntry( + remoteHost, + externalPort: 8443, + ProtocolType.Tcp, + internalPort: 443, + internalClient, + enabled: true, + description: "https", + leaseDuration: 3600); + + Assert.Equal(new IPEndPoint(remoteHost, 8443), entry.ExternalEP); + Assert.Equal(new IPEndPoint(internalClient, 443), entry.InternalEP); + Assert.Equal(ProtocolType.Tcp, entry.Protocol); + Assert.True(entry.Enabled); + Assert.Equal("https", entry.Description); + Assert.Equal(3600, entry.LeaseDuration); + } + + [Fact] + public void UPnPException_PreservesMessageAndInnerException() + { + Exception inner = new Exception("inner"); + UPnPException ex = new UPnPException("failed", inner); + + Assert.Equal("failed", ex.Message); + Assert.Same(inner, ex.InnerException); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/DnsDatagramTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/DnsDatagramTests.cs new file mode 100644 index 0000000..dc9376b --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/DnsDatagramTests.cs @@ -0,0 +1,238 @@ +using System.Net; +using System.Text.Json; +using TechnitiumLibrary.Net; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.EDnsOptions; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Dns +{ + public class DnsDatagramTests + { + [Fact] + public void DatagramRoundTripsQuestionsRecordsAndEdns() + { + DnsDatagram datagram = CreateResponse( + answers: + [ + new DnsResourceRecord("alias.example.test", DnsResourceRecordType.CNAME, DnsClass.IN, 60, new DnsCNAMERecordData("example.test")), + new DnsResourceRecord("example.test", DnsResourceRecordType.A, DnsClass.IN, 60, new DnsARecordData(IPAddress.Parse("192.0.2.10"))) + ], + authority: + [ + new DnsResourceRecord("example.test", DnsResourceRecordType.NS, DnsClass.IN, 300, new DnsNSRecordData("ns1.example.test")) + ], + additional: + [ + new DnsResourceRecord("ns1.example.test", DnsResourceRecordType.A, DnsClass.IN, 300, new DnsARecordData(IPAddress.Parse("192.0.2.53"))) + ]); + + DnsDatagram parsed = RoundTrip(datagram); + + Assert.Equal(datagram.Identifier, parsed.Identifier); + Assert.True(parsed.IsResponse); + Assert.True(parsed.AuthoritativeAnswer); + Assert.True(parsed.RecursionDesired); + Assert.True(parsed.RecursionAvailable); + Assert.True(parsed.AuthenticData); + Assert.True(parsed.CheckingDisabled); + Assert.True(parsed.DnssecOk); + Assert.Equal(DnsResponseCode.NoError, parsed.RCODE); + Assert.Single(parsed.Question); + Assert.Equal(2, parsed.Answer.Count); + Assert.Single(parsed.Authority); + Assert.Equal(2, parsed.Additional.Count); + Assert.NotNull(parsed.EDNS); + Assert.Equal(DnsResourceRecordType.A, parsed.GetLastAnswerRecord().Type); + Assert.Equal(DnsResourceRecordType.NS, parsed.FindFirstAuthorityType()); + Assert.False(parsed.IsFirstAuthoritySOA()); + Assert.False(parsed.IsFirstAuthoritySOAOrFWDOrAPP()); + AssertJsonCanBeWritten(parsed); + } + + [Fact] + public async Task DatagramRoundTripsThroughTcpFrame() + { + DnsDatagram datagram = CreateResponse( + answers: + [ + new DnsResourceRecord("example.test", DnsResourceRecordType.A, DnsClass.IN, 60, new DnsARecordData(IPAddress.Parse("192.0.2.10"))) + ]); + using MemoryStream stream = new MemoryStream(); + + await datagram.WriteToTcpAsync(stream); + stream.Position = 0; + DnsDatagram parsed = await DnsDatagram.ReadFromTcpAsync(stream); + + Assert.Equal(datagram.Identifier, parsed.Identifier); + Assert.Equal(IPAddress.Parse("192.0.2.10"), ((DnsARecordData)parsed.Answer[0].RDATA).Address); + } + + [Fact] + public void CloneHelpersPreserveAndRemoveExpectedSections() + { + DnsDatagram datagram = CreateResponse( + additional: + [ + new DnsResourceRecord("ns1.example.test", DnsResourceRecordType.A, DnsClass.IN, 300, new DnsARecordData(IPAddress.Parse("192.0.2.53"))) + ]); + + DnsDatagram clone = datagram.Clone(); + DnsDatagram withoutEdns = datagram.CloneWithoutEDns(); + DnsDatagram withoutGlue = datagram.CloneWithoutGlueRecords(); + DnsDatagram withoutEcs = datagram.CloneWithoutEDnsClientSubnet(); + + Assert.NotSame(datagram, clone); + Assert.Equal(datagram.Identifier, clone.Identifier); + Assert.Null(withoutEdns.EDNS); + Assert.Single(withoutEdns.Additional); + Assert.Single(withoutGlue.Additional); + Assert.Equal(DnsResourceRecordType.OPT, withoutGlue.Additional[0].Type); + Assert.Null(withoutEcs.GetEDnsClientSubnetOption()); + Assert.NotNull(datagram.GetEDnsClientSubnetOption()); + } + + [Fact] + public void ShadowClientSubnetOverridesAndCanBeHidden() + { + DnsDatagram datagram = CreateResponse(); + + datagram.SetShadowEDnsClientSubnetOption(new NetworkAddress(IPAddress.Parse("198.51.100.0"), 24), advancedForwardingClientSubnet: true); + + EDnsClientSubnetOptionData shadow = datagram.GetEDnsClientSubnetOption()!; + Assert.Equal(IPAddress.Parse("198.51.100.0"), shadow.Address); + Assert.Equal(24, shadow.ScopePrefixLength); + Assert.True(shadow.AdvancedForwardingClientSubnet); + Assert.Equal(IPAddress.Parse("192.0.2.0"), datagram.GetEDnsClientSubnetOption(noShadow: true)!.Address); + + datagram.ShadowHideEDnsClientSubnetOption(); + + Assert.Null(datagram.GetEDnsClientSubnetOption()); + Assert.NotNull(datagram.GetEDnsClientSubnetOption(noShadow: true)); + } + + [Fact] + public void BlockedAndReferrerResponseDetectionUsesResponseShape() + { + DnsDatagram blockedByEde = CreateResponse( + options: + [ + new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, "policy")) + ], + authoritativeAnswer: false); + DnsDatagram blockedByAddress = CreateResponse( + answers: + [ + new DnsResourceRecord("example.test", DnsResourceRecordType.A, DnsClass.IN, 60, new DnsARecordData(IPAddress.Any)) + ], + options: [], + authoritativeAnswer: false); + DnsDatagram referrer = CreateResponse( + answers: [], + authority: + [ + new DnsResourceRecord("example.test", DnsResourceRecordType.NS, DnsClass.IN, 300, new DnsNSRecordData("ns1.example.test")) + ], + options: []); + + Assert.True(blockedByEde.IsBlockedResponse()); + Assert.True(blockedByAddress.IsBlockedResponse()); + Assert.False(referrer.IsBlockedResponse()); + Assert.True(referrer.IsReferrerResponse()); + } + + [Fact] + public void MetadataAndDnssecStatusCanBeAppliedAndSerialized() + { + DnsDatagram datagram = CreateResponse( + answers: + [ + new DnsResourceRecord("example.test", DnsResourceRecordType.A, DnsClass.IN, 60, new DnsARecordData(IPAddress.Parse("192.0.2.10"))) + ]); + NameServerAddress server = new NameServerAddress(IPAddress.Parse("192.0.2.53"), DnsTransportProtocol.Udp); + + datagram.SetMetadata(server, 12.5); + datagram.SetDnssecStatusForAllRecords(DnssecStatus.Secure); + + Assert.Same(server, datagram.Metadata.NameServer); + Assert.Equal(DnsTransportProtocol.Udp, datagram.Metadata.Protocol); + Assert.Equal(12.5, datagram.Metadata.RoundTripTime); + Assert.Equal(DnssecStatus.Secure, datagram.Answer[0].DnssecStatus); + Assert.Equal(DnssecStatus.Indeterminate, datagram.Additional.Last().DnssecStatus); + Assert.Throws(() => datagram.SetMetadata(server)); + + using MemoryStream cacheStream = new MemoryStream(); + using (BinaryWriter writer = new BinaryWriter(cacheStream, System.Text.Encoding.UTF8, leaveOpen: true)) + datagram.Metadata.WriteTo(writer); + + cacheStream.Position = 0; + using BinaryReader reader = new BinaryReader(cacheStream); + DnsDatagramMetadata parsed = new DnsDatagramMetadata(reader); + + Assert.Equal(datagram.Metadata.NameServer, parsed.NameServer); + Assert.Equal(datagram.Metadata.DatagramSize, parsed.DatagramSize); + Assert.Equal(datagram.Metadata.RoundTripTime, parsed.RoundTripTime); + AssertJsonCanBeWritten(datagram); + } + + [Fact] + public void SplitRejectsUnsupportedDatagrams() + { + DnsDatagram query = new DnsDatagram(1, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, [new DnsQuestionRecord("example.test", DnsResourceRecordType.A, DnsClass.IN)]); + DnsDatagram response = CreateResponse(); + + Assert.Throws(() => query.Split()); + Assert.Throws(() => response.Split()); + Assert.False(query.IsZoneTransfer); + } + + private static DnsDatagram CreateResponse( + IReadOnlyList? answers = null, + IReadOnlyList? authority = null, + IReadOnlyList? additional = null, + IReadOnlyList? options = null, + bool authoritativeAnswer = true) + { + options ??= + [ + new EDnsOption(EDnsOptionCode.EDNS_CLIENT_SUBNET, new EDnsClientSubnetOptionData(24, 0, IPAddress.Parse("192.0.2.0"))) + ]; + + return new DnsDatagram( + 0x1234, + true, + DnsOpcode.StandardQuery, + authoritativeAnswer, + false, + true, + true, + true, + true, + DnsResponseCode.NoError, + [new DnsQuestionRecord("example.test", DnsResourceRecordType.A, DnsClass.IN)], + answers ?? [], + authority ?? [], + additional ?? [], + DnsDatagram.EDNS_DEFAULT_UDP_PAYLOAD_SIZE, + EDnsHeaderFlags.DNSSEC_OK, + options); + } + + private static DnsDatagram RoundTrip(DnsDatagram datagram) + { + using MemoryStream stream = new MemoryStream(); + datagram.WriteTo(stream); + stream.Position = 0; + return DnsDatagram.ReadFrom(stream); + } + + private static void AssertJsonCanBeWritten(DnsDatagram datagram) + { + using MemoryStream stream = new MemoryStream(); + using (Utf8JsonWriter writer = new Utf8JsonWriter(stream)) + datagram.SerializeTo(writer); + + Assert.True(stream.Length > 0); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/NameServerAddressTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/NameServerAddressTests.cs new file mode 100644 index 0000000..458628c --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/NameServerAddressTests.cs @@ -0,0 +1,147 @@ +using System.Net; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Dns +{ + public class NameServerAddressTests + { + [Theory] + [InlineData("192.0.2.53", DnsTransportProtocol.Udp, "192.0.2.53", 53, true)] + [InlineData("192.0.2.53:5353", DnsTransportProtocol.Udp, "192.0.2.53", 5353, false)] + [InlineData("[2001:db8::53]", DnsTransportProtocol.Udp, "2001:db8::53", 53, true)] + [InlineData("tcp://dns.example.test:5353", DnsTransportProtocol.Tcp, "dns.example.test", 5353, false)] + [InlineData("tls://dns.example.test", DnsTransportProtocol.Tls, "dns.example.test", 853, true)] + [InlineData("quic://dns.example.test", DnsTransportProtocol.Quic, "dns.example.test", 853, true)] + [InlineData("https://dns.example.test/dns-query", DnsTransportProtocol.Https, "dns.example.test", 443, true)] + [InlineData("h3://dns.example.test/dns-query", DnsTransportProtocol.Https, "dns.example.test", 443, true)] + public void ParseGuessesProtocolHostPortAndDefaultPort(string value, DnsTransportProtocol protocol, string host, int port, bool isDefaultPort) + { + NameServerAddress address = NameServerAddress.Parse(value); + + Assert.Equal(protocol, address.Protocol); + Assert.Equal(host, address.Host); + Assert.Equal(port, address.Port); + Assert.Equal(isDefaultPort, address.IsDefaultPort); + Assert.Equal(value, address.OriginalAddress); + Assert.NotNull(address.EndPoint); + Assert.Contains(host.Trim('[', ']'), address.ToString()); + } + + [Fact] + public void ParseSupportsPinnedIpAddressBesideDomainOrDohEndpoint() + { + NameServerAddress domainPinned = NameServerAddress.Parse("dns.example.test (192.0.2.53)"); + NameServerAddress dohPinned = NameServerAddress.Parse("https://dns.example.test/dns-query (192.0.2.54)"); + + Assert.Equal("dns.example.test", domainPinned.DomainEndPoint!.Address); + Assert.Equal(IPAddress.Parse("192.0.2.53"), domainPinned.IPEndPoint!.Address); + Assert.Equal(DnsTransportProtocol.Udp, domainPinned.Protocol); + Assert.Equal("dns.example.test (192.0.2.53)", domainPinned.ToString()); + + Assert.Equal(DnsTransportProtocol.Https, dohPinned.Protocol); + Assert.Equal(new Uri("https://dns.example.test/dns-query"), dohPinned.DoHEndPoint); + Assert.Equal(IPAddress.Parse("192.0.2.54"), dohPinned.IPEndPoint!.Address); + Assert.Equal("https://dns.example.test/dns-query (192.0.2.54)", dohPinned.ToString()); + } + + [Fact] + public void ExplicitProtocolValidationRejectsMismatchedAddressKinds() + { + Assert.Throws(() => NameServerAddress.Parse("https://dns.example.test/dns-query", DnsTransportProtocol.Udp)); + Assert.Throws(() => NameServerAddress.Parse("192.0.2.53:53", DnsTransportProtocol.Tls)); + Assert.Throws(() => NameServerAddress.Parse("192.0.2.53:853", DnsTransportProtocol.Udp)); + Assert.Throws(() => NameServerAddress.Parse("dns.example.test:853 (192.0.2.53:53)")); + Assert.Throws(() => NameServerAddress.Parse("[not-ipv6]")); + } + + [Fact] + public void ConstructorsCloneAndBinarySerializationPreserveAddress() + { + NameServerAddress original = new NameServerAddress("dns.example.test", new IPEndPoint(IPAddress.Parse("192.0.2.53"), 5353), DnsTransportProtocol.Tcp); + NameServerAddress clonedIp = original.Clone(IPAddress.Parse("198.51.100.53")); + NameServerAddress clonedTls = original.Clone(DnsTransportProtocol.Tls); + NameServerAddress sameProtocol = original.Clone(DnsTransportProtocol.Tcp); + + Assert.Equal(DnsTransportProtocol.Tcp, original.Protocol); + Assert.Equal("dns.example.test", original.DomainEndPoint!.Address); + Assert.Equal(5353, original.Port); + Assert.Equal(IPAddress.Parse("198.51.100.53"), clonedIp.IPEndPoint!.Address); + Assert.Equal(5353, clonedIp.Port); + Assert.Equal(DnsTransportProtocol.Tls, clonedTls.Protocol); + Assert.Equal(5353, clonedTls.Port); + Assert.Same(original, sameProtocol); + Assert.NotEqual(original, clonedIp); + + using MemoryStream stream = new MemoryStream(); + using (BinaryWriter writer = new BinaryWriter(stream, System.Text.Encoding.UTF8, leaveOpen: true)) + original.WriteTo(writer); + + stream.Position = 0; + using BinaryReader reader = new BinaryReader(stream); + NameServerAddress parsed = new NameServerAddress(reader); + + Assert.Equal(original, parsed); + Assert.Equal(original.GetHashCode(), parsed.GetHashCode()); + } + + [Fact] + public void GetNameServersFromResponseUsesGlueAndFiltersLoopback() + { + DnsDatagram response = new DnsDatagram( + 1, + true, + DnsOpcode.StandardQuery, + true, + false, + true, + true, + false, + false, + DnsResponseCode.NoError, + [new DnsQuestionRecord("example.test", DnsResourceRecordType.NS, DnsClass.IN)], + [ + new DnsResourceRecord("example.test", DnsResourceRecordType.NS, DnsClass.IN, 300, new DnsNSRecordData("ns1.example.test")), + new DnsResourceRecord("example.test", DnsResourceRecordType.NS, DnsClass.IN, 300, new DnsNSRecordData("ns2.example.test")) + ], + [], + [ + new DnsResourceRecord("ns1.example.test", DnsResourceRecordType.A, DnsClass.IN, 300, new DnsARecordData(IPAddress.Parse("192.0.2.53"))), + new DnsResourceRecord("ns1.example.test", DnsResourceRecordType.AAAA, DnsClass.IN, 300, new DnsAAAARecordData(IPAddress.Parse("2001:db8::53"))), + new DnsResourceRecord("ns2.example.test", DnsResourceRecordType.A, DnsClass.IN, 300, new DnsARecordData(IPAddress.Loopback)) + ]); + + List servers = NameServerAddress.GetNameServersFromResponse(response, IPv6Mode.Enabled, filterLoopbackAddresses: true); + + Assert.Equal(2, servers.Count); + Assert.Contains(servers, server => server.IPEndPoint!.Address.Equals(IPAddress.Parse("192.0.2.53"))); + Assert.Contains(servers, server => server.IPEndPoint!.Address.Equals(IPAddress.Parse("2001:db8::53"))); + Assert.DoesNotContain(servers, server => IPAddress.IsLoopback(server.IPEndPoint!.Address)); + } + + [Fact] + public void GetNameServersFromResponseReturnsDomainEndpointWhenGlueIsMissing() + { + DnsDatagram response = new DnsDatagram( + 1, + true, + DnsOpcode.StandardQuery, + true, + false, + true, + true, + false, + false, + DnsResponseCode.NoError, + [new DnsQuestionRecord("www.example.test", DnsResourceRecordType.A, DnsClass.IN)], + [], + [new DnsResourceRecord("example.test", DnsResourceRecordType.NS, DnsClass.IN, 300, new DnsNSRecordData("ns.example.test"))]); + + NameServerAddress server = Assert.Single(NameServerAddress.GetNameServersFromResponse(response, IPv6Mode.Disabled, filterLoopbackAddresses: false)); + + Assert.Null(server.IPEndPoint); + Assert.Equal("ns.example.test", server.DomainEndPoint!.Address); + Assert.Equal(53, server.Port); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/ResourceRecords/DnsResourceRecordDataTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/ResourceRecords/DnsResourceRecordDataTests.cs new file mode 100644 index 0000000..6dd390f --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/Dns/ResourceRecords/DnsResourceRecordDataTests.cs @@ -0,0 +1,129 @@ +using System.Net; +using System.Text.Json; +using TechnitiumLibrary.Net; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.Dnssec; +using TechnitiumLibrary.Net.Dns.EDnsOptions; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Dns.ResourceRecords +{ + public class DnsResourceRecordDataTests + { + public static IEnumerable RecordDataRoundTripCases() + { + yield return Case(DnsResourceRecordType.A, new DnsARecordData(IPAddress.Parse("192.0.2.1"))); + yield return Case(DnsResourceRecordType.AAAA, new DnsAAAARecordData(IPAddress.Parse("2001:db8::1"))); + yield return Case(DnsResourceRecordType.NS, new DnsNSRecordData("ns1.example.test")); + yield return Case(DnsResourceRecordType.CNAME, new DnsCNAMERecordData("target.example.test")); + yield return Case(DnsResourceRecordType.DNAME, new DnsDNAMERecordData("target.example.test")); + yield return Case(DnsResourceRecordType.PTR, new DnsPTRRecordData("ptr.example.test")); + yield return Case(DnsResourceRecordType.ANAME, new DnsANAMERecordData("alias.example.test")); + yield return Case(DnsResourceRecordType.ALIAS, new DnsALIASRecordData(DnsResourceRecordType.A, "alias.example.test")); + yield return Case(DnsResourceRecordType.MX, new DnsMXRecordData(10, "mail.example.test")); + yield return Case(DnsResourceRecordType.SOA, new DnsSOARecordData("ns1.example.test", "hostmaster.example.test", 1, 3600, 600, 604800, 60)); + yield return Case(DnsResourceRecordType.SRV, new DnsSRVRecordData(1, 5, 443, "svc.example.test")); + yield return Case(DnsResourceRecordType.RP, new DnsRPRecordData("admin.example.test", "txt.example.test")); + yield return Case(DnsResourceRecordType.HINFO, new DnsHINFORecordData("x64", "linux")); + yield return Case(DnsResourceRecordType.TXT, new DnsTXTRecordData(["hello", "world"])); + yield return Case(DnsResourceRecordType.NAPTR, new DnsNAPTRRecordData(1, 2, "s", "SIP+D2U", "!^.*$!sip:info@example.test!", "replacement.example.test")); + yield return Case(DnsResourceRecordType.CAA, new DnsCAARecordData(0, "issue", "ca.example.test")); + yield return Case(DnsResourceRecordType.URI, new DnsURIRecordData(10, 1, new Uri("https://example.test/dns"))); + yield return Case(DnsResourceRecordType.DS, new DnsDSRecordData(12345, DnssecAlgorithm.RSASHA256, DnssecDigestType.SHA256, Bytes(32))); + yield return Case(DnsResourceRecordType.DNSKEY, new DnsDNSKEYRecordData(DnsDnsKeyFlag.ZoneKey, 3, DnssecAlgorithm.Unknown, DnssecPublicKey.Parse(DnssecAlgorithm.Unknown, Bytes(8)))); + yield return Case(DnsResourceRecordType.RRSIG, new DnsRRSIGRecordData(DnsResourceRecordType.A, DnssecAlgorithm.RSASHA256, 2, 60, 2000000000, 1900000000, 12345, "example.test", Bytes(16))); + yield return Case(DnsResourceRecordType.NSEC, new DnsNSECRecordData("next.example.test", [DnsResourceRecordType.A, DnsResourceRecordType.AAAA, DnsResourceRecordType.RRSIG])); + yield return Case(DnsResourceRecordType.NSEC3, new DnsNSEC3RecordData(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.OptOut, 2, [1, 2], Bytes(20), [DnsResourceRecordType.NS, DnsResourceRecordType.DS])); + yield return Case(DnsResourceRecordType.NSEC3PARAM, new DnsNSEC3PARAMRecordData(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.None, 2, [1, 2])); + yield return Case(DnsResourceRecordType.SSHFP, new DnsSSHFPRecordData(DnsSSHFPAlgorithm.RSA, DnsSSHFPFingerprintType.SHA256, Bytes(32))); + yield return Case(DnsResourceRecordType.TLSA, new DnsTLSARecordData(DnsTLSACertificateUsage.DANE_EE, DnsTLSASelector.SPKI, DnsTLSAMatchingType.SHA2_256, Bytes(32))); + yield return Case(DnsResourceRecordType.ZONEMD, new DnsZONEMDRecordData(1234, ZoneMdScheme.Simple, ZoneMdHashAlgorithm.SHA384, Bytes(48))); + yield return Case(DnsResourceRecordType.APL, new DnsAPLRecordData(new NetworkAddress(IPAddress.Parse("192.0.2.0"), 24), false)); + yield return Case(DnsResourceRecordType.SVCB, new DnsSVCBRecordData(1, "svc.example.test", new Dictionary + { + [DnsSvcParamKey.ALPN] = new DnsSvcAlpnParamValue(["h2", "dot"]), + [DnsSvcParamKey.Port] = new DnsSvcPortParamValue(853), + [DnsSvcParamKey.IPv4Hint] = new DnsSvcIPv4HintParamValue([IPAddress.Parse("192.0.2.53")]), + [DnsSvcParamKey.IPv6Hint] = new DnsSvcIPv6HintParamValue([IPAddress.Parse("2001:db8::53")]), + [DnsSvcParamKey.DoHPath] = new DnsSvcDoHPathParamValue("/dns-query{?dns}") + })); + yield return Case(DnsResourceRecordType.TSIG, new DnsTSIGRecordData("hmac-sha256", DateTime.UnixEpoch.AddSeconds(1234), 300, Bytes(16), 7, DnsTsigError.NoError, [9, 8])); + yield return Case(DnsResourceRecordType.FWD, new DnsForwarderRecordData(DnsTransportProtocol.Udp, "192.0.2.53", true, DnsForwarderRecordProxyType.NoProxy, null, 0, null, null, 1)); + yield return Case(DnsResourceRecordType.APP, new DnsApplicationRecordData("app", "Namespace.Type", "{\"enabled\":true}")); + } + + [Theory] + [MemberData(nameof(RecordDataRoundTripCases))] + public void RecordDataRoundTripsThroughWireFormat(DnsResourceRecordType type, DnsResourceRecordData recordData) + { + byte[] rData = WriteAndExtractRData(recordData); + + DnsResourceRecordData parsed = DnsResourceRecord.ReadRecordDataFrom(type, rData); + + Assert.Equal(recordData, parsed); + Assert.True(parsed.UncompressedLength > 0); + Assert.NotEmpty(parsed.ToString()); + AssertJsonCanBeWritten(parsed); + } + + [Fact] + public void UnknownRecordDataRoundTripsEmptyAndNonEmptyPayloads() + { + DnsResourceRecordData empty = DnsResourceRecord.ReadRecordDataFrom((DnsResourceRecordType)65000, []); + DnsResourceRecordData payload = DnsResourceRecord.ReadRecordDataFrom((DnsResourceRecordType)65000, [1, 2, 3, 4]); + + Assert.Equal(new DnsUnknownRecordData([]), empty); + Assert.Equal(new DnsUnknownRecordData([1, 2, 3, 4]), payload); + Assert.NotEqual(empty, payload); + AssertJsonCanBeWritten(empty); + AssertJsonCanBeWritten(payload); + } + + [Fact] + public void OptRecordDataRoundTripsKnownAndUnknownOptions() + { + DnsOPTRecordData recordData = new DnsOPTRecordData( + [ + new EDnsOption(EDnsOptionCode.EDNS_CLIENT_SUBNET, new EDnsClientSubnetOptionData(24, 0, IPAddress.Parse("192.0.2.0"))), + new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.NetworkError, "upstream timeout")), + new EDnsOption(EDnsOptionCode.COOKIE, new EDnsUnknownOptionData([1, 2, 3, 4])) + ]); + + DnsOPTRecordData parsed = Assert.IsType(DnsResourceRecord.ReadRecordDataFrom(DnsResourceRecordType.OPT, WriteAndExtractRData(recordData))); + + Assert.Equal(recordData, parsed); + Assert.Equal(3, parsed.Options.Count); + Assert.IsType(parsed.Options[0].Data); + Assert.IsType(parsed.Options[1].Data); + Assert.IsType(parsed.Options[2].Data); + AssertJsonCanBeWritten(parsed); + } + + private static object[] Case(DnsResourceRecordType type, DnsResourceRecordData recordData) + { + return [type, recordData]; + } + + private static byte[] WriteAndExtractRData(DnsResourceRecordData recordData) + { + using MemoryStream stream = new MemoryStream(); + recordData.WriteTo(stream); + byte[] wireFormat = stream.ToArray(); + return wireFormat.Skip(2).ToArray(); + } + + private static byte[] Bytes(int length) + { + return Enumerable.Range(1, length).Select(i => (byte)i).ToArray(); + } + + private static void AssertJsonCanBeWritten(DnsResourceRecordData recordData) + { + using MemoryStream stream = new MemoryStream(); + using (Utf8JsonWriter writer = new Utf8JsonWriter(stream)) + recordData.SerializeTo(writer); + + Assert.True(stream.Length > 0); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DnsClientSimulatorTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DnsClientSimulatorTests.cs new file mode 100644 index 0000000..4df6bf7 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DnsClientSimulatorTests.cs @@ -0,0 +1,125 @@ +using System.Net; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; +using TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + public class DnsClientSimulatorTests + { + [Fact] + public async Task ResolveAsyncUsesUdpSimulatorAndParsesARecord() + { + using DnsTestServer server = new DnsTestServer(); + server.AddAddress("example.test", IPAddress.Parse("192.0.2.10")); + server.Start(); + DnsClient client = CreateClient(server, DnsTransportProtocol.Udp); + + DnsDatagram response = await client.ResolveAsync("example.test", DnsResourceRecordType.A); + IReadOnlyList addresses = DnsClient.ParseResponseA(response); + + IPAddress address = Assert.Single(addresses); + Assert.Equal(IPAddress.Parse("192.0.2.10"), address); + Assert.True(server.UdpQueryCount >= 1); + Assert.Equal(0, server.TcpQueryCount); + } + + [Fact] + public async Task ResolveIPAsyncCombinesAAndAAAAResponsesFromSimulator() + { + using DnsTestServer server = new DnsTestServer(); + server.AddAddress("dual.example.test", IPAddress.Parse("192.0.2.20")); + server.AddAddress("dual.example.test", IPAddress.Parse("2001:db8::20")); + server.Start(); + DnsClient client = CreateClient(server, DnsTransportProtocol.Udp); + + IReadOnlyList addresses = await DnsClient.ResolveIPAsync(client, "dual.example.test", IPv6Mode.Enabled); + + Assert.Contains(IPAddress.Parse("192.0.2.20"), addresses); + Assert.Contains(IPAddress.Parse("2001:db8::20"), addresses); + Assert.Equal(2, addresses.Count); + } + + [Fact] + public async Task TruncatedUdpResponseFallsBackToTcpOnSameSimulator() + { + using DnsTestServer server = new DnsTestServer(); + server.AddAddress("fallback.example.test", IPAddress.Parse("192.0.2.30")); + server.TruncateUdpResponses = true; + server.Start(); + DnsClient client = CreateClient(server, DnsTransportProtocol.Udp); + + DnsDatagram response = await client.ResolveAsync("fallback.example.test", DnsResourceRecordType.A); + + Assert.Equal(IPAddress.Parse("192.0.2.30"), Assert.Single(DnsClient.ParseResponseA(response))); + Assert.Equal(1, server.UdpQueryCount); + Assert.True(server.TcpQueryCount >= 1); + } + + [Fact] + public async Task TcpSimulatorSupportsMxResolution() + { + using DnsTestServer server = new DnsTestServer(); + server.AddMx("example.test", 10, "mail.example.test"); + server.Start(); + DnsClient client = CreateClient(server, DnsTransportProtocol.Tcp); + + IReadOnlyList exchanges = await DnsClient.ResolveMXAsync(client, "example.test"); + + Assert.Equal("mail.example.test", Assert.Single(exchanges)); + Assert.Equal(0, server.UdpQueryCount); + Assert.True(server.TcpQueryCount >= 1); + } + + [Fact] + public async Task CNameChainFromSimulatorIsParsedForAResponse() + { + using DnsTestServer server = new DnsTestServer(); + server.AddCNameAddress("alias.example.test", "target.example.test", IPAddress.Parse("192.0.2.40")); + server.Start(); + DnsClient client = CreateClient(server, DnsTransportProtocol.Udp); + + DnsDatagram response = await client.ResolveAsync("alias.example.test", DnsResourceRecordType.A); + + Assert.Equal(IPAddress.Parse("192.0.2.40"), Assert.Single(DnsClient.ParseResponseA(response))); + } + + [Fact] + public async Task NxDomainResponseFromSimulatorIsReturnedAndParsedAsException() + { + using DnsTestServer server = new DnsTestServer(); + server.SetResponseCode("missing.example.test", DnsResourceRecordType.A, DnsResponseCode.NxDomain); + server.Start(); + DnsClient client = CreateClient(server, DnsTransportProtocol.Udp); + + DnsDatagram response = await client.ResolveAsync("missing.example.test", DnsResourceRecordType.A); + + Assert.Equal(DnsResponseCode.NxDomain, response.RCODE); + Assert.Throws(() => DnsClient.ParseResponseA(response)); + } + + [Fact] + public async Task DroppedUdpResponsesSurfaceNoResponseException() + { + using DnsTestServer server = new DnsTestServer(); + server.DropUdpResponses = true; + server.Start(); + DnsClient client = CreateClient(server, DnsTransportProtocol.Udp); + client.Timeout = 100; + client.Retries = 1; + + await Assert.ThrowsAsync(() => client.ResolveAsync("timeout.example.test", DnsResourceRecordType.A)); + + Assert.True(server.UdpQueryCount >= 1); + } + + private static DnsClient CreateClient(DnsTestServer server, DnsTransportProtocol protocol) + { + return new DnsClient(new NameServerAddress(new IPEndPoint(IPAddress.Loopback, server.Port), protocol)) + { + Timeout = 1000, + Retries = 1 + }; + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DomainEndPointTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DomainEndPointTests.cs new file mode 100644 index 0000000..ff967fa --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DomainEndPointTests.cs @@ -0,0 +1,72 @@ +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + public class DomainEndPointTests + { + [Fact] + public void TryParseRejectsIpAddressAndParsesDomain() + { + Assert.False(DomainEndPoint.TryParse("127.0.0.1:53", out _)); + Assert.True(DomainEndPoint.TryParse("example.com:853", out DomainEndPoint endpoint)); + + Assert.Equal("example.com", endpoint.Address); + Assert.Equal(853, endpoint.Port); + Assert.Equal(System.Net.Sockets.AddressFamily.Unspecified, endpoint.AddressFamily); + Assert.Equal("example.com:853", endpoint.ToString()); + } + + [Fact] + public void TryParseUsesZeroPortWhenPortIsOmitted() + { + Assert.True(DomainEndPoint.TryParse("example.com", out DomainEndPoint endpoint)); + + Assert.Equal("example.com", endpoint.Address); + Assert.Equal(0, endpoint.Port); + } + + [Fact] + public void TryParseRejectsInvalidDomainAndPort() + { + Assert.False(DomainEndPoint.TryParse("bad domain:53", out _)); + Assert.False(DomainEndPoint.TryParse("example.com:not-a-port", out _)); + Assert.False(DomainEndPoint.TryParse("example.com:53:extra", out _)); + } + + [Fact] + public void ConstructorRejectsIpAddressAndNormalizesUnicodeDomain() + { + Assert.Throws(() => new DomainEndPoint("127.0.0.1", 53)); + + DomainEndPoint endpoint = new DomainEndPoint("bücher.example", 443); + + Assert.Equal("xn--bcher-kva.example", endpoint.Address); + Assert.Equal(443, endpoint.Port); + } + + [Fact] + public void GetAddressBytesUsesLengthPrefixedAsciiDomain() + { + DomainEndPoint endpoint = new DomainEndPoint("example.com", 53); + + byte[] address = endpoint.GetAddressBytes(); + + Assert.Equal(12, address.Length); + Assert.Equal(11, address[0]); + Assert.Equal("example.com", System.Text.Encoding.ASCII.GetString(address, 1, address.Length - 1)); + } + + [Fact] + public void EqualsComparesAddressCaseInsensitivelyAndPort() + { + DomainEndPoint endpoint = new DomainEndPoint("Example.com", 53); + + Assert.True(endpoint.Equals(endpoint)); + Assert.True(endpoint.Equals(new DomainEndPoint("example.com", 53))); + Assert.False(endpoint.Equals(null)); + Assert.False(endpoint.Equals("example.com:53")); + Assert.False(endpoint.Equals(new DomainEndPoint("example.net", 53))); + Assert.False(endpoint.Equals(new DomainEndPoint("example.com", 853))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/EndPointExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/EndPointExtensionsTests.cs new file mode 100644 index 0000000..8b9303d --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/EndPointExtensionsTests.cs @@ -0,0 +1,127 @@ +using System.IO; +using System.Net; +using System.Net.Sockets; +using TechnitiumLibrary.Net; +using TechnitiumLibrary.Tests.Simulators.TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + public class EndPointExtensionsTests + { + [Theory] + [InlineData("192.0.2.1", 53)] + [InlineData("2001:db8::1", 853)] + public void WriteAndReadRoundtripsIpEndPoints(string address, int port) + { + IPEndPoint expected = new IPEndPoint(IPAddress.Parse(address), port); + using MemoryStream stream = new MemoryStream(); + using BinaryWriter writer = new BinaryWriter(stream); + + expected.WriteTo(writer); + stream.Position = 0; + + Assert.Equal(expected, EndPointExtensions.ReadFrom(new BinaryReader(stream))); + } + + [Fact] + public void WriteAndReadRoundtripsDomainEndPoint() + { + DomainEndPoint expected = new DomainEndPoint("example.com", 853); + using MemoryStream stream = new MemoryStream(); + using BinaryWriter writer = new BinaryWriter(stream); + + expected.WriteTo(writer); + stream.Position = 0; + + EndPoint actual = EndPointExtensions.ReadFrom(new BinaryReader(stream)); + + Assert.True(expected.Equals(actual)); + } + + [Fact] + public void ReadFromRejectsUnsupportedMarker() + { + using MemoryStream stream = new MemoryStream(new byte[] { 99 }); + + Assert.Throws(() => EndPointExtensions.ReadFrom(new BinaryReader(stream))); + } + + [Fact] + public void AddressAndPortHelpersHandleIpAndDomainEndPoints() + { + EndPoint ip = new IPEndPoint(IPAddress.Parse("192.0.2.1"), 53); + EndPoint domain = new DomainEndPoint("example.com", 853); + + Assert.Equal("192.0.2.1", ip.GetAddress()); + Assert.Equal(53, ip.GetPort()); + ip.SetPort(54); + Assert.Equal(54, ip.GetPort()); + + Assert.Equal("example.com", domain.GetAddress()); + Assert.Equal(853, domain.GetPort()); + domain.SetPort(443); + Assert.Equal(443, domain.GetPort()); + } + + [Fact] + public void GetEndPointCreatesIpOrDomainEndPoint() + { + Assert.IsType(EndPointExtensions.GetEndPoint("192.0.2.1", 53)); + Assert.IsType(EndPointExtensions.GetEndPoint("example.com", 53)); + } + + [Fact] + public void ParseAndTryParseHandleIpDomainAndInvalidValues() + { + Assert.IsType(EndPointExtensions.Parse("192.0.2.1:53")); + Assert.IsType(EndPointExtensions.Parse("example.com:853")); + Assert.True(EndPointExtensions.TryParse("example.com:853", out EndPoint endpoint)); + Assert.Equal(853, endpoint.GetPort()); + Assert.False(EndPointExtensions.TryParse("bad domain:not-a-port", out _)); + Assert.Throws(() => EndPointExtensions.Parse("bad domain:not-a-port")); + } + + [Fact] + public void IsEqualsHandlesNullReferenceFamilyAndValueComparison() + { + EndPoint ip = new IPEndPoint(IPAddress.Parse("192.0.2.1"), 53); + EndPoint matchingIp = new IPEndPoint(IPAddress.Parse("192.0.2.1"), 53); + EndPoint differentFamily = new DomainEndPoint("example.com", 53); + + Assert.True(ip.IsEquals(ip)); + Assert.True(ip.IsEquals(matchingIp)); + Assert.False(ip.IsEquals(null)); + Assert.False(ip.IsEquals(differentFamily)); + Assert.True(differentFamily.IsEquals(new DomainEndPoint("EXAMPLE.com", 53))); + } + + [Fact] + public async Task GetIPEndPointAsyncReturnsExistingIpEndPointWithoutDnsLookup() + { + IPEndPoint expected = new IPEndPoint(IPAddress.Loopback, 53); + + IPEndPoint actual = await expected.GetIPEndPointAsync(); + + Assert.Same(expected, actual); + } + + [Fact] + public async Task UnsupportedEndPointFamilyThrows() + { + EndPoint unsupported = new UnsupportedEndPoint(AddressFamily.Unknown); + + Assert.Throws(() => unsupported.WriteTo(new BinaryWriter(new MemoryStream()))); + Assert.Throws(() => unsupported.GetAddress()); + Assert.Throws(() => unsupported.GetPort()); + Assert.Throws(() => unsupported.SetPort(53)); + Assert.Throws(() => unsupported.IsEquals(new UnsupportedEndPoint(AddressFamily.Unknown))); + await Assert.ThrowsAsync(() => unsupported.GetIPEndPointAsync()); + } + + [Fact] + public async Task UnspecifiedNonDomainEndPointThrowsWhenResolving() + { + await Assert.ThrowsAsync(() => new UnsupportedEndPoint(AddressFamily.Unspecified).GetIPEndPointAsync()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/IPAddressExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/IPAddressExtensionsTests.cs new file mode 100644 index 0000000..5f869f9 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/IPAddressExtensionsTests.cs @@ -0,0 +1,129 @@ +using System.IO; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + public class IPAddressExtensionsTests + { + [Theory] + [InlineData("192.0.2.1")] + [InlineData("2001:db8::1")] + public void WriteAndReadRoundtrips(string addressValue) + { + IPAddress expected = IPAddress.Parse(addressValue); + using MemoryStream stream = new MemoryStream(); + + expected.WriteTo(stream); + stream.Position = 0; + + Assert.Equal(expected, IPAddressExtensions.ReadFrom(stream)); + } + + [Fact] + public void BinaryWriterAndReaderOverloadsRoundtrip() + { + IPAddress expected = IPAddress.Parse("192.0.2.1"); + using MemoryStream stream = new MemoryStream(); + using BinaryWriter writer = new BinaryWriter(stream); + + expected.WriteTo(writer); + stream.Position = 0; + + Assert.Equal(expected, IPAddressExtensions.ReadFrom(new BinaryReader(stream))); + } + + [Fact] + public void ReadFromRejectsUnsupportedMarkerAndEndOfStream() + { + Assert.Throws(() => IPAddressExtensions.ReadFrom(new MemoryStream(new byte[] { 99 }))); + Assert.Throws(() => IPAddressExtensions.ReadFrom(new MemoryStream())); + } + + [Fact] + public void ConvertsIPv4AddressToAndFromNumber() + { + IPAddress address = IPAddress.Parse("192.0.2.1"); + + uint number = address.ConvertIpToNumber(); + + Assert.Equal(0xC0000201u, number); + Assert.Equal(address, IPAddressExtensions.ConvertNumberToIp(number)); + Assert.Throws(() => IPAddress.Parse("2001:db8::1").ConvertIpToNumber()); + } + + [Fact] + public void SubnetMaskHelpersHandleValidAndInvalidPrefixes() + { + Assert.Equal(IPAddress.Any, IPAddressExtensions.GetSubnetMask(0)); + Assert.Equal(IPAddress.Parse("255.255.255.0"), IPAddressExtensions.GetSubnetMask(24)); + Assert.Equal(24, IPAddress.Parse("255.255.255.0").GetSubnetMaskWidth()); + Assert.Throws(() => IPAddressExtensions.GetSubnetMask(33)); + Assert.Throws(() => IPAddress.Parse("ffff:ffff::").GetSubnetMaskWidth()); + } + + [Fact] + public void GetNetworkAddressHandlesIPv4AndIPv6() + { + Assert.Equal(IPAddress.Parse("192.0.2.128"), IPAddress.Parse("192.0.2.200").GetNetworkAddress(25)); + Assert.Same(IPAddress.Loopback, IPAddress.Loopback.GetNetworkAddress(32)); + Assert.Equal(IPAddress.Parse("2001:db8:abcd:1200::"), IPAddress.Parse("2001:db8:abcd:1234::1").GetNetworkAddress(56)); + Assert.Equal(IPAddress.Parse("2001:db8:abcd:1280::"), IPAddress.Parse("2001:db8:abcd:12ff::1").GetNetworkAddress(57)); + + IPAddress ipv6 = IPAddress.Parse("2001:db8::1"); + Assert.Same(ipv6, ipv6.GetNetworkAddress(128)); + Assert.Throws(() => IPAddress.Loopback.GetNetworkAddress(33)); + Assert.Throws(() => ipv6.GetNetworkAddress(129)); + } + + [Theory] + [InlineData(32)] + [InlineData(40)] + [InlineData(48)] + [InlineData(56)] + [InlineData(64)] + [InlineData(96)] + public void MapToIPv6AndMapToIPv4RoundtripForSupportedPrefixes(byte prefixLength) + { + IPAddress ipv4 = IPAddress.Parse("192.0.2.33"); + NetworkAddress prefix = new NetworkAddress(IPAddress.Parse("2001:db8:1234:5678::"), prefixLength); + + IPAddress mapped = ipv4.MapToIPv6(prefix); + + Assert.Equal(System.Net.Sockets.AddressFamily.InterNetworkV6, mapped.AddressFamily); + Assert.Equal(ipv4, mapped.MapToIPv4(prefixLength)); + } + + [Fact] + public void MappingHelpersReturnAlreadyMatchingFamilyAndRejectUnsupportedPrefix() + { + IPAddress ipv4 = IPAddress.Parse("192.0.2.33"); + IPAddress ipv6 = IPAddress.Parse("2001:db8::c000:221"); + + Assert.Same(ipv6, ipv6.MapToIPv6(NetworkAddress.Parse("64:ff9b::/96"))); + Assert.Same(ipv4, ipv4.MapToIPv4(96)); + Assert.Throws(() => ipv4.MapToIPv6(NetworkAddress.Parse("2001:db8::/65"))); + Assert.Throws(() => ipv6.MapToIPv4(65)); + } + + [Fact] + public void ReverseDomainRoundtripsIPv4AndIPv6() + { + IPAddress ipv4 = IPAddress.Parse("192.0.2.1"); + IPAddress ipv6 = IPAddress.Parse("2001:db8::1"); + + Assert.Equal("1.2.0.192.in-addr.arpa", ipv4.GetReverseDomain()); + Assert.Equal(ipv4, IPAddressExtensions.ParseReverseDomain(ipv4.GetReverseDomain())); + Assert.Equal(ipv6, IPAddressExtensions.ParseReverseDomain(ipv6.GetReverseDomain())); + } + + [Fact] + public void TryParseReverseDomainRejectsInvalidDomains() + { + Assert.False(IPAddressExtensions.TryParseReverseDomain("not-a-reverse.example", out _)); + Assert.False(IPAddressExtensions.TryParseReverseDomain("x.2.0.192.in-addr.arpa", out _)); + Assert.False(IPAddressExtensions.TryParseReverseDomain("x.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", out _)); + Assert.Throws(() => IPAddressExtensions.ParseReverseDomain("not-a-reverse.example")); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAccessControlTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAccessControlTests.cs new file mode 100644 index 0000000..aa9a184 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAccessControlTests.cs @@ -0,0 +1,113 @@ +using System.IO; +using System.Linq; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + public class NetworkAccessControlTests + { + [Fact] + public void ParseHandlesAllowAndDenyRules() + { + NetworkAccessControl allow = NetworkAccessControl.Parse("192.0.2.0/24"); + NetworkAccessControl deny = NetworkAccessControl.Parse("! 192.0.2.128/25"); + + Assert.False(allow.Deny); + Assert.True(deny.Deny); + Assert.Equal("192.0.2.0/24", allow.ToString()); + Assert.Equal("!192.0.2.128/25", deny.ToString()); + Assert.False(NetworkAccessControl.TryParse("not-a-network", out _)); + Assert.Throws(() => NetworkAccessControl.Parse("not-a-network")); + } + + [Fact] + public void TryMatchReturnsAllowedStateWhenAddressIsInsideNetwork() + { + NetworkAccessControl allow = NetworkAccessControl.Parse("192.0.2.0/24"); + NetworkAccessControl deny = NetworkAccessControl.Parse("!192.0.2.0/24"); + + Assert.True(allow.TryMatch(IPAddress.Parse("192.0.2.42"), out bool isAllowed)); + Assert.True(isAllowed); + Assert.True(deny.TryMatch(IPAddress.Parse("192.0.2.42"), out isAllowed)); + Assert.False(isAllowed); + Assert.False(allow.TryMatch(IPAddress.Parse("192.0.3.42"), out isAllowed)); + Assert.False(isAllowed); + } + + [Fact] + public void IsAddressAllowedUsesFirstMatchingRuleAndLoopbackFallback() + { + NetworkAccessControl[] acl = + [ + NetworkAccessControl.Parse("!192.0.2.128/25"), + NetworkAccessControl.Parse("192.0.2.0/24") + ]; + + Assert.False(NetworkAccessControl.IsAddressAllowed(IPAddress.Parse("192.0.2.200"), acl)); + Assert.True(NetworkAccessControl.IsAddressAllowed(IPAddress.Parse("192.0.2.42"), acl)); + Assert.False(NetworkAccessControl.IsAddressAllowed(IPAddress.Parse("198.51.100.1"), acl)); + Assert.True(NetworkAccessControl.IsAddressAllowed(IPAddress.Loopback, null, allowLoopbackWhenNoMatch: true)); + Assert.True(NetworkAccessControl.IsAddressAllowed(IPAddress.Parse("::ffff:192.0.2.42"), acl)); + } + + [Fact] + public void WriteAndReadRoundtrips() + { + NetworkAccessControl expected = NetworkAccessControl.Parse("!2001:db8::/32"); + using MemoryStream stream = new MemoryStream(); + + expected.WriteTo(stream); + stream.Position = 0; + + NetworkAccessControl actual = NetworkAccessControl.ReadFrom(stream); + + Assert.Equal(expected, actual); + Assert.Equal(expected.GetHashCode(), actual.GetHashCode()); + } + + [Fact] + public void BinaryWriterAndReaderOverloadsRoundtrip() + { + NetworkAccessControl expected = new NetworkAccessControl(IPAddress.Parse("192.0.2.1"), 32); + using MemoryStream stream = new MemoryStream(); + using BinaryWriter writer = new BinaryWriter(stream); + + expected.WriteTo(writer); + stream.Position = 0; + + NetworkAccessControl actual = NetworkAccessControl.ReadFrom(new BinaryReader(stream)); + + Assert.Equal(expected, actual); + Assert.False(actual.Deny); + Assert.Equal(IPAddress.Parse("192.0.2.1"), actual.NetworkAddress.Address); + } + + [Fact] + public void AplConversionRoundtripsAccessControlList() + { + NetworkAccessControl[] expected = + [ + NetworkAccessControl.Parse("192.0.2.0/24"), + NetworkAccessControl.Parse("!2001:db8::/32") + ]; + + var apl = NetworkAccessControl.ConvertToAPLRecordData(expected); + NetworkAccessControl[] actual = NetworkAccessControl.ConvertFromAPLRecordData(apl).ToArray(); + + Assert.Equal(expected, actual); + } + + [Fact] + public void EqualsHandlesNullReferenceAndDifferentValues() + { + NetworkAccessControl accessControl = NetworkAccessControl.Parse("192.0.2.0/24"); + + Assert.True(accessControl.Equals((object)accessControl)); + Assert.False(accessControl.Equals(null)); + Assert.False(accessControl.Equals((object)"192.0.2.0/24")); + Assert.False(accessControl.Equals(NetworkAccessControl.Parse("!192.0.2.0/24"))); + Assert.False(accessControl.Equals(NetworkAccessControl.Parse("192.0.3.0/24"))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAddressTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAddressTests.cs new file mode 100644 index 0000000..9ea544a --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAddressTests.cs @@ -0,0 +1,108 @@ +using System.IO; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + public class NetworkAddressTests + { + [Fact] + public void ParseNormalizesAndContainsIPv4Addresses() + { + NetworkAddress network = NetworkAddress.Parse("192.168.1.123/24"); + + Assert.Equal(IPAddress.Parse("192.168.1.0"), network.Address); + Assert.Equal((byte)24, network.PrefixLength); + Assert.Equal(System.Net.Sockets.AddressFamily.InterNetwork, network.AddressFamily); + Assert.False(network.IsHostAddress); + Assert.True(network.Contains(IPAddress.Parse("192.168.1.200"))); + Assert.False(network.Contains(IPAddress.Parse("192.168.2.1"))); + Assert.False(network.Contains(IPAddress.Parse("2001:db8::1"))); + Assert.Equal("192.168.1.0/24", network.ToString()); + Assert.Equal(IPAddress.Parse("192.168.1.255"), network.GetLastAddress()); + } + + [Fact] + public void ParseWithoutPrefixCreatesHostAddress() + { + NetworkAddress ipv4 = NetworkAddress.Parse("192.0.2.10"); + NetworkAddress ipv6 = NetworkAddress.Parse("2001:db8::10"); + + Assert.True(ipv4.IsHostAddress); + Assert.True(ipv6.IsHostAddress); + Assert.Equal((byte)32, ipv4.PrefixLength); + Assert.Equal((byte)128, ipv6.PrefixLength); + Assert.Equal("192.0.2.10", ipv4.ToString()); + Assert.Equal("2001:db8::10", ipv6.ToString()); + } + + [Fact] + public void GetLastAddressHandlesIPv6Networks() + { + NetworkAddress network = NetworkAddress.Parse("2001:db8:abcd:1200::1/56"); + + Assert.Equal(IPAddress.Parse("2001:db8:abcd:12ff:ffff:ffff:ffff:ffff"), network.GetLastAddress()); + Assert.True(network.Contains(IPAddress.Parse("2001:db8:abcd:12aa::5"))); + Assert.False(network.Contains(IPAddress.Parse("2001:db8:abcd:1300::1"))); + } + + [Fact] + public void WriteAndReadRoundtrips() + { + NetworkAddress expected = NetworkAddress.Parse("2001:db8::1/64"); + using MemoryStream stream = new MemoryStream(); + + expected.WriteTo(stream); + stream.Position = 0; + + NetworkAddress actual = NetworkAddress.ReadFrom(stream); + + Assert.Equal(expected, actual); + Assert.Equal(expected.GetHashCode(), actual.GetHashCode()); + Assert.Equal("2001:db8::/64", actual.ToString()); + } + + [Fact] + public void BinaryWriterAndReaderOverloadsRoundtrip() + { + NetworkAddress expected = NetworkAddress.Parse("192.0.2.128/25"); + using MemoryStream stream = new MemoryStream(); + using BinaryWriter writer = new BinaryWriter(stream); + + expected.WriteTo(writer); + stream.Position = 0; + + NetworkAddress actual = NetworkAddress.ReadFrom(new BinaryReader(stream)); + + Assert.Equal(expected, actual); + } + + [Fact] + public void ConstructorRejectsInvalidIPv4Prefix() + { + Assert.Throws(() => new NetworkAddress(IPAddress.Loopback, 33)); + } + + [Fact] + public void TryParseRejectsInvalidValues() + { + Assert.False(NetworkAddress.TryParse("not-an-ip/24", out _)); + Assert.False(NetworkAddress.TryParse("192.0.2.1/not-a-prefix", out _)); + Assert.False(NetworkAddress.TryParse("192.0.2.1/33", out _)); + Assert.False(NetworkAddress.TryParse("2001:db8::1/129", out _)); + Assert.Throws(() => NetworkAddress.Parse("192.0.2.1/33")); + } + + [Fact] + public void EqualsHandlesNullReferenceAndDifferentValues() + { + NetworkAddress network = NetworkAddress.Parse("192.0.2.0/24"); + + Assert.True(network.Equals((object)network)); + Assert.False(network.Equals(null)); + Assert.False(network.Equals((object)"192.0.2.0/24")); + Assert.False(network.Equals(NetworkAddress.Parse("192.0.3.0/24"))); + Assert.False(network.Equals(NetworkAddress.Parse("192.0.2.0/25"))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkMapTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkMapTests.cs new file mode 100644 index 0000000..be838a0 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkMapTests.cs @@ -0,0 +1,78 @@ +using System.Net; +using System.Net.Sockets; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + public class NetworkMapTests + { + [Fact] + public void TryGetValueReturnsValueForAddressesInsideIPv4Network() + { + NetworkMap map = new NetworkMap(AddressFamily.InterNetwork); + object value = new object(); + + map.Add("192.0.2.0/24", value); + + Assert.True(map.TryGetValue("192.0.2.0", out object actual)); + Assert.Same(value, actual); + Assert.True(map.TryGetValue(IPAddress.Parse("192.0.2.200"), out actual)); + Assert.Same(value, actual); + Assert.True(map.TryGetValue("192.0.2.255", out actual)); + Assert.Same(value, actual); + Assert.False(map.TryGetValue("192.0.3.1", out actual)); + Assert.Null(actual); + } + + [Fact] + public void RemoveDeletesBothBoundaryEntries() + { + NetworkMap map = new NetworkMap(AddressFamily.InterNetwork); + NetworkAddress network = NetworkAddress.Parse("192.0.2.0/24"); + object value = new object(); + + map.Add(network, value); + + Assert.True(map.Remove(network)); + Assert.False(map.TryGetValue("192.0.2.42", out _)); + Assert.False(map.Remove(network)); + } + + [Fact] + public void TryGetValueReturnsFalseWhenMapIsEmptyOrAddressIsOutsideRange() + { + NetworkMap map = new NetworkMap(AddressFamily.InterNetwork, capacity: 4); + + Assert.False(map.TryGetValue("192.0.2.1", out _)); + + object value = new object(); + map.Add("192.0.2.10/32", value); + + Assert.False(map.TryGetValue("192.0.2.9", out _)); + Assert.False(map.TryGetValue("192.0.2.11", out _)); + } + + [Fact] + public void IPv6NetworksCanBeAddedAndQueried() + { + NetworkMap map = new NetworkMap(AddressFamily.InterNetworkV6); + object value = new object(); + + map.Add("2001:db8::/32", value); + + Assert.True(map.TryGetValue("2001:db8::1234", out object actual)); + Assert.Same(value, actual); + Assert.False(map.TryGetValue("2001:db9::1", out _)); + } + + [Fact] + public void AddRemoveAndLookupRejectWrongAddressFamily() + { + NetworkMap ipv4Map = new NetworkMap(AddressFamily.InterNetwork); + + Assert.Throws(() => ipv4Map.Add("2001:db8::/32", new object())); + Assert.Throws(() => ipv4Map.Remove(NetworkAddress.Parse("2001:db8::/32"))); + Assert.Throws(() => ipv4Map.TryGetValue(IPAddress.Parse("2001:db8::1"), out _)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/AsymmetricCryptoKeyTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/AsymmetricCryptoKeyTests.cs new file mode 100644 index 0000000..c8bac69 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/AsymmetricCryptoKeyTests.cs @@ -0,0 +1,77 @@ +using System.Security.Cryptography; +using System.Text; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class AsymmetricCryptoKeyTests + { + [Fact] + public void RsaKeyEncryptsDecryptsSignsVerifiesAndSerializes() + { + byte[] clearText = Encoding.UTF8.GetBytes("secret"); + byte[] data = Encoding.UTF8.GetBytes("signed data"); + byte[] hash = SHA256.HashData(data); + + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + byte[] cipherText = key.Encrypt(clearText); + byte[] signature = key.Sign(hash, "SHA256"); + + Assert.Equal(AsymmetricEncryptionAlgorithm.RSA, key.Algorithm); + Assert.Equal(clearText, key.Decrypt(cipherText)); + Assert.True(key.Verify(hash, signature, "SHA256")); + Assert.True(key.Verify(new MemoryStream(data), signature, "SHA256")); + Assert.False(key.Verify(SHA256.HashData("other"u8.ToArray()), signature, "SHA256")); + Assert.NotNull(key.GetRSAPublicKey().Modulus); + + string publicKey = key.GetPublicKey(); + Assert.Equal(clearText, AsymmetricCryptoKey.Decrypt(AsymmetricCryptoKey.Encrypt(clearText, AsymmetricEncryptionAlgorithm.RSA, publicKey), AsymmetricEncryptionAlgorithm.RSA, keyXml(key))); + Assert.True(AsymmetricCryptoKey.Verify(hash, AsymmetricCryptoKey.Sign(hash, "SHA256", AsymmetricEncryptionAlgorithm.RSA, keyXml(key)), "SHA256", AsymmetricEncryptionAlgorithm.RSA, publicKey)); + Assert.True(AsymmetricCryptoKey.Verify(new MemoryStream(data), AsymmetricCryptoKey.Sign(new MemoryStream(data), "SHA256", AsymmetricEncryptionAlgorithm.RSA, keyXml(key)), "SHA256", AsymmetricEncryptionAlgorithm.RSA, publicKey)); + + using MemoryStream stream = new MemoryStream(); + key.WriteTo(stream); + stream.Position = 0; + using AsymmetricCryptoKey parsed = new AsymmetricCryptoKey(stream); + + Assert.Equal(clearText, parsed.Decrypt(parsed.Encrypt(clearText))); + Assert.True(parsed.Verify(hash, parsed.Sign(hash, "SHA256"), "SHA256")); + } + + [Fact] + public void CreateUsingImportsRsaParameters() + { + using RSA rsa = RSA.Create(1024); + using AsymmetricCryptoKey key = AsymmetricCryptoKey.CreateUsing(rsa.ExportParameters(true)); + + byte[] clearText = Encoding.UTF8.GetBytes("secret"); + + Assert.Equal(AsymmetricEncryptionAlgorithm.RSA, key.Algorithm); + Assert.Equal(clearText, key.Decrypt(key.Encrypt(clearText))); + } + + [Fact] + public void UnsupportedAlgorithmsAndInvalidFormatsThrow() + { + using AsymmetricCryptoKey dsa = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.DSA, 1024); + + Assert.Throws(() => new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.Unknown, 1024)); + Assert.Throws(() => AsymmetricCryptoKey.Encrypt([1], AsymmetricEncryptionAlgorithm.DSA, dsa.GetPublicKey())); + Assert.Throws(() => AsymmetricCryptoKey.Decrypt([1], AsymmetricEncryptionAlgorithm.DSA, dsa.GetPublicKey())); + Assert.Throws(() => dsa.GetRSAPublicKey()); + Assert.Throws(() => new AsymmetricCryptoKey(new MemoryStream([0, 1, 2]))); + Assert.Throws(() => new AsymmetricCryptoKey(new MemoryStream([.. Encoding.ASCII.GetBytes("AK"), 255]))); + } + + private static string keyXml(AsymmetricCryptoKey key) + { + using MemoryStream stream = new MemoryStream(); + key.WriteTo(stream); + stream.Position = 0; + using BinaryReader reader = new BinaryReader(stream); + _ = reader.ReadBytes(4); + ushort length = reader.ReadUInt16(); + return Encoding.ASCII.GetString(reader.ReadBytes(length)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateProfileTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateProfileTests.cs new file mode 100644 index 0000000..2ba98d3 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateProfileTests.cs @@ -0,0 +1,71 @@ +using System.Net.Mail; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class CertificateProfileTests + { + [Fact] + public void ProfileRoundTripsAllFieldsAndVerificationFlags() + { + CertificateProfile profile = new CertificateProfile( + name: "Alice", + type: CertificateProfileType.Individual, + emailAddress: new MailAddress("alice@example.test"), + website: new Uri("https://example.test/"), + phoneNumber: "+1 555 0100", + streetAddress: "1 Main St", + city: "Vienna", + state: "Vienna", + country: "AT", + postalCode: "1010", + verificationFlags: CertificateProfileFlags.Name | CertificateProfileFlags.EmailAddress | CertificateProfileFlags.Country); + using MemoryStream stream = new MemoryStream(); + + profile.WriteTo(stream); + stream.Position = 0; + CertificateProfile parsed = new CertificateProfile(stream); + + Assert.Equal(profile, parsed); + Assert.Equal(1, parsed.Version); + Assert.Equal(CertificateProfileType.Individual, parsed.Type); + Assert.Equal("Alice", parsed.Name); + Assert.Equal("alice@example.test", parsed.EmailAddress.Address); + Assert.Equal(new Uri("https://example.test/"), parsed.Website); + Assert.Equal("+1 555 0100", parsed.PhoneNumber); + Assert.Equal("1 Main St", parsed.StreetAddress); + Assert.Equal("Vienna", parsed.City); + Assert.Equal("Vienna", parsed.State); + Assert.Equal("AT", parsed.Country); + Assert.Equal("1010", parsed.PostalCode); + Assert.True(parsed.FieldExists(CertificateProfileFlags.Name)); + Assert.True(parsed.IsFieldVerified(CertificateProfileFlags.Name)); + Assert.False(parsed.IsFieldVerified(CertificateProfileFlags.PhoneNumber)); + Assert.Contains("Name (verified): Alice", parsed.ToString()); + Assert.Contains("Email Address (verified): alice@example.test", parsed.ToString()); + } + + [Fact] + public void ProfileMasksVerificationFlagsToExistingFields() + { + CertificateProfile profile = new CertificateProfile(name: "Bob", verificationFlags: CertificateProfileFlags.All); + + Assert.True(profile.FieldExists(CertificateProfileFlags.Name)); + Assert.True(profile.IsFieldVerified(CertificateProfileFlags.Name)); + Assert.False(profile.FieldExists(CertificateProfileFlags.EmailAddress)); + Assert.False(profile.IsFieldVerified(CertificateProfileFlags.EmailAddress)); + Assert.NotEqual(profile, new CertificateProfile(name: "Alice")); + Assert.False(profile.Equals(null)); + Assert.True(profile.Equals(profile)); + Assert.False(profile.Equals("profile")); + Assert.NotEqual(0, profile.GetHashCode()); + } + + [Fact] + public void InvalidProfileFormatThrows() + { + Assert.Throws(() => new CertificateProfile(new MemoryStream([0, 1, 2]))); + Assert.Throws(() => new CertificateProfile(new MemoryStream([.. System.Text.Encoding.ASCII.GetBytes("CP"), 255]))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateStoreTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateStoreTests.cs new file mode 100644 index 0000000..93f3906 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateStoreTests.cs @@ -0,0 +1,76 @@ +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class CertificateStoreTests + { + [Fact] + public void CertificateStoreRoundTripsPlainAndPasswordProtected() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = CreateCertificate(key); + using CertificateStore plainStore = new CertificateStore(cert, key); + using MemoryStream plainStream = new MemoryStream(); + + plainStore.WriteTo(plainStream); + plainStream.Position = 0; + using CertificateStore parsedPlain = new CertificateStore(plainStream); + + Assert.Equal(cert, parsedPlain.Certificate); + Assert.Equal(key.Algorithm, parsedPlain.PrivateKey.Algorithm); + + using CertificateStore protectedStore = new CertificateStore(cert, key, "password"); + using MemoryStream protectedStream = new MemoryStream(); + protectedStore.WriteTo(protectedStream); + + using CertificateStore parsedProtected = new CertificateStore(new MemoryStream(protectedStream.ToArray()), "password"); + + Assert.Equal(cert, parsedProtected.Certificate); + Assert.Equal(key.Algorithm, parsedProtected.PrivateKey.Algorithm); + Assert.Throws(() => new CertificateStore(new MemoryStream(protectedStream.ToArray()), "wrong")); + } + + [Fact] + public void CertificateStoreFileConstructorRoundTrips() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = CreateCertificate(key); + using CertificateStore store = new CertificateStore(cert, key, "password"); + string file = Path.Combine(Path.GetTempPath(), "certificate-store-" + Guid.NewGuid() + ".bin"); + + try + { + store.SaveAs(file); + using CertificateStore parsed = new CertificateStore(file, "password"); + Assert.Equal(cert, parsed.Certificate); + } + finally + { + if (File.Exists(file)) + File.Delete(file); + } + } + + [Fact] + public void InvalidCertificateStoreFormatThrows() + { + Assert.Throws(() => new CertificateStore(new MemoryStream([.. System.Text.Encoding.ASCII.GetBytes("CC"), 0, 0, 1]))); + Assert.Throws(() => new CertificateStore(new MemoryStream([.. System.Text.Encoding.ASCII.GetBytes("CC"), 0, .. System.Text.Encoding.ASCII.GetBytes("CS"), 255]))); + } + + private static Certificate CreateCertificate(AsymmetricCryptoKey key) + { + Certificate cert = new Certificate( + CertificateType.RootCA, + "root", + new CertificateProfile("root"), + CertificateCapability.SignCACertificate, + DateTime.UtcNow.AddMinutes(-1), + DateTime.UtcNow.AddDays(1), + key.Algorithm, + key.GetPublicKey()); + cert.SelfSign("SHA256", key, null); + return cert; + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateTests.cs new file mode 100644 index 0000000..25b8226 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CertificateTests.cs @@ -0,0 +1,105 @@ +using System.Security.Cryptography; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class CertificateTests + { + [Fact] + public void RootCertificateSelfSignsSerializesAndVerifies() + { + using AsymmetricCryptoKey rootKey = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate root = CreateCertificate(CertificateType.RootCA, "root", CertificateCapability.SignCACertificate, rootKey); + + root.SelfSign("SHA256", rootKey, new Uri("https://ca.example.test/revoke")); + + Assert.True(root.IsSigned()); + Assert.False(root.HasExpired()); + Assert.Equal(CertificateType.RootCA, root.Type); + Assert.Equal("root", root.SerialNumber); + Assert.Equal(CertificateCapability.SignCACertificate, root.Capability); + Assert.Equal(AsymmetricEncryptionAlgorithm.RSA, root.PublicKeyEncryptionAlgorithm); + Assert.Equal(new Uri("https://ca.example.test/revoke"), root.RevocationURL); + Assert.Same(root.GetHash("SHA256"), root.GetHash("SHA256")); + + root.Verify([root]); + + Certificate parsed = RoundTrip(root); + Assert.Equal(root, parsed); + Assert.True(parsed.IssuerSignature.Verify(parsed.GetHash(parsed.IssuerSignature.HashAlgorithm), parsed)); + } + + [Fact] + public void CaAndUserCertificatesVerifyThroughTrustedRoot() + { + using AsymmetricCryptoKey rootKey = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + using AsymmetricCryptoKey caKey = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + using AsymmetricCryptoKey userKey = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate root = CreateCertificate(CertificateType.RootCA, "root", CertificateCapability.SignCACertificate, rootKey); + Certificate ca = CreateCertificate(CertificateType.CA, "ca", CertificateCapability.SignAnyUserCertificate, caKey); + Certificate user = CreateCertificate(CertificateType.User, "user", CertificateCapability.UserAuthentication, userKey); + root.SelfSign("SHA256", rootKey, new Uri("https://ca.example.test/revoke")); + + ca.Sign("SHA256", root, rootKey, new Uri("https://ca.example.test/revoke")); + user.Sign("SHA256", ca, caKey, new Uri("https://ca.example.test/revoke")); + + ca.Verify([root]); + user.Verify([root]); + user.VerifyRevocationList(timeout: 1); + + Assert.Equal(root, ca.IssuerSignature.SigningCertificate); + Assert.Equal(ca, user.IssuerSignature.SigningCertificate); + Assert.NotEqual(root, user); + Assert.False(user.Equals(null)); + Assert.False(user.Equals("certificate")); + Assert.True(user.Equals(user)); + Assert.NotEqual(0, user.GetHashCode()); + } + + [Fact] + public void CertificateValidationRejectsInvalidDatesAndCapabilities() + { + using AsymmetricCryptoKey rootKey = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + using AsymmetricCryptoKey userKey = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate root = CreateCertificate(CertificateType.RootCA, "root", CertificateCapability.SignCACertificate, rootKey); + Certificate badRoot = CreateCertificate(CertificateType.RootCA, "bad-root", CertificateCapability.UserAuthentication, rootKey); + Certificate user = CreateCertificate(CertificateType.User, "user", CertificateCapability.UserAuthentication, userKey); + root.SelfSign("SHA256", rootKey, null); + badRoot.SelfSign("SHA256", rootKey, null); + + Assert.Throws(() => new Certificate(CertificateType.User, "bad", new CertificateProfile("bad"), CertificateCapability.UserAuthentication, DateTime.UtcNow.AddDays(1), DateTime.UtcNow, AsymmetricEncryptionAlgorithm.RSA, userKey.GetPublicKey())); + Assert.Throws(() => root.Sign("SHA256", root, rootKey, null)); + Assert.Throws(() => user.Sign("SHA256", user, userKey, null)); + Assert.Throws(() => badRoot.Verify([badRoot])); + Assert.Throws(() => root.Verify([])); + } + + [Fact] + public void CertificateInvalidFormatsThrow() + { + Assert.Throws(() => new Certificate(new MemoryStream([0, 1, 2]))); + Assert.Throws(() => new Certificate(new MemoryStream([.. System.Text.Encoding.ASCII.GetBytes("CE"), 255]))); + } + + private static Certificate CreateCertificate(CertificateType type, string serial, CertificateCapability capability, AsymmetricCryptoKey key) + { + return new Certificate( + type, + serial, + new CertificateProfile(serial, CertificateProfileType.Individual, new System.Net.Mail.MailAddress(serial + "@example.test")), + capability, + DateTime.UtcNow.AddMinutes(-1), + DateTime.UtcNow.AddDays(30), + key.Algorithm, + key.GetPublicKey()); + } + + private static Certificate RoundTrip(Certificate certificate) + { + using MemoryStream stream = new MemoryStream(); + certificate.WriteTo(stream); + stream.Position = 0; + return new Certificate(stream); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptoContainerTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptoContainerTests.cs new file mode 100644 index 0000000..09bfffe --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptoContainerTests.cs @@ -0,0 +1,113 @@ +using System.Text; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class CryptoContainerTests + { + [Fact] + public void PlainTextContainerRoundTripsThroughStreamAndFile() + { + TestCryptoContainer container = new TestCryptoContainer("plain text"); + using MemoryStream stream = new MemoryStream(); + + container.WriteTo(stream); + stream.Position = 0; + using TestCryptoContainer parsed = new TestCryptoContainer(stream); + + Assert.Equal("plain text", parsed.Value); + + string file = Path.Combine(Path.GetTempPath(), "crypto-container-" + Guid.NewGuid() + ".bin"); + try + { + container.SaveAs(file); + using FileStream fileStream = File.OpenRead(file); + using TestCryptoContainer parsedFile = new TestCryptoContainer(fileStream); + Assert.Equal("plain text", parsedFile.Value); + } + finally + { + if (File.Exists(file)) + File.Delete(file); + } + } + + [Fact] + public void PasswordProtectedContainerRoundTripsAndDetectsWrongPasswordOrTampering() + { + TestCryptoContainer container = new TestCryptoContainer("secret", "password"); + using MemoryStream stream = new MemoryStream(); + + container.WriteTo(stream); + byte[] protectedBytes = stream.ToArray(); + + using TestCryptoContainer parsed = new TestCryptoContainer(new MemoryStream(protectedBytes), "password"); + Assert.Equal("secret", parsed.Value); + + Assert.Throws(() => new TestCryptoContainer(new MemoryStream(protectedBytes), "wrong")); + + byte[] tampered = protectedBytes.ToArray(); + tampered[^1] ^= 0x01; + Assert.Throws(() => new TestCryptoContainer(new MemoryStream(tampered), "password")); + } + + [Fact] + public void PasswordCanBeAddedChangedAndRemoved() + { + TestCryptoContainer container = new TestCryptoContainer("secret"); + + Assert.Throws(() => container.ChangePassword("new")); + + container.SetPassword(SymmetricEncryptionAlgorithm.Rijndael, 256, "first"); + using MemoryStream protectedStream = new MemoryStream(); + container.WriteTo(protectedStream); + using TestCryptoContainer protectedParsed = new TestCryptoContainer(new MemoryStream(protectedStream.ToArray()), "first"); + Assert.Equal("secret", protectedParsed.Value); + + container.ChangePassword(null); + using MemoryStream plainStream = new MemoryStream(); + container.WriteTo(plainStream); + using TestCryptoContainer plainParsed = new TestCryptoContainer(new MemoryStream(plainStream.ToArray())); + Assert.Equal("secret", plainParsed.Value); + } + + [Fact] + public void InvalidContainerFormatThrows() + { + Assert.Throws(() => new TestCryptoContainer(new MemoryStream([0, 1, 2]))); + Assert.Throws(() => new TestCryptoContainer(new MemoryStream([.. Encoding.ASCII.GetBytes("CC"), 255]))); + } + + private sealed class TestCryptoContainer : CryptoContainer + { + public TestCryptoContainer(string value) + { + Value = value; + } + + public TestCryptoContainer(string value, string password) + : base(SymmetricEncryptionAlgorithm.Rijndael, 256, password) + { + Value = value; + } + + public TestCryptoContainer(Stream stream, string? password = null) + : base(stream, password) + { } + + public string? Value { get; private set; } + + protected override void ReadPlainTextFrom(Stream s) + { + using StreamReader reader = new StreamReader(s, Encoding.UTF8, leaveOpen: true); + Value = reader.ReadToEnd(); + } + + protected override void WritePlainTextTo(Stream s) + { + byte[] data = Encoding.UTF8.GetBytes(Value ?? string.Empty); + s.Write(data); + } + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptoExceptionTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptoExceptionTests.cs new file mode 100644 index 0000000..23d30b0 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptoExceptionTests.cs @@ -0,0 +1,47 @@ +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class CryptoExceptionTests + { + [Theory] + [MemberData(nameof(ExceptionFactories))] + public void ExceptionsPreserveMessagesAndInnerExceptions(Func defaultFactory, Func messageFactory, Func innerFactory) + { + InvalidOperationException inner = new InvalidOperationException("inner"); + + Exception defaultException = defaultFactory(); + Exception messageException = messageFactory("message"); + Exception innerException = innerFactory("message", inner); + + Assert.NotNull(defaultException.Message); + Assert.Equal("message", messageException.Message); + Assert.Equal("message", innerException.Message); + Assert.Same(inner, innerException.InnerException); + } + + public static IEnumerable ExceptionFactories() + { + yield return + [ + new Func(() => new CryptoException()), + new Func(message => new CryptoException(message)), + new Func((message, inner) => new CryptoException(message, inner)) + ]; + + yield return + [ + new Func(() => new InvalidCryptoContainerException()), + new Func(message => new InvalidCryptoContainerException(message)), + new Func((message, inner) => new InvalidCryptoContainerException(message, inner)) + ]; + + yield return + [ + new Func(() => new InvalidCertificateException()), + new Func(message => new InvalidCertificateException(message)), + new Func((message, inner) => new InvalidCertificateException(message, inner)) + ]; + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptographyProjectTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptographyProjectTests.cs new file mode 100644 index 0000000..9686e71 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/CryptographyProjectTests.cs @@ -0,0 +1,50 @@ +using System.IO; +using System.Security.Cryptography; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class CryptographyProjectTests + { + [Fact] + public void PEMFormat_RsaPrivateKeyRoundtrip() + { + using RSA rsa = RSA.Create(1024); + RSAParameters expected = rsa.ExportParameters(true); + using MemoryStream stream = new MemoryStream(); + + PEMFormat.WriteRSAPrivateKey(expected, stream); + stream.Position = 0; + RSAParameters actual = PEMFormat.ReadRSAPrivateKey(stream); + + Assert.Equal(expected.Modulus, actual.Modulus); + Assert.Equal(expected.Exponent, actual.Exponent); + Assert.Equal(expected.D, actual.D); + Assert.Equal(expected.P, actual.P); + Assert.Equal(expected.Q, actual.Q); + } + + [Fact] + public void PEMFormat_RsaPublicKeyRoundtrip() + { + using RSA rsa = RSA.Create(1024); + RSAParameters expected = rsa.ExportParameters(false); + using MemoryStream stream = new MemoryStream(); + + PEMFormat.WriteRSAPublicKey(expected, stream); + stream.Position = 0; + RSAParameters actual = PEMFormat.ReadRSAPublicKey(stream); + + Assert.Equal(expected.Modulus, actual.Modulus); + Assert.Equal(expected.Exponent, actual.Exponent); + } + + [Fact] + public void PEMFormat_InvalidHeaderThrows() + { + using MemoryStream stream = new MemoryStream(System.Text.Encoding.ASCII.GetBytes("bad pem")); + + Assert.Throws(() => PEMFormat.ReadRSAPublicKey(stream)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/KeyAgreementTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/KeyAgreementTests.cs new file mode 100644 index 0000000..9697f6d --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/KeyAgreementTests.cs @@ -0,0 +1,148 @@ +using System.Numerics; +using System.Security.Cryptography; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class KeyAgreementTests + { + [Theory] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA256, 32)] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA384, 48)] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA512, 64)] + public void HashKeyDerivationProducesExpectedLength(KeyAgreementKeyDerivationHashAlgorithm hashAlgorithm, int expectedLength) + { + TestKeyAgreement agreement = new TestKeyAgreement(KeyAgreementKeyDerivationFunction.Hash, hashAlgorithm, [1, 2, 3]); + + byte[] key = agreement.DeriveKeyMaterial([9, 8, 7]); + + Assert.Equal(expectedLength, key.Length); + Assert.Equal(KeyAgreementKeyDerivationFunction.Hash, agreement.KeyDerivationFunction); + Assert.Equal(hashAlgorithm, agreement.KeyDerivationHashAlgorithm); + Assert.Equal([9, 8, 7], agreement.LastOtherPartyPublicKey); + } + + [Theory] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA256, 32)] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA384, 48)] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA512, 64)] + public void HmacKeyDerivationProducesExpectedLength(KeyAgreementKeyDerivationHashAlgorithm hashAlgorithm, int expectedLength) + { + TestKeyAgreement agreement = new TestKeyAgreement(KeyAgreementKeyDerivationFunction.Hmac, hashAlgorithm, [1, 2, 3]) + { + HmacKey = [4, 5, 6, 7] + }; + + byte[] key = agreement.DeriveKeyMaterial([9, 8, 7]); + + Assert.Equal(expectedLength, key.Length); + Assert.Equal([4, 5, 6, 7], agreement.HmacKey); + } + + [Fact] + public void UnsupportedKeyDerivationSettingsThrow() + { + Assert.Throws(() => new TestKeyAgreement((KeyAgreementKeyDerivationFunction)99, KeyAgreementKeyDerivationHashAlgorithm.SHA256, [1]).DeriveKeyMaterial([2])); + Assert.Throws(() => new TestKeyAgreement(KeyAgreementKeyDerivationFunction.Hash, (KeyAgreementKeyDerivationHashAlgorithm)99, [1]).DeriveKeyMaterial([2])); + Assert.Throws(() => new TestKeyAgreement(KeyAgreementKeyDerivationFunction.Hmac, (KeyAgreementKeyDerivationHashAlgorithm)99, [1]) { HmacKey = [1] }.DeriveKeyMaterial([2])); + } + + [Fact] + public void DiffieHellmanPublicKeysRoundTripAndDeriveSameSecret() + { + BigInteger p = new BigInteger(467); + BigInteger g = new BigInteger(2); + DiffieHellmanPublicKey seed = new DiffieHellmanPublicKey(16, p, g, new BigInteger(5)); + DiffieHellman alice = new DiffieHellman(seed, KeyAgreementKeyDerivationFunction.Hash, KeyAgreementKeyDerivationHashAlgorithm.SHA256); + DiffieHellman bob = new DiffieHellman(seed, KeyAgreementKeyDerivationFunction.Hash, KeyAgreementKeyDerivationHashAlgorithm.SHA256); + + DiffieHellmanPublicKey alicePublic = new DiffieHellmanPublicKey(alice.GetPublicKey()); + DiffieHellmanPublicKey bobPublic = new DiffieHellmanPublicKey(bob.GetPublicKey()); + + Assert.Equal(DiffieHellmanGroupType.None, alice.Group); + Assert.Equal(16, alice.KeySize); + Assert.Equal(p, alice.P); + Assert.Equal(g, alice.G); + Assert.Equal(alicePublic.P, bobPublic.P); + Assert.Equal(alicePublic.G, bobPublic.G); + Assert.Equal(alicePublic.KeySize, bobPublic.KeySize); + Assert.Equal(alice.DeriveKeyMaterial(bob.GetPublicKey()), bob.DeriveKeyMaterial(alice.GetPublicKey())); + } + + [Fact] + public void DiffieHellmanGroupsAndInvalidPublicKeysAreValidated() + { + DiffieHellmanGroup group = DiffieHellmanGroup.GetGroup(DiffieHellmanGroupType.RFC3526_GROUP14_2048BIT); + DiffieHellman alice = new DiffieHellman(DiffieHellmanGroupType.RFC3526_GROUP14_2048BIT, KeyAgreementKeyDerivationFunction.Hmac, KeyAgreementKeyDerivationHashAlgorithm.SHA256) + { + HmacKey = [1, 2, 3] + }; + DiffieHellman bob = new DiffieHellman(DiffieHellmanGroupType.RFC3526_GROUP14_2048BIT, KeyAgreementKeyDerivationFunction.Hmac, KeyAgreementKeyDerivationHashAlgorithm.SHA256) + { + HmacKey = [1, 2, 3] + }; + + Assert.Equal(2048, group.KeySize); + Assert.Equal(DiffieHellmanGroupType.RFC3526_GROUP14_2048BIT, group.Group); + Assert.Equal(new BigInteger(2), group.G); + Assert.Equal(alice.DeriveKeyMaterial(bob.GetPublicKey()), bob.DeriveKeyMaterial(alice.GetPublicKey())); + Assert.Throws(() => DiffieHellmanGroup.GetGroup(DiffieHellmanGroupType.None)); + Assert.Throws(() => new DiffieHellmanPublicKey([0, 1, 2])); + Assert.Throws(() => new DiffieHellmanPublicKey(16, new BigInteger(467), new BigInteger(1), new BigInteger(5))); + Assert.Throws(() => new DiffieHellmanPublicKey(16, new BigInteger(467), new BigInteger(2), new BigInteger(1))); + Assert.Throws(() => alice.DeriveKeyMaterial(new DiffieHellman(DiffieHellmanGroupType.RFC3526_GROUP15_3072BIT, KeyAgreementKeyDerivationFunction.Hmac, KeyAgreementKeyDerivationHashAlgorithm.SHA256) { HmacKey = [1, 2, 3] }.GetPublicKey())); + } + + [Theory] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA256, 32)] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA384, 48)] + [InlineData(KeyAgreementKeyDerivationHashAlgorithm.SHA512, 64)] + public void ECDiffieHellmanDerivesSameSecretOnSupportedPlatforms(KeyAgreementKeyDerivationHashAlgorithm hashAlgorithm, int expectedLength) + { + if (!OperatingSystem.IsWindows()) + return; + + global::TechnitiumLibrary.Security.Cryptography.ECDiffieHellman alice = new global::TechnitiumLibrary.Security.Cryptography.ECDiffieHellman(256, KeyAgreementKeyDerivationFunction.Hash, hashAlgorithm); + global::TechnitiumLibrary.Security.Cryptography.ECDiffieHellman bob = new global::TechnitiumLibrary.Security.Cryptography.ECDiffieHellman(256, KeyAgreementKeyDerivationFunction.Hash, hashAlgorithm); + + byte[] aliceSecret = alice.DeriveKeyMaterial(bob.GetPublicKey()); + byte[] bobSecret = bob.DeriveKeyMaterial(alice.GetPublicKey()); + + Assert.Equal(expectedLength, aliceSecret.Length); + Assert.Equal(aliceSecret, bobSecret); + } + + [Fact] + public void ECDiffieHellmanUnsupportedHashThrowsOnSupportedPlatforms() + { + if (!OperatingSystem.IsWindows()) + return; + + Assert.Throws(() => new global::TechnitiumLibrary.Security.Cryptography.ECDiffieHellman(256, KeyAgreementKeyDerivationFunction.Hash, (KeyAgreementKeyDerivationHashAlgorithm)99)); + } + + private sealed class TestKeyAgreement : KeyAgreement + { + private readonly byte[] _computedKey; + + public TestKeyAgreement(KeyAgreementKeyDerivationFunction kdFunc, KeyAgreementKeyDerivationHashAlgorithm kdHashAlgo, byte[] computedKey) + : base(kdFunc, kdHashAlgo) + { + _computedKey = computedKey; + } + + public byte[]? LastOtherPartyPublicKey { get; private set; } + + public override byte[] GetPublicKey() + { + return [1, 2, 3]; + } + + protected override byte[] ComputeKey(byte[] otherPartyPublicKey) + { + LastOtherPartyPublicKey = otherPartyPublicKey; + return _computedKey; + } + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/ManagedAlgorithmTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/ManagedAlgorithmTests.cs new file mode 100644 index 0000000..0054717 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/ManagedAlgorithmTests.cs @@ -0,0 +1,63 @@ +using System.Security.Cryptography; +using System.Text; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class ManagedAlgorithmTests + { + [Fact] + public void Crc32MatchesKnownCheckValueAndCanBeReused() + { + using CRC32Managed crc32 = new CRC32Managed(); + byte[] input = Encoding.ASCII.GetBytes("123456789"); + + byte[] hash = crc32.ComputeHash(input); + crc32.Initialize(); + byte[] hashAfterReset = crc32.ComputeHash(input); + + Assert.Equal(Convert.FromHexString("CBF43926"), hash); + Assert.Equal(hash, hashAfterReset); + Assert.Equal(32, crc32.HashSize); + } + + [Fact] + public void Rc4EncryptorAndDecryptorRoundTripData() + { + byte[] key = Enumerable.Range(1, 32).Select(i => (byte)i).ToArray(); + byte[] iv = Enumerable.Range(33, 32).Select(i => (byte)i).ToArray(); + byte[] clearText = Encoding.UTF8.GetBytes("The quick brown fox jumps over the lazy dog."); + byte[] cipherText = new byte[clearText.Length]; + byte[] decrypted = new byte[clearText.Length]; + + using RC4Managed rc4 = new RC4Managed(key, iv); + using ICryptoTransform encryptor = rc4.CreateEncryptor(key, iv); + using ICryptoTransform decryptor = rc4.CreateDecryptor(key, iv); + + Assert.Equal(clearText.Length, encryptor.TransformBlock(clearText, 0, clearText.Length, cipherText, 0)); + Assert.Equal(clearText.Length, decryptor.TransformBlock(cipherText, 0, cipherText.Length, decrypted, 0)); + + Assert.NotEqual(clearText, cipherText); + Assert.Equal(clearText, decrypted); + Assert.Empty(encryptor.TransformFinalBlock([], 0, 0)); + Assert.False(encryptor.CanReuseTransform); + Assert.True(encryptor.CanTransformMultipleBlocks); + Assert.Equal(8, encryptor.InputBlockSize); + Assert.Equal(8, encryptor.OutputBlockSize); + } + + [Fact] + public void Rc4ConstructorsGenerateKeyAndIvForLegalKeySizes() + { + using RC4Managed defaultKeySize = new RC4Managed(); + using RC4Managed explicitKeySize = new RC4Managed(128); + + Assert.Equal(256, defaultKeySize.KeySize); + Assert.Equal(32, defaultKeySize.Key.Length); + Assert.Equal(32, defaultKeySize.IV.Length); + Assert.Equal(128, explicitKeySize.KeySize); + Assert.Equal(16, explicitKeySize.Key.Length); + Assert.Equal(16, explicitKeySize.IV.Length); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/PBKDF2Tests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/PBKDF2Tests.cs new file mode 100644 index 0000000..4392c19 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/PBKDF2Tests.cs @@ -0,0 +1,51 @@ +using System.Security.Cryptography; +using System.Text; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class PBKDF2Tests + { + [Theory] + [InlineData("password", "salt", 1, 20, "0C60C80F961F0E71F3A9B524AF6012062FE037A6")] + [InlineData("password", "salt", 2, 20, "EA6C014DC72D6F8CCD1ED92ACE1D41F0D8DE8957")] + [InlineData("password", "salt", 4096, 20, "4B007901B765489ABEAD49D926F721D065A429C1")] + public void HmacSha1MatchesRfc6070Vectors(string password, string salt, int iterations, int length, string expectedHex) + { + using PBKDF2 pbkdf2 = PBKDF2.CreateHMACSHA1(password, Encoding.ASCII.GetBytes(salt), iterations); + + byte[] actual = pbkdf2.GetBytes(length); + + Assert.Equal(Convert.FromHexString(expectedHex), actual); + Assert.Equal(iterations, pbkdf2.IterationCount); + Assert.Equal(Encoding.ASCII.GetBytes(salt), pbkdf2.Salt); + } + + [Fact] + public void HmacSha256MatchesFrameworkImplementation() + { + byte[] salt = [1, 2, 3, 4, 5, 6, 7, 8]; + using PBKDF2 pbkdf2 = PBKDF2.CreateHMACSHA256("password", salt, 1000); + + byte[] actual = pbkdf2.GetBytes(48); + byte[] expected = Rfc2898DeriveBytes.Pbkdf2("password", salt, 1000, HashAlgorithmName.SHA256, 48); + + Assert.Equal(expected, actual); + } + + [Fact] + public void RandomSaltFactoriesCreateRequestedSaltLength() + { + using PBKDF2 fromString = PBKDF2.CreateHMACSHA1("password", saltLength: 16, iterationCount: 2); + using PBKDF2 fromBytes = PBKDF2.CreateHMACSHA256(Encoding.UTF8.GetBytes("password"), saltLength: 24, iterationCount: 3); + + Assert.Equal(16, fromString.Salt.Length); + Assert.Equal(24, fromBytes.Salt.Length); + Assert.Equal(20, fromString.GetBytes(20).Length); + Assert.Equal(32, fromBytes.GetBytes(32).Length); + + fromString.Reset(); + fromBytes.Reset(); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/RevocationCertificateTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/RevocationCertificateTests.cs new file mode 100644 index 0000000..0dfac3b --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/RevocationCertificateTests.cs @@ -0,0 +1,94 @@ +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class RevocationCertificateTests + { + [Fact] + public void RevocationCertificateRoundTripsAndValidates() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = CreateRootCertificate(key); + RevocationCertificate revocation = new RevocationCertificate(cert, "SHA256", key); + using MemoryStream stream = new MemoryStream(); + + revocation.WriteTo(stream); + stream.Position = 0; + RevocationCertificate parsed = new RevocationCertificate(stream); + + Assert.Equal(cert.SerialNumber, parsed.SerialNumber); + Assert.Equal("SHA256", parsed.HashAlgorithm); + Assert.NotEmpty(parsed.Signature); + Assert.True(parsed.RevokedOnUTC <= DateTime.UtcNow.AddSeconds(1)); + } + + [Fact] + public void RevocationCertificateReturnsFalseForInvalidSignature() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = CreateRootCertificate(key); + RevocationCertificate revocation = new RevocationCertificate(cert, "SHA256", key); + + revocation.Signature[0] ^= 0xff; + + Assert.False(revocation.IsValid(cert)); + } + + [Fact] + public void RevocationServerResponsesCanBeWritten() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = CreateRootCertificate(key); + RevocationCertificate revocation = new RevocationCertificate(cert, "SHA256", key); + using MemoryStream found = new MemoryStream(); + using MemoryStream notFound = new MemoryStream(); + + revocation.WriteFoundServerResponseTo(found); + RevocationCertificate.WriteNotFoundServerResponseTo(notFound); + + Assert.Equal(1, found.ToArray()[0]); + Assert.Equal(0, notFound.ToArray()[0]); + found.Position = 1; + RevocationCertificate parsed = new RevocationCertificate(found); + Assert.Equal(cert.SerialNumber, parsed.SerialNumber); + } + + [Fact] + public void RevocationCertificateRejectsMismatchedSerialAndInvalidFormat() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = CreateRootCertificate(key); + Certificate other = new Certificate( + CertificateType.RootCA, + "other", + new CertificateProfile("other"), + CertificateCapability.SignCACertificate, + DateTime.UtcNow.AddMinutes(-1), + DateTime.UtcNow.AddDays(1), + key.Algorithm, + key.GetPublicKey()); + other.SelfSign("SHA256", key, null); + RevocationCertificate revocation = new RevocationCertificate(cert, "SHA256", key); + + Assert.Throws(() => revocation.IsValid(other)); + Assert.Throws(() => new RevocationCertificate(new MemoryStream([0, 1, 2]))); + Assert.Throws(() => new RevocationCertificate(new MemoryStream([.. System.Text.Encoding.ASCII.GetBytes("RC"), 255]))); + Assert.Throws(() => RevocationCertificate.IsRevoked(other, out _)); + } + + private static Certificate CreateRootCertificate(AsymmetricCryptoKey key) + { + Certificate cert = new Certificate( + CertificateType.RootCA, + "root", + new CertificateProfile("root"), + CertificateCapability.SignCACertificate, + DateTime.UtcNow.AddMinutes(-1), + DateTime.UtcNow.AddDays(1), + key.Algorithm, + key.GetPublicKey()); + cert.SelfSign("SHA256", key, new Uri("https://ca.example.test/revoke")); + return cert; + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/SignatureTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/SignatureTests.cs new file mode 100644 index 0000000..3c693c5 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/SignatureTests.cs @@ -0,0 +1,73 @@ +using System.Security.Cryptography; +using System.Text; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class SignatureTests + { + [Fact] + public void SignatureRoundTripsAndVerifiesWithCertificate() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = new Certificate( + CertificateType.RootCA, + "root", + new CertificateProfile("root"), + CertificateCapability.SignCACertificate, + DateTime.UtcNow.AddMinutes(-1), + DateTime.UtcNow.AddDays(1), + key.Algorithm, + key.GetPublicKey()); + cert.SelfSign("SHA256", key, null); + byte[] data = Encoding.UTF8.GetBytes("signed content"); + byte[] hash = SHA256.HashData(data); + + Signature signature = new Signature(hash, "SHA256", cert, key); + using MemoryStream stream = new MemoryStream(); + signature.WriteTo(stream); + stream.Position = 0; + Signature parsed = new Signature(stream); + + Assert.Equal("SHA256", parsed.HashAlgorithm); + Assert.Equal(AsymmetricEncryptionAlgorithm.RSA, parsed.SignatureAlgorithm); + Assert.Equal(cert, parsed.SigningCertificate); + Assert.True(parsed.Verify(hash, cert)); + Assert.True(parsed.Verify(new MemoryStream(data), [cert])); + Assert.False(parsed.Verify(SHA256.HashData("other"u8.ToArray()), cert)); + Assert.Equal(signature, parsed); + Assert.NotEqual(0, parsed.GetHashCode()); + Assert.StartsWith("", parsed.ToString()); + } + + [Fact] + public void SignatureConstructedFromStreamDataVerifies() + { + using AsymmetricCryptoKey key = new AsymmetricCryptoKey(AsymmetricEncryptionAlgorithm.RSA, 1024); + Certificate cert = new Certificate( + CertificateType.RootCA, + "root", + new CertificateProfile("root"), + CertificateCapability.SignCACertificate, + DateTime.UtcNow.AddMinutes(-1), + DateTime.UtcNow.AddDays(1), + key.Algorithm, + key.GetPublicKey()); + cert.SelfSign("SHA256", key, null); + byte[] data = Encoding.UTF8.GetBytes("signed content"); + + Signature signature = new Signature(new MemoryStream(data), "SHA256", cert, key); + + Assert.True(signature.Verify(SHA256.HashData(data), cert)); + Assert.False(signature.Equals(null)); + Assert.False(signature.Equals("signature")); + } + + [Fact] + public void InvalidSignatureFormatThrows() + { + Assert.Throws(() => new Signature(new MemoryStream([0, 1, 2]))); + Assert.Throws(() => new Signature(new MemoryStream([.. Encoding.ASCII.GetBytes("SN"), 255]))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/SymmetricCryptoKeyTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/SymmetricCryptoKeyTests.cs new file mode 100644 index 0000000..adbccb6 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.Cryptography/SymmetricCryptoKeyTests.cs @@ -0,0 +1,80 @@ +using System.Security.Cryptography; +using System.Text; +using TechnitiumLibrary.Security.Cryptography; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.Cryptography +{ + public class SymmetricCryptoKeyTests + { + [Theory] + [InlineData(SymmetricEncryptionAlgorithm.DES, 64)] + [InlineData(SymmetricEncryptionAlgorithm.RC2, 128)] + [InlineData(SymmetricEncryptionAlgorithm.TripleDES, 192)] + [InlineData(SymmetricEncryptionAlgorithm.Rijndael, 256)] + public void GeneratedKeyEncryptsDecryptsAndSerializes(SymmetricEncryptionAlgorithm algorithm, int keySize) + { + using SymmetricCryptoKey key = new SymmetricCryptoKey(algorithm, keySize, PaddingMode.PKCS7); + byte[] clearText = Encoding.UTF8.GetBytes("clear text for " + algorithm); + using MemoryStream cipherText = new MemoryStream(); + using MemoryStream decrypted = new MemoryStream(); + + key.Encrypt(new MemoryStream(clearText), cipherText); + byte[] encrypted = cipherText.ToArray(); + key.Decrypt(new MemoryStream(encrypted), decrypted); + + Assert.Equal(clearText, decrypted.ToArray()); + Assert.Equal(algorithm, key.Algorithm); + Assert.Equal(keySize, key.KeySize); + Assert.True(key.BlockSize > 0); + + using MemoryStream serialized = new MemoryStream(); + key.WriteTo(serialized); + serialized.Position = 0; + using SymmetricCryptoKey parsed = new SymmetricCryptoKey(serialized); + + using MemoryStream parsedCipherText = new MemoryStream(); + using MemoryStream parsedDecrypted = new MemoryStream(); + parsed.Encrypt(new MemoryStream(clearText), parsedCipherText); + parsed.Decrypt(new MemoryStream(parsedCipherText.ToArray()), parsedDecrypted); + + Assert.Equal(clearText, parsedDecrypted.ToArray()); + Assert.Equal(algorithm, parsed.Algorithm); + Assert.Equal(key.IV.Length, parsed.IV.Length); + } + + [Fact] + public void ExplicitKeyAndIvCanBeUsedWithCryptoStreams() + { + byte[] keyBytes = Enumerable.Range(1, 32).Select(i => (byte)i).ToArray(); + byte[] iv = Enumerable.Range(33, 16).Select(i => (byte)i).ToArray(); + byte[] clearText = Encoding.UTF8.GetBytes("stream encryption"); + using SymmetricCryptoKey key = new SymmetricCryptoKey(SymmetricEncryptionAlgorithm.Rijndael, keyBytes, iv, PaddingMode.PKCS7); + using MemoryStream cipherText = new MemoryStream(); + + using (CryptoStream writer = key.GetCryptoStreamWriter(cipherText)) + { + writer.Write(clearText); + writer.FlushFinalBlock(); + } + + using CryptoStream reader = key.GetCryptoStreamReader(new MemoryStream(cipherText.ToArray())); + using MemoryStream decrypted = new MemoryStream(); + reader.CopyTo(decrypted); + + Assert.Equal(clearText, decrypted.ToArray()); + } + + [Fact] + public void GenerateIvChangesIvAndInvalidFormatThrows() + { + using SymmetricCryptoKey key = new SymmetricCryptoKey(SymmetricEncryptionAlgorithm.Rijndael, 256, PaddingMode.PKCS7); + byte[] oldIv = key.IV.ToArray(); + + key.GenerateIV(); + + Assert.NotEqual(oldIv, key.IV); + Assert.Throws(() => new SymmetricCryptoKey(new MemoryStream([0, 1, 2]))); + Assert.Throws(() => new SymmetricCryptoKey(new MemoryStream([.. Encoding.ASCII.GetBytes("SK"), 255]))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/OTPProjectTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/OTPProjectTests.cs new file mode 100644 index 0000000..8714441 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/OTPProjectTests.cs @@ -0,0 +1,157 @@ +using System.Text; +using TechnitiumLibrary.Security.OTP; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.OTP +{ + public class OTPProjectTests + { + [Fact] + public void AuthenticatorKeyUri_ToStringAndParse_RoundtripEscapedValues() + { + AuthenticatorKeyUri expected = new AuthenticatorKeyUri( + "totp", + "Example Issuer", + "user@example.com", + "JBSWY3DPEHPK3PXP", + "SHA256", + 8, + 45); + + AuthenticatorKeyUri actual = AuthenticatorKeyUri.Parse(expected.ToString()); + + Assert.Equal(expected.Type, actual.Type); + Assert.Equal(expected.Issuer, actual.Issuer); + Assert.Equal(expected.AccountName, actual.AccountName); + Assert.Equal(expected.Secret, actual.Secret); + Assert.Equal(expected.Algorithm, actual.Algorithm); + Assert.Equal(expected.Digits, actual.Digits); + Assert.Equal(expected.Period, actual.Period); + } + + [Fact] + public void Authenticator_GeneratesRfc6238TotpVector() + { + AuthenticatorKeyUri keyUri = new AuthenticatorKeyUri( + "totp", + "RFC", + "test", + "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ", + "SHA1", + 8, + 30); + Authenticator authenticator = new Authenticator(keyUri); + + string totp = authenticator.GetTOTP(DateTime.UnixEpoch.AddSeconds(59)); + + Assert.Equal("94287082", totp); + } + + [Theory] + [InlineData("SHA1", "12345678901234567890", "94287082")] + [InlineData("SHA256", "12345678901234567890123456789012", "46119246")] + [InlineData("SHA512", "1234567890123456789012345678901234567890123456789012345678901234", "90693936")] + public void Authenticator_GeneratesRfc6238VectorsForSupportedAlgorithms(string algorithm, string key, string expected) + { + string secret = Base32.ToBase32String(Encoding.ASCII.GetBytes(key), skipPadding: true); + AuthenticatorKeyUri keyUri = new AuthenticatorKeyUri("totp", "RFC", "test", secret, algorithm, 8, 30); + Authenticator authenticator = new Authenticator(keyUri); + + Assert.Equal(expected, authenticator.GetTOTP(DateTime.UnixEpoch.AddSeconds(59))); + } + + [Fact] + public void AuthenticatorKeyUri_GenerateCreatesValidTotpUri() + { + AuthenticatorKeyUri keyUri = AuthenticatorKeyUri.Generate("issuer", "account", keySize: 20, algorithm: "SHA512", digits: 7, period: 60); + + Assert.Equal("totp", keyUri.Type); + Assert.Equal("issuer", keyUri.Issuer); + Assert.Equal("account", keyUri.AccountName); + Assert.Equal("SHA512", keyUri.Algorithm); + Assert.Equal(7, keyUri.Digits); + Assert.Equal(60, keyUri.Period); + Assert.Equal(32, keyUri.Secret.Length); + Assert.NotNull(new Authenticator(keyUri).GetTOTP(DateTime.UnixEpoch.AddSeconds(59))); + } + + [Fact] + public void AuthenticatorKeyUri_ParseRejectsInvalidValues() + { + Assert.Throws(() => AuthenticatorKeyUri.Parse("https://issuer/account?secret=JBSWY3DPEHPK3PXP")); + Assert.Throws(() => AuthenticatorKeyUri.Parse("otpauth://totp/issuer/account?secret=JBSWY3DPEHPK3PXP")); + Assert.Throws(() => AuthenticatorKeyUri.Parse("otpauth://totp/issuer:account?secret=JBSWY3DPEHPK3PXP&digits=abc")); + Assert.Throws(() => AuthenticatorKeyUri.Parse("otpauth://totp/issuer:account?secret=JBSWY3DPEHPK3PXP&period=abc")); + } + + [Fact] + public void Authenticator_RejectsUnsupportedTypeAndAlgorithm() + { + Assert.Throws(() => new Authenticator(new AuthenticatorKeyUri("hotp", "issuer", "account", "JBSWY3DPEHPK3PXP"))); + + Authenticator authenticator = new Authenticator(new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP", "MD5")); + Assert.Throws(() => authenticator.GetTOTP(DateTime.UnixEpoch)); + } + + [Fact] + public void Authenticator_IsTOTPValid_AcceptsCurrentCodeAndRejectsInvalidCode() + { + Authenticator authenticator = new Authenticator(new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP")); + string currentTotp = authenticator.GetTOTP(); + + Assert.True(authenticator.IsTOTPValid(currentTotp, fudge: 0)); + Assert.False(authenticator.IsTOTPValid("000000", fudge: 0)); + } + + [Fact] + public void Authenticator_IsTOTPValid_AcceptsCodesInsideFutureAndPastFudgeWindow() + { + Authenticator authenticator = new Authenticator(new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP", period: 3600)); + string futureTotp = authenticator.GetTOTP(DateTime.UtcNow.AddSeconds(3600)); + string pastTotp = authenticator.GetTOTP(DateTime.UtcNow.AddSeconds(-3600)); + + Assert.True(authenticator.IsTOTPValid(futureTotp, fudge: 1)); + Assert.True(authenticator.IsTOTPValid(pastTotp, fudge: 1)); + } + + [Fact] + public void Authenticator_KeyUri_ReturnsOriginalKeyUri() + { + AuthenticatorKeyUri keyUri = new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP"); + Authenticator authenticator = new Authenticator(keyUri); + + Assert.Same(keyUri, authenticator.KeyUri); + } + + [Fact] + public void AuthenticatorKeyUri_NullAlgorithmDefaultsToSha1() + { + AuthenticatorKeyUri keyUri = new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP", algorithm: null); + + Assert.Equal("SHA1", keyUri.Algorithm); + } + + [Fact] + public void AuthenticatorKeyUri_GetQRCodePngImage_ReturnsPngBytes() + { + AuthenticatorKeyUri keyUri = new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP"); + + byte[] png = keyUri.GetQRCodePngImage(); + + Assert.Equal(new byte[] { 137, 80, 78, 71 }, png.Take(4).ToArray()); + } + + [Theory] + [InlineData(5)] + [InlineData(9)] + public void AuthenticatorKeyUri_InvalidDigitsThrows(int digits) + { + Assert.Throws(() => new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP", digits: digits)); + } + + [Fact] + public void AuthenticatorKeyUri_InvalidPeriodThrows() + { + Assert.Throws(() => new AuthenticatorKeyUri("totp", "issuer", "account", "JBSWY3DPEHPK3PXP", period: -1)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj b/TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj new file mode 100644 index 0000000..f804ccd --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj @@ -0,0 +1,34 @@ + + + + net10.0 + enable + enable + false + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/Base32Tests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/Base32Tests.cs new file mode 100644 index 0000000..5ffaf68 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/Base32Tests.cs @@ -0,0 +1,90 @@ +using System; +using System.Linq; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class Base32Tests + { + [Fact] + public void Roundtrip_Base32_WithPadding() + { + byte[] data = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + string encoded = Base32.ToBase32String(data); + byte[] decoded = Base32.FromBase32String(encoded); + Assert.Equal(data, decoded); + } + + [Fact] + public void Roundtrip_Base32_WithoutPadding() + { + byte[] data = new byte[] { 10, 20, 30 }; + string encoded = Base32.ToBase32String(data, skipPadding: true); + Assert.DoesNotContain("=", encoded); + byte[] decoded = Base32.FromBase32String(encoded); + Assert.Equal(data, decoded); + } + + [Fact] + public void Base32Hex_Roundtrip() + { + var data = new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }; + var enc = Base32.ToBase32HexString(data); + var dec = Base32.FromBase32HexString(enc); + Assert.Equal(data, dec); + } + + [Fact] + public void FromBase32_InvalidPadding_Throws() + { + // The implementation may throw different exception types for malformed input; assert that it fails. + Assert.ThrowsAny(() => Base32.FromBase32String("=====")); + } + + [Theory] + [InlineData(1, "AE======")] + [InlineData(2, "AEBA====")] + [InlineData(3, "AEBAG===")] + [InlineData(4, "AEBAGBA=")] + [InlineData(5, "AEBAGBAF")] + public void ToBase32String_CoversAllPaddingLengths(int length, string expected) + { + byte[] data = Enumerable.Range(1, length).Select(i => (byte)i).ToArray(); + + Assert.Equal(expected, Base32.ToBase32String(data)); + Assert.Equal(expected.TrimEnd('='), Base32.ToBase32String(data, skipPadding: true)); + Assert.Equal(data, Base32.FromBase32String(expected)); + Assert.Equal(data, Base32.FromBase32String(expected.TrimEnd('='))); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + public void Base32Hex_CoversAllPaddingLengths(int length) + { + byte[] data = Enumerable.Range(1, length).Select(i => (byte)i).ToArray(); + string encoded = Base32.ToBase32HexString(data); + + Assert.Equal(data, Base32.FromBase32HexString(encoded)); + Assert.Equal(data, Base32.FromBase32HexString(encoded.TrimEnd('='))); + } + + [Fact] + public void FromBase32String_SingleCharacterThrowsIndexOutOfRange() + { + Assert.Throws(() => Base32.FromBase32String("A")); + } + + [Theory] + [InlineData("AA==")] + [InlineData("AAA")] + [InlineData("AAAAAA")] + public void FromBase32String_InvalidLengthsThrow(string value) + { + Assert.ThrowsAny(() => Base32.FromBase32String(value)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/BinaryNumberTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/BinaryNumberTests.cs new file mode 100644 index 0000000..8316d72 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/BinaryNumberTests.cs @@ -0,0 +1,183 @@ +using System; +using System.IO; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class BinaryNumberTests + { + [Fact] + public void ParseAndToString_Roundtrip() + { + var hex = "0a0b0c"; + var bn = BinaryNumber.Parse(hex); + Assert.Equal(hex, bn.ToString()); + } + + [Fact] + public void Clone_EqualsOriginal() + { + var data = new byte[] { 1, 2, 3 }; + var bn = new BinaryNumber(data); + var clone = bn.Clone(); + Assert.Equal(bn, clone); + Assert.NotSame(bn.Value, clone.Value); + } + + [Fact] + public void WriteTo_And_ReadFromStream() + { + var bytes = new byte[] { 0xAA, 0xBB }; + var bn = new BinaryNumber(bytes); + + using MemoryStream ms = new MemoryStream(); + bn.WriteTo(ms); + ms.Position = 0; + + using BinaryReader br = new BinaryReader(ms); + var read = new BinaryNumber(br); + Assert.Equal(bn, read); + } + + [Fact] + public void BitwiseOperators_WorkAndCompare() + { + var a = new BinaryNumber(new byte[] { 0xFF, 0x00 }); + var b = new BinaryNumber(new byte[] { 0x0F, 0xFF }); + + var or = a | b; + var and = a & b; + var xor = a ^ b; + + Assert.Equal(new BinaryNumber(new byte[] { 0xFF, 0xFF }), or); + Assert.Equal(new BinaryNumber(new byte[] { 0x0F, 0x00 }), and); + Assert.Equal(new BinaryNumber(new byte[] { 0xF0, 0xFF }), xor); + + Assert.True(a > b || a < b || a == b); // simple sanity + } + + [Fact] + public void CompareTo_DifferentLength_Throws() + { + var a = new BinaryNumber(new byte[] { 1, 2 }); + var b = new BinaryNumber(new byte[] { 1 }); + Assert.Throws(() => a.CompareTo(b)); + } + + [Fact] + public void ShiftOperators_DoNotThrow() + { + var v = new BinaryNumber(new byte[] { 0x01, 0x02, 0x03 }); + var r1 = v >> 4; + var r2 = v << 9; + Assert.Equal(3, r1.Value.Length); + Assert.Equal(3, r2.Value.Length); + } + + [Fact] + public void StaticFactories_CreateExpectedLengthsAndValues() + { + Assert.Equal(20, BinaryNumber.GenerateRandomNumber160().Value.Length); + Assert.Equal(32, BinaryNumber.GenerateRandomNumber256().Value.Length); + Assert.Equal(20, BinaryNumber.MaxValueNumber160().Value.Length); + Assert.All(BinaryNumber.MaxValueNumber160().Value, b => Assert.Equal(0xFF, b)); + + byte[] source = new byte[] { 9, 8, 7, 6 }; + BinaryNumber clone = BinaryNumber.Clone(source, 1, 2); + Assert.Equal(new byte[] { 8, 7 }, clone.Value); + Assert.NotSame(source, clone.Value); + } + + [Fact] + public void StaticEquals_CoversReferenceNullLengthAndByteMismatch() + { + byte[] value = new byte[] { 1, 2 }; + + Assert.True(BinaryNumber.Equals(value, value)); + Assert.False(BinaryNumber.Equals(value, null)); + Assert.False(BinaryNumber.Equals(null, value)); + Assert.False(BinaryNumber.Equals(value, new byte[] { 1 })); + Assert.False(BinaryNumber.Equals(value, new byte[] { 1, 3 })); + Assert.True(BinaryNumber.Equals(value, new byte[] { 1, 2 })); + } + + [Fact] + public void EqualityComparisonAndHashCode_CoverBranches() + { + BinaryNumber a = new BinaryNumber(new byte[] { 1, 2 }); + BinaryNumber same = new BinaryNumber(new byte[] { 1, 2 }); + BinaryNumber smaller = new BinaryNumber(new byte[] { 1, 1 }); + BinaryNumber larger = new BinaryNumber(new byte[] { 1, 3 }); + + Assert.True(a.Equals((object)same)); + Assert.False(a.Equals((object)"not binary")); + Assert.False(a.Equals(null!)); + Assert.Equal(0, a.CompareTo(same)); + Assert.True(a.CompareTo(smaller) > 0); + Assert.True(a.CompareTo(larger) < 0); + Assert.Equal(a.GetHashCode(), same.GetHashCode()); + BinaryNumber sameReference = a; + Assert.True(a == sameReference); + Assert.False(a != sameReference); + } + + [Fact] + public void StreamConstructor_ReadsBinaryNumber() + { + BinaryNumber expected = new BinaryNumber(new byte[] { 4, 5, 6 }); + using MemoryStream stream = new MemoryStream(); + expected.WriteTo(stream); + stream.Position = 0; + + BinaryNumber actual = new BinaryNumber((Stream)stream); + + Assert.Equal(expected, actual); + } + + [Fact] + public void RelationalOperators_CoverEqualLessGreaterAndLengthMismatch() + { + BinaryNumber a = new BinaryNumber(new byte[] { 1, 2 }); + BinaryNumber same = new BinaryNumber(new byte[] { 1, 2 }); + BinaryNumber smaller = new BinaryNumber(new byte[] { 1, 1 }); + BinaryNumber larger = new BinaryNumber(new byte[] { 1, 3 }); + BinaryNumber differentLength = new BinaryNumber(new byte[] { 1 }); + + Assert.True(a == same); + Assert.False(a != same); + Assert.True(a != smaller); + Assert.True(smaller < a); + Assert.False(a < smaller); + Assert.False(a < same); + Assert.True(larger > a); + Assert.False(a > larger); + Assert.False(a > same); + Assert.True(smaller <= a); + Assert.True(a <= same); + Assert.False(larger <= a); + Assert.True(larger >= a); + Assert.True(a >= same); + Assert.False(smaller >= a); + + Assert.Throws(() => _ = a | differentLength); + Assert.Throws(() => _ = a & differentLength); + Assert.Throws(() => _ = a ^ differentLength); + Assert.Throws(() => _ = a < differentLength); + Assert.Throws(() => _ = a > differentLength); + Assert.Throws(() => _ = a <= differentLength); + Assert.Throws(() => _ = a >= differentLength); + } + + [Fact] + public void ShiftAndNotOperators_ProduceExpectedValues() + { + BinaryNumber value = new BinaryNumber(new byte[] { 0x12, 0x34 }); + + Assert.Equal(new BinaryNumber(new byte[] { 0x01, 0x23 }), value >> 4); + Assert.Equal(new BinaryNumber(new byte[] { 0x00, 0x12 }), value >> 8); + Assert.Equal(new BinaryNumber(new byte[] { 0x23, 0x40 }), value << 4); + Assert.Equal(new BinaryNumber(new byte[] { 0x34, 0x00 }), value << 8); + Assert.Equal(new BinaryNumber(new byte[] { 0xED, 0xCB }), ~value); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/CollectionExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/CollectionExtensionsTests.cs new file mode 100644 index 0000000..0f5ebe5 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/CollectionExtensionsTests.cs @@ -0,0 +1,89 @@ +using System; +using System.Collections.Generic; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class CollectionExtensionsTests + { + [Fact] + public void Convert_ListAndCollection() + { + var list = new List { 1, 2, 3 }; + var converted = ((IReadOnlyList)list).Convert(i => i * 2); + Assert.Equal(new int[] { 2, 4, 6 }, converted); + + var set = new HashSet { 4, 5 }; + var conv2 = ((IReadOnlyCollection)set).Convert(i => i + 1); + Assert.Contains(5, conv2); + } + + [Fact] + public void ListEqualsAndHasSameItems() + { + var a = new List { "x", "y" }; + var b = new List { "x", "y" }; + var c = new List { "y", "x" }; + + Assert.True(((IReadOnlyList)a).ListEquals(b)); + Assert.False(((IReadOnlyList)a).ListEquals(c)); + + Assert.True(((IReadOnlyCollection)a).HasSameItems(c)); + } + + [Fact] + public void Interleave_MergesLists() + { + var l1 = new List { 1, 3 }; + var l2 = new List { 2, 4, 5 }; + var inter = ((IReadOnlyList)l1).Interleave(l2); + Assert.Equal(5, inter.Count); + Assert.Equal(new int[] { 1, 2, 3, 4, 5 }, inter); + } + + [Fact] + public void Shuffle_PreservesItems() + { + var values = new List { 1, 2, 3, 4, 5 }; + + values.Shuffle(); + + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, values.OrderBy(x => x).ToArray()); + } + + [Fact] + public void ListEquals_CoversReferenceNullCountAndItemMismatch() + { + IReadOnlyList list = new List { "a", "b" }; + + Assert.True(list.ListEquals(list)); + Assert.False(list.ListEquals(null!)); + Assert.False(((IReadOnlyList)null!).ListEquals(list)); + Assert.False(list.ListEquals(new List { "a" })); + Assert.False(list.ListEquals(new List { "a", "c" })); + } + + [Fact] + public void HasSameItems_CoversReferenceNullCountAndMissingItem() + { + IReadOnlyCollection list = new List { "a", "b" }; + + Assert.True(list.HasSameItems(list)); + Assert.False(list.HasSameItems(null!)); + Assert.False(((IReadOnlyCollection)null!).HasSameItems(list)); + Assert.False(list.HasSameItems(new List { "a" })); + Assert.False(list.HasSameItems(new List { "a", "c" })); + } + + [Fact] + public void GetArrayHashCode_ReturnsZeroForNullAndStableValueForSameItems() + { + IReadOnlyCollection values = new List { 1, 2, 3 }; + IReadOnlyCollection sameValues = new List { 1, 2, 3 }; + + Assert.Equal(0, ((IReadOnlyCollection)null!).GetArrayHashCode()); + Assert.Equal(values.GetArrayHashCode(), sameValues.GetArrayHashCode()); + Assert.NotEqual(0, values.GetArrayHashCode()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/IndependentTaskSchedulerTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/IndependentTaskSchedulerTests.cs new file mode 100644 index 0000000..a96e8a2 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/IndependentTaskSchedulerTests.cs @@ -0,0 +1,57 @@ +using System; +using System.Threading.Tasks; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class IndependentTaskSchedulerTests + { + [Fact] + public async Task TaskScheduled_OnSchedulerExecutes() + { + using var scheduler = new IndependentTaskScheduler(1); + var tcs = new TaskCompletionSource(); + + await Task.Factory.StartNew(() => + { + tcs.SetResult(5); + }, System.Threading.CancellationToken.None, TaskCreationOptions.None, scheduler); + + var result = await tcs.Task.TimeoutAfter(1000); + Assert.Equal(5, result); + } + + [Fact] + public void MaximumConcurrencyLevel_Property() + { + using var scheduler = new IndependentTaskScheduler(3); + Assert.Equal(3, scheduler.MaximumConcurrencyLevel); + } + + [Fact] + public async Task LongRunningTask_ExecutesOnDedicatedThread() + { + using IndependentTaskScheduler scheduler = new IndependentTaskScheduler(maximumConcurrencyLevel: 1); + Task task = Task.Factory.StartNew( + () => Environment.CurrentManagedThreadId, + System.Threading.CancellationToken.None, + TaskCreationOptions.LongRunning, + scheduler); + + Assert.True(await task.TimeoutAfter(1000) > 0); + } + + [Fact] + public void Dispose_IsIdempotent() + { + IndependentTaskScheduler scheduler = new IndependentTaskScheduler(maximumConcurrencyLevel: 1); + + scheduler.Dispose(); + scheduler.Dispose(); + + Assert.Equal(1, scheduler.MaximumConcurrencyLevel); + } + } +} + +// reuse TestHelpers.TimeoutAfter from TaskPoolTests diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/JsonExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/JsonExtensionsTests.cs new file mode 100644 index 0000000..e773523 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/JsonExtensionsTests.cs @@ -0,0 +1,156 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text.Json; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class JsonExtensionsTests + { + [Fact] + public void ReadArrayAndSet_And_GetPropertyValue() + { + using var doc = JsonDocument.Parse("{\"arr\":[\"a\",\"b\"], \"val\": 5, \"flag\": true}"); + var root = doc.RootElement; + + var arr = root.ReadArray("arr"); + Assert.Equal(new string[] { "a", "b" }, arr); + + Assert.True(root.TryReadArray("arr", out string[] arr2)); + Assert.Equal(arr, arr2); + + Assert.Equal(5, root.GetPropertyValue("val", 0)); + Assert.True(root.GetPropertyValue("flag", false)); + Assert.Equal("x", root.GetPropertyValue("missing", "x")); + } + + [Fact] + public void ReadObjectAsMap_WritesStringArray() + { + using var doc = JsonDocument.Parse("{\"map\": {\"k1\": \"v1\"}}"); + var root = doc.RootElement; + + var map = root.ReadObjectAsMap("map", (k, v) => Tuple.Create(k, v.GetString())); + Assert.Equal("v1", map["k1"]); + + using var ms = new MemoryStream(); + using var writer = new Utf8JsonWriter(ms); + writer.WriteStartObject(); + writer.WriteStringArray("arr", new List { 1, 2 }); + writer.WriteEndObject(); + writer.Flush(); + // ensure method runs without throwing + } + + [Fact] + public void GetArray_ReadsJsonArrayElementDirectly() + { + using JsonDocument doc = JsonDocument.Parse("[\"x\",\"y\"]"); + + Assert.Equal(new[] { "x", "y" }, doc.RootElement.GetArray()); + } + + [Fact] + public void TryReadMethods_ReturnFalseWhenPropertyMissing() + { + using JsonDocument doc = JsonDocument.Parse("{}"); + JsonElement root = doc.RootElement; + + Assert.False(root.TryReadArray("missing", out string[] array)); + Assert.Null(array); + Assert.False(root.TryReadArray("missing", int.Parse, out int[] parsed)); + Assert.Null(parsed); + Assert.False(root.TryReadArray("missing", e => e.GetInt32(), out int[] elements)); + Assert.Null(elements); + Assert.False(root.TryReadArrayAsSet("missing", out HashSet set)); + Assert.Null(set); + Assert.False(root.TryReadArrayAsMap("missing", e => Tuple.Create(e.GetProperty("k").GetString()!, e.GetProperty("v").GetInt32()), out Dictionary map)); + Assert.Null(map); + Assert.False(root.TryReadObjectAsMap("missing", (k, v) => Tuple.Create(k, v.GetInt32()), out Dictionary objectMap)); + Assert.Null(objectMap); + } + + [Fact] + public void ReadArrayOverloads_ParseStringsElementsSetsAndMaps() + { + using JsonDocument doc = JsonDocument.Parse("{\"numbers\":[\"1\",\"2\"],\"objects\":[{\"k\":\"a\",\"v\":1},{\"k\":\"b\",\"v\":2},null],\"set\":[\"a\",\"a\",\"b\"]}"); + JsonElement root = doc.RootElement; + + Assert.Equal(new[] { 1, 2 }, root.ReadArray("numbers", int.Parse)); + Assert.True(root.TryReadArray("numbers", int.Parse, out int[] numbers)); + Assert.Equal(new[] { 1, 2 }, numbers); + Assert.Equal(new[] { 1, 2 }, root.ReadArray("numbers", e => int.Parse(e.GetString()!))); + Assert.True(root.TryReadArray("numbers", e => int.Parse(e.GetString()!), out int[] elements)); + Assert.Equal(new[] { 1, 2 }, elements); + + HashSet set = root.ReadArrayAsSet("set"); + Assert.Equal(new[] { "a", "b" }, set.OrderBy(x => x).ToArray()); + Assert.True(root.TryReadArrayAsSet("set", out HashSet set2)); + Assert.Equal(set, set2); + + Dictionary map = root.ReadArrayAsMap("objects", e => + e.ValueKind == JsonValueKind.Null ? null : Tuple.Create(e.GetProperty("k").GetString()!, e.GetProperty("v").GetInt32())); + Assert.Equal(2, map.Count); + Assert.Equal(2, map["b"]); + Assert.True(root.TryReadArrayAsMap("objects", e => + e.ValueKind == JsonValueKind.Null ? null : Tuple.Create(e.GetProperty("k").GetString()!, e.GetProperty("v").GetInt32()), out Dictionary map2)); + Assert.Equal(map, map2); + } + + [Fact] + public void NullArrays_ReturnNullAndInvalidKindsThrow() + { + using JsonDocument doc = JsonDocument.Parse("{\"value\":5,\"items\":null}"); + JsonElement root = doc.RootElement; + + Assert.Null(root.ReadArray("items")); + Assert.Null(root.ReadArray("items", int.Parse)); + Assert.Null(root.ReadArray("items", e => e.GetInt32())); + Assert.Null(root.ReadArrayAsSet("items")); + Assert.Null(root.ReadArrayAsMap("items", e => Tuple.Create("x", 1))); + + Assert.Throws(() => root.ReadArray("value")); + Assert.Throws(() => root.ReadArray("value", int.Parse)); + Assert.Throws(() => root.ReadArray("value", e => e.GetInt32())); + Assert.Throws(() => root.ReadArrayAsSet("value")); + Assert.Throws(() => root.ReadArrayAsMap("value", e => Tuple.Create("x", 1))); + Assert.Throws(() => root.TryReadObjectAsMap("value", (k, v) => Tuple.Create(k, v.GetInt32()), out _)); + } + + [Fact] + public void TryReadObjectAsMap_CoversObjectNullAndSkippedItems() + { + using JsonDocument doc = JsonDocument.Parse("{\"map\":{\"a\":1,\"skip\":2},\"nullMap\":null}"); + JsonElement root = doc.RootElement; + + Assert.True(root.TryReadObjectAsMap("map", (k, v) => k == "skip" ? null : Tuple.Create(k, v.GetInt32()), out Dictionary map)); + Assert.Single(map); + Assert.Equal(1, map["a"]); + + Assert.True(root.TryReadObjectAsMap("nullMap", (k, v) => Tuple.Create(k, v.GetInt32()), out Dictionary nullMap)); + Assert.Null(nullMap); + } + + [Fact] + public void GetPropertyValue_HandlesNumericEnumAndParsedValues() + { + using JsonDocument doc = JsonDocument.Parse("{\"name\":\"abc\",\"flag\":true,\"i\":-5,\"u\":5,\"l\":1234567890123,\"kind\":\"Friday\",\"parsed\":\"10\"}"); + JsonElement root = doc.RootElement; + + Assert.Equal("abc", root.GetPropertyValue("name", "default")); + Assert.Equal(-5, root.GetPropertyValue("i", 0)); + Assert.Equal(5u, root.GetPropertyValue("u", 0u)); + Assert.Equal(1234567890123L, root.GetPropertyValue("l", 0L)); + Assert.Equal(10, root.GetPropertyValue("parsed", int.Parse, 0)); + Assert.Equal(DayOfWeek.Friday, root.GetPropertyEnumValue("kind", DayOfWeek.Monday)); + + Assert.False(root.GetPropertyValue("missingBool", false)); + Assert.Equal(6, root.GetPropertyValue("missingInt", 6)); + Assert.Equal(7u, root.GetPropertyValue("missingUInt", 7u)); + Assert.Equal(8L, root.GetPropertyValue("missingLong", 8L)); + Assert.Equal(9, root.GetPropertyValue("missingParsed", int.Parse, 9)); + Assert.Equal(DayOfWeek.Sunday, root.GetPropertyEnumValue("missingEnum", DayOfWeek.Sunday)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/StringExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/StringExtensionsTests.cs new file mode 100644 index 0000000..293e025 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/StringExtensionsTests.cs @@ -0,0 +1,36 @@ +using System; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class StringExtensionsTests + { + [Fact] + public void Split_ParseIntArray() + { + var input = "1, 2,3"; + var arr = input.Split(s => int.Parse(s), ','); + Assert.Equal(new int[] { 1, 2, 3 }, arr); + } + + [Fact] + public void Join_Enumerable() + { + var joined = new int[] { 1, 2, 3 }.Join(','); + Assert.Equal("1, 2, 3", joined); + } + + [Fact] + public void ParseColonHexString_Valid() + { + var bytes = "0A:FF:10".ParseColonHexString(); + Assert.Equal(new byte[] { 0x0A, 0xFF, 0x10 }, bytes); + } + + [Fact] + public void ParseColonHexString_Invalid_Throws() + { + Assert.Throws(() => "0G:12".ParseColonHexString()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskExtensionsTests.cs new file mode 100644 index 0000000..400c481 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskExtensionsTests.cs @@ -0,0 +1,68 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class TaskExtensionsTests + { + [Fact] + public async Task TimeoutAsync_CompletesBeforeTimeout() + { + await TaskExtensions.TimeoutAsync(ct => Task.CompletedTask, 100); + } + + [Fact] + public async Task TimeoutAsync_ThrowsOnTimeout() + { + await Assert.ThrowsAsync(() => TaskExtensions.TimeoutAsync(async ct => await Task.Delay(500, ct), 50)); + } + + [Fact] + public void Sync_ReturnsValue() + { + var t = Task.FromResult(42); + Assert.Equal(42, t.Sync()); + } + + [Fact] + public async Task TimeoutAsync_GenericCompletesBeforeTimeout() + { + int result = await TaskExtensions.TimeoutAsync(ct => Task.FromResult(123), 100); + + Assert.Equal(123, result); + } + + [Fact] + public async Task TimeoutAsync_GenericThrowsOnTimeout() + { + await Assert.ThrowsAsync(() => TaskExtensions.TimeoutAsync(async ct => + { + await Task.Delay(500, ct); + return 123; + }, 50)); + } + + [Fact] + public async Task TimeoutAsync_ExternalCancellationThrowsOperationCanceled() + { + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(() => TaskExtensions.TimeoutAsync(ct => Task.Delay(500, ct), 50, cts.Token)); + } + + [Fact] + public void Sync_WaitsForTaskAndValueTasks() + { + bool completed = false; + Task.Run(() => completed = true).Sync(); + global::TechnitiumLibrary.TaskExtensions.Sync((Task)Task.CompletedTask); + + Assert.True(completed); + new ValueTask(Task.CompletedTask).Sync(); + Assert.Equal(7, new ValueTask(7).Sync()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskPoolTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskPoolTests.cs new file mode 100644 index 0000000..3bef9f5 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskPoolTests.cs @@ -0,0 +1,61 @@ +using System; +using System.Threading.Tasks; +using Xunit; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + public class TaskPoolTests + { + [Fact] + public async Task TryQueueTask_ExecutesQueuedTask() + { + using var pool = new TaskPool(); + var tcs = new TaskCompletionSource(); + + bool queued = pool.TryQueueTask(async _ => { tcs.SetResult(7); await Task.CompletedTask; }); + Assert.True(queued); + + var result = await tcs.Task.TimeoutAfter(1000); + Assert.Equal(7, result); + } + + [Fact] + public async Task TryQueueTask_WithState_ExecutesQueuedTask() + { + using TaskPool pool = new TaskPool(queueSize: 4, maximumConcurrencyLevel: 1); + TaskCompletionSource tcs = new TaskCompletionSource(); + + Assert.Equal(4, pool.QueueSize); + Assert.Equal(1, pool.MaximumConcurrencyLevel); + Assert.True(pool.TryQueueTask(state => + { + tcs.SetResult((string)state); + return Task.CompletedTask; + }, "state-value")); + + Assert.Equal("state-value", await tcs.Task.TimeoutAfter(1000)); + } + + [Fact] + public void Dispose_IsIdempotentAndRejectsFurtherWrites() + { + TaskPool pool = new TaskPool(queueSize: 1, maximumConcurrencyLevel: 1); + + pool.Dispose(); + pool.Dispose(); + + Assert.False(pool.TryQueueTask(_ => Task.CompletedTask)); + } + } +} + +public static class TestHelpers +{ + public static async Task TimeoutAfter(this Task task, int ms) + { + var delay = Task.Delay(ms); + var finished = await Task.WhenAny(task, delay); + if (finished == delay) throw new TimeoutException(); + return await task; + } +} diff --git a/TechnitiumLibrary.Tests/coverage.runsettings b/TechnitiumLibrary.Tests/coverage.runsettings new file mode 100644 index 0000000..23724d6 --- /dev/null +++ b/TechnitiumLibrary.Tests/coverage.runsettings @@ -0,0 +1,16 @@ + + + + + + + cobertura + [*.Tests]*,[xunit.*]* + CompilerGeneratedAttribute,GeneratedCodeAttribute,ObsoleteAttribute + **/bin/**/*.cs,**/obj/**/*.cs + true + + + + + diff --git a/TechnitiumLibrary.sln b/TechnitiumLibrary.sln index 9cbcda3..0374441 100644 --- a/TechnitiumLibrary.sln +++ b/TechnitiumLibrary.sln @@ -25,6 +25,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TechnitiumLibrary", "Techni EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TechnitiumLibrary.Security.OTP", "TechnitiumLibrary.Security.OTP\TechnitiumLibrary.Security.OTP.csproj", "{72AF4EB6-EB81-4655-9998-8BF24B304614}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TechnitiumLibrary.Tests", "TechnitiumLibrary.Tests\TechnitiumLibrary.Tests.csproj", "{37DA8491-32F6-64D5-76B3-4830986FA4C8}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -75,6 +77,10 @@ Global {72AF4EB6-EB81-4655-9998-8BF24B304614}.Debug|Any CPU.Build.0 = Debug|Any CPU {72AF4EB6-EB81-4655-9998-8BF24B304614}.Release|Any CPU.ActiveCfg = Release|Any CPU {72AF4EB6-EB81-4655-9998-8BF24B304614}.Release|Any CPU.Build.0 = Release|Any CPU + {37DA8491-32F6-64D5-76B3-4830986FA4C8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {37DA8491-32F6-64D5-76B3-4830986FA4C8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {37DA8491-32F6-64D5-76B3-4830986FA4C8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {37DA8491-32F6-64D5-76B3-4830986FA4C8}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE