using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading; using InABox.Clients; using InABox.Core; namespace InABox.Rpc { public class RpcClient : BaseClient where TEntity : Entity, new() { private IRpcClientTransport _transport; private ConcurrentDictionary _events = new ConcurrentDictionary(); private ConcurrentDictionary _responses = new ConcurrentDictionary(); private const int DefaultRequestTimeout = 5 * 60 * 1000; // 5 minutes public RpcClient(IRpcClientTransport transport) { _transport = transport; _transport.OnMessage += Transport_Message; if (!_transport.IsConnected()) _transport.Connect(); } ~RpcClient() { _transport.OnMessage -= Transport_Message; } #region TransportManagement public override bool IsConnected() => _transport?.IsConnected() == true; // I'm assuming this is for unexpected messages (notifications, etc)? private void Transport_Message(IRpcTransport transport, RpcTransportMessageArgs e) { RaiseLogEvent(LogType.Error, "", "Message received: ({0}) -> {1}", e.Message.Command, e.Message.Payload); } public TResult Send(TParameters parameters) where TCommand : IRpcCommand where TParameters : ISerializeBinary where TResult : ISerializeBinary, new() { var request = new RpcMessage() { Id = new Guid(), Command = typeof(TCommand).Name, Payload = parameters.WriteBinary(BinarySerializationSettings.Latest) }; var response = Send(request); if (response.Error != RpcError.NONE) throw new Exception($"Exception in {typeof(TCommand).Name}({request.Id}): {response.Error}"); var result = Serialization.ReadBinary(response.Payload, BinarySerializationSettings.Latest); if (result == null) throw new Exception($"{typeof(TCommand).Name}({request.Id}) returned NULL"); return result; } public RpcMessage Send(RpcMessage request, int timeout = DefaultRequestTimeout) { var start = DateTime.Now; var ev = Queue(request.Id); _transport.Send(request); var result = GetResult(request.Id, ev, timeout); return result; } public ManualResetEventSlim Queue(Guid id) { var ev = new ManualResetEventSlim(); _events[id] = ev; return ev; } public RpcMessage GetResult(Guid id, ManualResetEventSlim ev, int timeout) { if (_responses.TryGetValue(id, out var result)) { _responses.Remove(id, out result); _events.Remove(id, out ev); return result; } try { if (!ev.Wait(timeout)) { return new RpcMessage() { Id = id, Error = RpcError.TIMEOUT }; } } catch (Exception e) { RaiseLogEvent(LogType.Error, "", e.Message); throw; } _responses.Remove(id, out result); _events.Remove(id, out ev); return result ?? new RpcMessage() { Id =id, Error = RpcError.UNKNOWN }; } #endregion #region Client Interface public override DatabaseInfo Info() { var result = _transport.Send(new RpcInfoParameters()); return result.Info; } private static string[]? _types; public override IEnumerable SupportedTypes() { _types ??= CoreUtils.Entities .Where(x => x.GetInterfaces().Contains(typeof(IPersistent))) .Select(x => x.EntityName().Replace(".", "_")) .ToArray(); return _types; } #region Validate & 2FA protected override IValidationData DoValidate(string userid, string password, Guid session = default) { var parameters = new RpcValidateParameters() { UserID = userid, Password = password, PIN = "", UsePIN = false, SessionID = session, Platform = ClientFactory.Platform, Version = ClientFactory.Version }; return _transport.Send(parameters); } protected override IValidationData DoValidate(string pin, Guid session = default) { var ticks = DateTime.Now.ToUniversalTime().Ticks.ToString(); var parameters = new RpcValidateParameters() { UserID = Encryption.Encrypt(ticks, "wCq9rryEJEuHIifYrxRjxg", true), Password = Encryption.Encrypt(ticks, "7mhvLnqMwkCAzN+zNGlyyg", true), PIN = pin, UsePIN = true, SessionID = session, Platform = ClientFactory.Platform, Version = ClientFactory.Version }; return _transport.Send(parameters); } protected override IValidationData DoValidate(Guid session = default) { var parameters = new RpcValidateParameters() { UserID = "", Password = "", PIN = "", UsePIN = false, SessionID = session, Platform = ClientFactory.Platform, Version = ClientFactory.Version }; return _transport.Send(parameters); } protected override bool DoCheck2FA(string code, Guid? session) { var parameters = new RpcCheck2FAParameters() { Code = code, SessionId = session ?? Guid.Empty, }; var result = _transport.Send(parameters); return result.Valid; } #endregion protected override CoreTable DoQuery(Filter? filter, Columns? columns, SortOrder? sort = null) { var parameters = new RpcQueryParameters() { Queries = new RpcQueryDefinition[] { new RpcQueryDefinition() { Type = typeof(TEntity), Filter = filter, Columns = columns, Sort = sort } } }; var result = _transport.Send(parameters); return result.Tables[0].Table; } protected override TEntity[] DoLoad(Filter? filter = null, SortOrder? sort = null) { return DoQuery(filter, null, sort).Rows.Select(r => r.ToObject()).ToArray(); } protected override Dictionary DoQueryMultiple(Dictionary queries) { var result = new Dictionary(); var parameters = new RpcQueryParameters() { Queries = queries.Select(kvp => new RpcQueryDefinition() { Key = kvp.Key, Type = kvp.Value.Type, Filter = kvp.Value.Filter, Columns = kvp.Value.Columns, Sort = kvp.Value.SortOrder } ).ToArray() }; var response = _transport.Send(parameters); foreach (var key in response.Tables) result[key.Key] = key.Table; return result; } protected override void DoSave(TEntity entity, string auditnote) { DoSave(new TEntity[] { entity }, auditnote); } protected override void DoSave(IEnumerable entities, string auditnote) { var items = entities.ToArray(); var parameters = new RpcSaveParameters() { Type = typeof(TEntity), Items = items }; var result = _transport.Send(parameters); for (int i=0; i< result.Deltas.Length; i++) { items[i].SetObserving(false); foreach (var (key, value) in result.Deltas[i]) { if (CoreUtils.TryGetProperty(key, out var property)) CoreUtils.SetPropertyValue(items[i], key, CoreUtils.ChangeType(value, property.PropertyType)); } items[i].CommitChanges(); items[i].SetObserving(true); } } protected override void DoDelete(TEntity entity, string auditnote) { DoDelete(new TEntity[] { entity }, auditnote); } protected override void DoDelete(IList entities, string auditnote) { var parameters = new RpcDeleteParameters() { Type = typeof(TEntity), IDs = entities.Select(x=>x.ID).ToArray(), AuditNote = auditnote }; _transport.Send(parameters); } protected override bool DoPing() { try { _transport.Send(new RpcPingParameters()); return true; } catch (Exception e) { return false; } } #endregion } }