using GenHTTP.Api.Content; using GenHTTP.Api.Protocol; using GenHTTP.Modules.IO; using GenHTTP.Modules.IO.Streaming; using GenHTTP.Modules.IO.Strings; using InABox.Clients; using InABox.Core; using InABox.Database; using InABox.Rpc; using InABox.Server; using PRSServices; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; using System.Text; using System.Threading.Tasks; using RequestMethod = GenHTTP.Api.Protocol.RequestMethod; namespace PRSServer; public class HTTPDatabaseProxyHandlerProperties { public HTTPDatabaseProxyProperties Properties { get; set; } public IRpcClientTransport ServerTransport { get; set; } public HTTPDatabaseProxyHandlerProperties(HTTPDatabaseProxyProperties properties, IRpcClientTransport serverTransport) { Properties = properties; ServerTransport = serverTransport; } } internal class HTTPDatabaseProxyHandler : Handler { private HTTPDatabaseProxyProperties Properties { get; set; } public IRpcClientTransport ServerTransport { get; set; } public override void Init(HTTPDatabaseProxyHandlerProperties properties) { Properties = properties.Properties; ServerTransport = properties.ServerTransport; } private static RequestData GetRequestData(IRequest request) { BinarySerializationSettings settings = BinarySerializationSettings.V1_0; if (request.Query.TryGetValue("serializationVersion", out var versionString)) { settings = BinarySerializationSettings.ConvertVersionString(versionString); } var data = new RequestData(settings); if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse(formatString, out var format)) { data.RequestFormat = format; } data.ResponseFormat = SerializationFormat.Json; if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse(formatString, out format)) { data.ResponseFormat = format; } return data; } /// /// The main handler for the server; an HTTP request comes in, an HTTP response goes out. /// /// /// public override ValueTask HandleAsync(IRequest request) { try { switch (request.Method.KnownMethod) { case RequestMethod.GET: case RequestMethod.HEAD: var current = request.Target.Current?.Value; if (String.Equals(current,"update")) { request.Target.Advance(); current = request.Target.Current?.Value; } switch (current) { case "operations" or "supported_operations": Logger.Send(LogType.Error, "", "Supported operations is no longer supported"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); case "classes" or "supported_classes": Logger.Send(LogType.Error, "", "Supported classes is no longer supported"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop": return new ValueTask(GetUpdateFile(request).Build()); case "info": return new ValueTask(GetServerInfo(request).Build()); case "ping": return new ValueTask(request.Respond().Status(ResponseStatus.OK).Build()); } Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current)); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); case RequestMethod.POST: var target = request.Target.Current; if (target is not null) { var data = GetRequestData(request); return target.Value switch { "validate" => new ValueTask(Validate(request, data).Build()), "check_2fa" => new ValueTask(Check2FA(request, data).Build()), _ => HandleDatabaseRequest(request, data), }; } return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } return new ValueTask(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build()); } catch(Exception e) { Logger.Send(LogType.Error, "", CoreUtils.FormatException(e)); return new ValueTask(request.Respond().Status(ResponseStatus.InternalServerError).Build()); } } private IResponseBuilder DoForward( IRequest request, RequestData data, Func convertRequest, Func convertResponse ) where TRequest : Request where TResponse : Response, new() where TCommand : IRpcCommand where TParameters : IRpcCommandParameters, ISerializeBinary where TResult : IRpcCommandResult, ISerializeBinary, new() { var requestObj = Deserialize(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); var internalMessage = new InternalServerMessage { Session = requestObj.Credentials.Session, Payload = convertRequest(requestObj).WriteBinary(BinarySerializationSettings.Latest) }; var serverResponse = ServerTransport.Send(typeof(TCommand).Name, internalMessage, checkErrors: false); TResponse response; if(serverResponse.Error == RpcError.NONE) { var result = Serialization.ReadBinary(serverResponse.Payload, BinarySerializationSettings.Latest) ?? throw new Exception($"Cannot Deserialize {typeof(TCommand).Name}"); response = convertResponse(requestObj, result); response.Status = StatusCode.OK; } else { response = new TResponse { Status = serverResponse.Error switch { RpcError.UNAUTHENTICATED => StatusCode.Unauthenticated, _ => StatusCode.Error }, }; response.Messages.Add(Encoding.UTF8.GetString(serverResponse.Payload)); } return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response); } /// /// Returns the Splash Logo and Color Scheme for this Database /// /// /// private IResponseBuilder GetServerInfo(IRequest request) { var data = GetRequestData(request); return DoForward( request, data, x => new RpcInfoParameters(), (r, x) => new InfoResponse { Info = x.Info ?? new DatabaseInfo() }); } #region Authentication private IResponseBuilder Validate(IRequest request, RequestData data) { return DoForward( request, data, x => new RpcValidateParameters { UserID = x.UserID, Password = x.Password, PIN = x.PIN, UsePIN = x.UsePIN, SessionID = x.Credentials.Session, Platform = x.Credentials.Platform, Version = x.Credentials.Version, }, (r, x) => new ValidateResponse { ValidationStatus = x.Status, UserGuid = x.UserGuid, UserID = x.UserID, SecurityID = x.SecurityID, Session = x.SessionID, Recipient2FA = x.Recipient2FA, PasswordExpiration = x.PasswordExpiration }); } private IResponseBuilder Check2FA(IRequest request, RequestData data) { return DoForward( request, data, x => new RpcCheck2FAParameters { Code = x.Code, SessionId = x.Credentials.Session }, (r, x) => new Check2FAResponse { Valid = x.Valid }); } #endregion #region Database private static MethodInfo GetMethod(string name) => typeof(HTTPDatabaseProxyHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Instance) ?? throw new Exception($"Invalid method '{name}'"); private static readonly List> methodMap = new() { new("List", GetMethod(nameof(List))), new("Save", GetMethod(nameof(Save))), new("Delete", GetMethod(nameof(Delete))), new("MultiSave", GetMethod(nameof(MultiSave))), new("MultiDelete", GetMethod(nameof(MultiDelete))) }; private class RequestData { public SerializationFormat RequestFormat { get; set; } public SerializationFormat ResponseFormat { get; set; } public BinarySerializationSettings BinarySerializationSettings { get; set; } public RequestData(BinarySerializationSettings binarySerializationSettings) { BinarySerializationSettings = binarySerializationSettings; } } private IResponseBuilder List(IRequest request, RequestData data) where T : Entity, new() { return DoForward, QueryResponse, RpcQueryCommand, RpcQueryParameters, RpcQueryResult>( request, data, x => new RpcQueryParameters { Queries = new[] { new RpcQueryDefinition { Key = typeof(T).Name, Type = typeof(T), Filter = x.Filter, Columns = x.Columns, Sort = x.Sort } } }, (r, x) => new QueryResponse { Items = x.Tables[0].Table }); } private IResponseBuilder Save(IRequest request, RequestData data) where T : Entity, new() { return DoForward, SaveResponse, RpcSaveCommand, RpcSaveParameters, RpcSaveResult>( request, data, x => new RpcSaveParameters { AuditNote = x.AuditNote, Items = new[] { x.Item }, Type = typeof(T) }, (r, x) => { if (r.ReturnOnlyChanged) { return new SaveResponse { ChangedValues = x.Deltas[0] }; } else { var deltas = x.Deltas[0]; r.Item.SetObserving(false); foreach (var (key, value) in deltas) { if (CoreUtils.TryGetProperty(key, out var property)) CoreUtils.SetPropertyValue(deltas, key, CoreUtils.ChangeType(value, property.PropertyType)); } r.Item.CommitChanges(); r.Item.SetObserving(true); return new SaveResponse { Item = r.Item }; } }); } private IResponseBuilder Delete(IRequest request, RequestData data) where T : Entity, new() { return DoForward, DeleteResponse, RpcDeleteCommand, RpcDeleteParameters, RpcDeleteResult>( request, data, x => new RpcDeleteParameters { AuditNote = x.AuditNote, IDs = new[] { x.Item.ID }, Type = typeof(T) }, (r, x) => new DeleteResponse()); } private IResponseBuilder MultiSave(IRequest request, RequestData data) where T : Entity, new() { return DoForward, MultiSaveResponse, RpcSaveCommand, RpcSaveParameters, RpcSaveResult>( request, data, x => new RpcSaveParameters { AuditNote = x.AuditNote, Items = x.Items, Type = typeof(T) }, (r, x) => { if (r.ReturnOnlyChanged) { return new MultiSaveResponse { ChangedValues = x.Deltas.ToList() }; } else { for (int i = 0; i < x.Deltas.Length; i++) { r.Items[i].SetObserving(false); foreach (var (key, value) in x.Deltas[i]) { if (CoreUtils.TryGetProperty(key, out var property)) CoreUtils.SetPropertyValue(r.Items[i], key, CoreUtils.ChangeType(value, property.PropertyType)); } r.Items[i].CommitChanges(); r.Items[i].SetObserving(true); } return new MultiSaveResponse { Items = r.Items }; } }); } private IResponseBuilder MultiDelete(IRequest request, RequestData data) where T : Entity, new() { return DoForward, MultiDeleteResponse, RpcDeleteCommand, RpcDeleteParameters, RpcDeleteResult>( request, data, x => new RpcDeleteParameters { AuditNote = x.AuditNote, IDs = x.Items.Select(x => x.ID).ToArray(), Type = typeof(T) }, (r, x) => new MultiDeleteResponse()); } private IResponseBuilder QueryMultiple(IRequest request, RequestData data) { return DoForward( request, data, x => new RpcQueryParameters { Queries = x.Queries.Select(x => new RpcQueryDefinition { Key = x.Key, Type = CoreUtils.GetEntity(x.Value.Type), Filter = x.Value.Filter, Columns = x.Value.Columns, Sort = x.Value.Sort }).ToArray() }, (r, x) => new MultiQueryResponse { Tables = x.Tables.ToDictionary(x => x.Key, x => x.Table) }); } private static T Deserialize(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false) { if (stream is null) throw new Exception("Stream is null"); if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary))) { return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings); } else { var str = new StreamReader(stream).ReadToEnd(); return Serialization.Deserialize(str, strict) ?? throw new Exception("Deserialization failed"); } } private static IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result) { if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary) { var stream = new MemoryStream(); binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings)); var response = request.Respond() .Type(new FlexibleContentType(ContentType.ApplicationOctetStream)) .Content(stream, (ulong?)stream.Length, () => new ValueTask((ulong)stream.GetHashCode())); return response; } else { var serialized = Serialization.Serialize(result); var response = request.Respond() .Type(new FlexibleContentType(ContentType.ApplicationJson)) .Content(new ResourceContent(Resource.FromString(serialized).Build())); return response; } } /// /// Handler for all database requests /// /// /// private ValueTask HandleDatabaseRequest(IRequest request, RequestData requestData) { var endpoint = request.Target.Current?.Value ?? ""; if (endpoint.StartsWith("QueryMultiple")) { var result = QueryMultiple(request, requestData); return new ValueTask(result.Build()); } foreach (var (name, method) in methodMap) if (endpoint.Length > name.Length && endpoint.StartsWith(name)) { var entityName = endpoint[name.Length..]; var entityType = GetEntity(entityName); if (entityType != null) { if (entityType.IsAssignableTo(typeof(ISecure))) { Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } var resolvedMethod = method.MakeGenericMethod(entityType); var result = (resolvedMethod.Invoke(null, new object[] { request, requestData }) as IResponseBuilder)!; return new ValueTask(result.Build()); } Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } private Dictionary? _persistentRemotable; private Type? GetEntity(string entityName) { _persistentRemotable ??= CoreUtils.TypeList( e => e.IsSubclassOf(typeof(Entity)) && e.GetInterfaces().Contains(typeof(IRemotable)) && e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x); return _persistentRemotable.GetValueOrDefault(entityName); } #endregion #region Installer private IResponseBuilder GetUpdateFile(IRequest request) { var endpoint = request.Target.Current; switch (endpoint?.Value) { case "version": return request.Respond() .Type(new FlexibleContentType(ContentType.TextPlain)) .Content(new ResourceContent(Resource.FromString(Client.Version()).Build())); case "releasenotes" or "release_notes": return request.Respond() .Type(new FlexibleContentType(ContentType.TextPlain)) .Content(new ResourceContent(Resource.FromString(Client.ReleaseNotes()).Build())); case "install" or "install_desktop": return request.Respond() .Header("Content-Disposition", $"attachment; filename=PRSDesktopSetup.exe") .Content(new ResourceContent(new ByteArrayResource(Client.Installer() ?? Array.Empty(), "PRSDesktopSetup.exe", new FlexibleContentType(ContentType.ApplicationOctetStream), null))); } return request.Respond().Status(ResponseStatus.NotFound); } #endregion #region GenHTTP stuff public IHandler Parent { get; } public ValueTask PrepareAsync() { return new ValueTask(); } public IEnumerable GetContent(IRequest request) { return Enumerable.Empty(); } #endregion } internal class HTTPDatabaseProxyEngine : DatabaseProxyEngine { private Listener? Listener; protected override void RunProxy() { Logger.Send(LogType.Information, "", "Starting Listener on port " + Properties.ListenPort); try { Listener = new Listener(new HTTPDatabaseProxyHandlerProperties(Properties, ServerTransport)); Listener.InitHTTPS((ushort)Properties.ListenPort, CertificateFileName()); Listener.Start(); } catch (Exception eListen) { Logger.Send(LogType.Error, ClientFactory.UserID, eListen.Message); } } private string CertificateFileName() => !string.IsNullOrWhiteSpace(Properties.CertificateFile) ? Properties.CertificateFile : CertificateEngine.CertificateFile; public override void Stop() { Logger.Send(LogType.Information, "", "Stopping"); Listener?.Stop(); } }