using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using InABox.Clients; using InABox.Core; using Logger = InABox.Core.Logger; namespace InABox.Rpc { public abstract class RpcClientTransport : IRpcClientTransport { private RpcClientSession _session = new RpcClientSession(); private ConcurrentDictionary _events = new ConcurrentDictionary(); private ConcurrentDictionary _responses = new ConcurrentDictionary(); private const int DefaultRequestTimeout = 5 * 60 * 1000; // 5 minutes public event RpcTransportOpenEvent? OnOpen; protected void DoOpen() => OnOpen?.Invoke(this, new RpcTransportOpenArgs(_session)); public event RpcTransportCloseEvent? OnClose; protected void DoClose(RpcTransportCloseEventType type) { foreach (var ev in _events) { _responses.TryAdd(ev.Key, new RpcMessage() { Error = RpcError.DISCONNECTED }); ev.Value.Set(); } OnClose?.Invoke(this, new RpcTransportCloseArgs(_session, type)); } public event RpcTransportExceptionEvent? OnException; protected void DoException(Exception e) => OnException?.Invoke(this, new RpcTransportExceptionArgs(_session, e)); public event RpcTransportMessageEvent? OnMessage; protected void DoMessage(RpcMessage message) => OnMessage?.Invoke(this, new RpcTransportMessageArgs(_session, message)); public abstract bool Connect(CancellationToken ct = default); public abstract bool IsConnected(); public abstract bool IsSecure(); public abstract string? ServerName(); public abstract void Send(RpcMessage message); public void Accept(RpcMessage? message) { if (message == null) return; if (_events.TryGetValue(message.Id, out var ev)) { _responses[message.Id] = message; ev.Set(); } else { Task.Run(() => { DoMessage(message); }).ContinueWith(task => { if (task.Exception != null) { Logger.Send(LogType.Error, "", $"Error in RPC Client Push: {CoreUtils.FormatException(task.Exception)}"); } }); } } public abstract void Disconnect(); public TResult Send(TParameters parameters) where TCommand : IRpcCommand where TParameters : IRpcCommandParameters, ISerializeBinary where TResult : IRpcCommandResult, ISerializeBinary, new() { CheckConnection(); var request = new RpcMessage() { Id = Guid.NewGuid(), Command = typeof(TCommand).Name, Payload = Serialization.WriteBinary(parameters, BinarySerializationSettings.Latest) }; var ev = Queue(request.Id); Send(request); var response = GetResponse(request.Id, ev, DefaultRequestTimeout) ?? throw new Exception($"{typeof(TCommand).Name}({request.Id}) returned NULL"); if (response.Error != RpcError.NONE)throw new RpcException($"Server error in {typeof(TCommand).Name}({request.Id})", response.Error); var result = Serialization.ReadBinary(response.Payload, BinarySerializationSettings.Latest) ?? throw new Exception($"Cannot Deserialize {typeof(TCommand).Name}({request.Id})"); return result; } private void CheckConnection() where TCommand : IRpcCommand where TParameters : IRpcCommandParameters where TResult : IRpcCommandResult { if (!IsConnected()) throw new RpcException($"Transport Disconnected: {typeof(TCommand).Name}()", RpcError.DISCONNECTED); } public ManualResetEventSlim Queue(Guid id) { var ev = new ManualResetEventSlim(); _events[id] = ev; return ev; } public RpcMessage? GetResponse(Guid id, ManualResetEventSlim ev, int timeout) { if (_responses.TryGetValue(id, out var response)) { _responses.Remove(id, out response); _events.Remove(id, out ev); return response; } try { if (!ev.Wait(timeout)) { return new RpcMessage() { Error = RpcError.TIMEOUT }; } } catch (Exception e) { Logger.Send(LogType.Error, "", e.Message); throw; } _responses.Remove(id, out response); _events.Remove(id, out ev); return response ?? new RpcMessage() { Error = RpcError.UNKNOWN }; } protected abstract RpcClientTransport Clone(); public bool Ping() { bool result = false; try { var transport = Clone(); transport.Connect(); result = transport.IsConnected(); transport.Disconnect(); } catch { } return result; } public DatabaseInfo? Info() { try { var transport = Clone(); transport.Connect(); var result = transport.Send(new RpcInfoParameters()).Info; transport.Disconnect(); return result; } catch { return null; } } } }