浏览代码

Interim Update

Frank van den Bos 2 年之前
父节点
当前提交
55392e8812

+ 16 - 14
InABox.Server/RPC/Transports/Socket/RPCServerSocketTransport.cs

@@ -1,7 +1,10 @@
+using System.Net.Security;
+using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using InABox.Core;
 using WebSocketSharp;
 using WebSocketSharp.Server;
+using Logger = InABox.Core.Logger;
 
 namespace InABox.Rpc
 {
@@ -15,26 +18,25 @@ namespace InABox.Rpc
         
         public override bool IsSecure() => Certificate != null;
         
-        public RpcServerSocketTransport(int port, string? certificatefile = null) 
+        public RpcServerSocketTransport(int port, X509Certificate2? certificate = null) 
         {
-            if (File.Exists(certificatefile))
-            {
-                try
-                {
-                    var certificate = new X509Certificate2(certificatefile);
-                    if (certificate.NotAfter > DateTime.Now)
-                        Certificate = certificate;
-                }
-                catch
-                {
-                }
-            }
+            
+            Certificate = certificate;
             
             _server = new WebSocketServer(port, Certificate != null);
             _server.SslConfiguration.ServerCertificate = Certificate;
+            _server.SslConfiguration.ClientCertificateRequired = false;
+            _server.SslConfiguration.CheckCertificateRevocation = false;
+            _server.SslConfiguration.ClientCertificateValidationCallback = WSSCallback;
+            _server.SslConfiguration.EnabledSslProtocols = SslProtocols.Tls12;
             _server?.AddWebSocketService("/", () => new RpcServerSocketConnection() { Transport = this } );
         }
-        
+
+        private bool WSSCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslpolicyerrors)
+        {
+            return true;
+        }
+
         public override void Start()
         {
             _server?.Start();

+ 417 - 0
InABox.Server/Rest/RestHandler.cs

@@ -0,0 +1,417 @@
+using System.Reflection;
+using GenHTTP.Api.Content;
+using GenHTTP.Api.Protocol;
+using GenHTTP.Modules.IO;
+using GenHTTP.Modules.IO.FileSystem;
+using GenHTTP.Modules.IO.Streaming;
+using InABox.Clients;
+using InABox.Core;
+using InABox.Database;
+using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
+
+namespace InABox.API
+{
+    public class RestHandler : IHandler
+    {
+        private readonly List<string> endpoints;
+        private readonly List<string> operations;
+
+        //private int? WebSocketPort;
+
+        public RestHandler(IHandler parent) //, int? webSocketPort)
+        {
+            // WebSocketPort = webSocketPort;
+
+            Parent = parent;
+
+            endpoints = new();
+            operations = new();
+
+            var types = CoreUtils.TypeList(
+                x => x.IsSubclassOf(typeof(Entity))
+                     && x.GetInterfaces().Contains(typeof(IRemotable))
+            );
+            var DBTypes = DbFactory.SupportedTypes();
+
+            foreach (var t in types)
+                if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
+                {
+                    operations.Add(t.EntityName().Replace(".", "_"));
+
+                    endpoints.Add(string.Format("List{0}", t.Name));
+                    endpoints.Add(string.Format("Load{0}", t.Name));
+                    endpoints.Add(string.Format("Save{0}", t.Name));
+                    endpoints.Add(string.Format("MultiSave{0}", t.Name));
+                    endpoints.Add(string.Format("Delete{0}", t.Name));
+                    endpoints.Add(string.Format("MultiDelete{0}", t.Name));
+                }
+
+            endpoints.Add("QueryMultiple");
+        }
+        
+        private RequestData GetRequestData(IRequest request)
+        {
+            BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
+            if (request.Query.TryGetValue("serializationVersion", out var versionString))
+            {
+                settings = BinarySerializationSettings.ConvertVersionString(versionString);
+            }
+
+            var data = new RequestData(settings);
+            if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
+            {
+                data.RequestFormat = format;
+            }
+            data.ResponseFormat = SerializationFormat.Json;
+            if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
+            {
+                data.ResponseFormat = format;
+            }
+
+            return data;
+        }
+
+        /// <summary>
+        /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
+        /// </summary>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        public ValueTask<IResponse?> HandleAsync(IRequest request)
+        {
+            try
+            {
+                switch (request.Method.KnownMethod)
+                {
+                    case RequestMethod.GET:
+                    case RequestMethod.HEAD:
+                        var current = request.Target.Current?.Value;
+                        if (String.Equals(current,"update"))
+                        {
+                            request.Target.Advance();
+                            current = request.Target.Current?.Value;
+                        }
+                        switch (current)
+                        {
+                            case "operations" or "supported_operations":
+                                return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
+                            case "classes" or "supported_classes":
+                                return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
+                            case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
+                                return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
+                            case "info":
+                                return new ValueTask<IResponse?>(GetServerInfo(request).Build());
+                            case "ping":
+                                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
+                        }
+
+                        Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
+                            string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
+                        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                    case RequestMethod.POST:
+                        var target = request.Target.Current;
+                        if (target is not null)
+                        {
+                            var data = GetRequestData(request);
+                            return target.Value switch
+                            {
+                                "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
+                                "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
+                                //"notify" or "push" => new ValueTask<IResponse?>(GetPush(request, data).Build()),
+                                _ => HandleDatabaseRequest(request, data),
+                            };
+                        }
+                        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                }
+
+                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
+            }
+            catch(Exception e)
+            {
+                Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
+                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
+            }
+        }
+        
+        /// <summary>
+        /// Returns a JSON list of operation names; used for checking support of operations client side
+        /// </summary>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        private IResponseBuilder GetSupportedOperations(IRequest request)
+        {
+            var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
+            return request.Respond()
+                .Type(new FlexibleContentType(ContentType.ApplicationJson))
+                .Content(new ResourceContent(Resource.FromString(serialized).Build()));
+        }
+
+        /// <summary>
+        /// Returns a JSON list of class names; used for checking support of operations client side
+        /// </summary>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        private IResponseBuilder GetSupportedClasses(IRequest request)
+        {
+            var serialized = Serialization.Serialize(operations) ?? "";
+            return request.Respond()
+                .Type(new FlexibleContentType(ContentType.ApplicationJson))
+                .Content(new ResourceContent(Resource.FromString(serialized).Build()));
+        }
+        
+        /// <summary>
+        /// Returns the Splash Logo and Color Scheme for this Database
+        /// </summary>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        private IResponseBuilder GetServerInfo(IRequest request)
+        {
+            var data = GetRequestData(request);
+            InfoResponse response = RestService.Info(new InfoRequest());
+            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
+        }
+
+        /// <summary>
+        /// Gets port for web socket
+        /// </summary>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        // private IResponseBuilder GetPush(IRequest request, RequestData data)
+        // {
+        //     var requestObj = Deserialize<PushRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        //     if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
+        //     {
+        //         return request.Respond().Status(ResponseStatus.NotFound);
+        //     }
+        //     var response = new PushResponse
+        //     {
+        //         Status = StatusCode.OK,
+        //         SocketPort = WebSocketPort
+        //     };
+        //     return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
+        // }
+
+        #region Authentication
+
+        private IResponseBuilder Validate(IRequest request, RequestData data)
+        {
+            var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+            var response = RestService.Validate(requestObj);
+
+            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
+        }
+
+        private IResponseBuilder Check2FA(IRequest request, RequestData data)
+        {
+            var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+            var response = RestService.Check2FA(requestObj);
+
+            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
+        }
+
+        #endregion
+
+        #region Database
+
+        private static MethodInfo GetMethod(string name) =>
+            typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
+            ?? throw new Exception($"Invalid method '{name}'");
+
+        private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
+        {
+            new("List", GetMethod(nameof(List))),
+            new("Save", GetMethod(nameof(Save))),
+            new("Delete", GetMethod(nameof(Delete))),
+            new("MultiSave", GetMethod(nameof(MultiSave))),
+            new("MultiDelete", GetMethod(nameof(MultiDelete)))
+        };
+
+        private class RequestData
+        {
+            public SerializationFormat RequestFormat { get; set; }
+            public SerializationFormat ResponseFormat { get; set; }
+
+            public BinarySerializationSettings BinarySerializationSettings { get; set; }
+
+            public RequestData(BinarySerializationSettings binarySerializationSettings)
+            {
+                BinarySerializationSettings = binarySerializationSettings;
+            }
+        }
+
+        private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
+        {
+            var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+            return RestService<T>.List(requestObject);
+        }
+        private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
+        {
+            var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+            return RestService<T>.Save(requestObject);
+        }
+        private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
+        {
+            var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+            return RestService<T>.Delete(requestObject);
+        }
+        private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
+        {
+            var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+            return RestService<T>.MultiSave(requestObject);
+        }
+        private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
+        {
+            var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+            return RestService<T>.MultiDelete(requestObject);
+        }
+        private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
+        {
+            var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+
+            return RestService.QueryMultiple(requestObject, false);
+        }
+
+        private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
+        {
+            if (stream is null)
+                throw new Exception("Stream is null");
+            if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
+            {
+                return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
+            }
+            else
+            {
+                var str = new StreamReader(stream).ReadToEnd();
+                return Serialization.Deserialize<T>(str, strict)
+                       ?? throw new Exception("Deserialization failed");
+            }
+        }
+
+        private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
+        {
+            if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
+            {
+                var stream = new MemoryStream();
+                binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
+
+                var response = request.Respond()
+                    .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
+                    .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
+                return response;
+            }
+            else
+            {
+                var serialized = Serialization.Serialize(result);
+
+                var response = request.Respond()
+                    .Type(new FlexibleContentType(ContentType.ApplicationJson))
+                    .Content(new ResourceContent(Resource.FromString(serialized).Build()));
+                return response;
+            }
+        }
+
+        /// <summary>
+        /// Handler for all database requests
+        /// </summary>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
+        {
+            var endpoint = request.Target.Current?.Value ?? "";
+            if (endpoint.StartsWith("QueryMultiple"))
+            {
+                var result = QueryMultiple(request, requestData);
+                return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
+            }
+
+            foreach (var (name, method) in methodMap)
+                if (endpoint.Length > name.Length && endpoint.StartsWith(name))
+                {
+                    var entityName = endpoint[name.Length..];
+
+                    var entityType = GetEntity(entityName);
+                    if (entityType != null)
+                    {
+                        if (entityType.IsAssignableTo(typeof(ISecure)))
+                        {
+                            Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
+                            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                        }
+
+                        var resolvedMethod = method.MakeGenericMethod(entityType);
+                        var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
+
+                        return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
+                    }
+
+                    Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
+                        $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
+                    return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                }
+
+            Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
+            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+        }
+
+        private Dictionary<string, Type>? _persistentRemotable;
+
+        private Type? GetEntity(string entityName)
+        {
+            _persistentRemotable ??= CoreUtils.TypeList(
+                e => e.IsSubclassOf(typeof(Entity)) &&
+                     e.GetInterfaces().Contains(typeof(IRemotable)) &&
+                     e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
+            return _persistentRemotable.GetValueOrDefault(entityName);
+        }
+
+        #endregion
+
+        #region Installer
+
+        
+        private IResponseBuilder GetUpdateFile(IRequest request)
+        {
+            var endpoint = request.Target.Current;
+            string? filename = null;
+            switch (endpoint?.Value)
+            {
+                case "version":
+                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
+                    break;
+                case "releasenotes" or "release_notes":
+                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
+                    break;
+                case "install" or "install_desktop":
+                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
+                    break;
+            }
+
+            if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
+
+            if (File.Exists(filename))
+            {
+                return request.Respond()
+                    .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
+                    .Content(new ResourceContent(
+                        new FileResource(new FileInfo(filename), null, null)));
+            }
+            return request.Respond().Status(ResponseStatus.NotFound);
+        }
+
+        #endregion
+
+        #region GenHTTP stuff
+        public IHandler Parent { get; }
+
+        public ValueTask PrepareAsync()
+        {
+            return new ValueTask();
+        }
+
+        public IEnumerable<ContentElement> GetContent(IRequest request)
+        {
+            return Enumerable.Empty<ContentElement>();
+        }
+
+        #endregion
+    }
+}

+ 27 - 0
InABox.Server/Rest/RestHandlerBuilder.cs

@@ -0,0 +1,27 @@
+using GenHTTP.Api.Content;
+
+namespace InABox.API
+{
+    public class RestHandlerBuilder : IHandlerBuilder<RestHandlerBuilder>
+    {
+        private readonly List<IConcernBuilder> _Concerns = new();
+
+        //private int? WebSocketPort;
+
+        // public RestHandlerBuilder(int? webSocketPort)
+        // {
+        //     WebSocketPort = webSocketPort;
+        // }
+
+        public RestHandlerBuilder Add(IConcernBuilder concern)
+        {
+            _Concerns.Add(concern);
+            return this;
+        }
+
+        public IHandler Build(IHandler parent)
+        {
+            return Concerns.Chain(parent, _Concerns, p => new RestHandler(p)); //, WebSocketPort));
+        }
+    }
+}

+ 43 - 543
InABox.Server/Rest/RestListener.cs

@@ -1,572 +1,72 @@
 using System.Net;
-using System.Reflection;
 using System.Security.Cryptography.X509Certificates;
-using GenHTTP.Api.Content;
 using GenHTTP.Api.Infrastructure;
-using GenHTTP.Api.Protocol;
 using GenHTTP.Engine;
-using GenHTTP.Modules.IO;
-using GenHTTP.Modules.IO.FileSystem;
-using GenHTTP.Modules.IO.Streaming;
 using GenHTTP.Modules.Practices;
-using InABox.Clients;
-using InABox.Core;
-using InABox.Database;
-using InABox.Remote.Shared;
-using InABox.Server.WebSocket;
-using InABox.WebSocket.Shared;
-using NPOI.POIFS.Crypt.Dsig;
-using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
 
 namespace InABox.API
 {
-    public class RestHandler : IHandler
-    {
-        private readonly List<string> endpoints;
-        private readonly List<string> operations;
-
-        private int? WebSocketPort;
-
-        public RestHandler(IHandler parent, int? webSocketPort)
-        {
-            WebSocketPort = webSocketPort;
-
-            Parent = parent;
-
-            endpoints = new();
-            operations = new();
-
-            var types = CoreUtils.TypeList(
-                x => x.IsSubclassOf(typeof(Entity))
-                     && x.GetInterfaces().Contains(typeof(IRemotable))
-            );
-            var DBTypes = DbFactory.SupportedTypes();
-
-            foreach (var t in types)
-                if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
-                {
-                    operations.Add(t.EntityName().Replace(".", "_"));
-
-                    endpoints.Add(string.Format("List{0}", t.Name));
-                    endpoints.Add(string.Format("Load{0}", t.Name));
-                    endpoints.Add(string.Format("Save{0}", t.Name));
-                    endpoints.Add(string.Format("MultiSave{0}", t.Name));
-                    endpoints.Add(string.Format("Delete{0}", t.Name));
-                    endpoints.Add(string.Format("MultiDelete{0}", t.Name));
-                }
-
-            endpoints.Add("QueryMultiple");
-        }
-        
-        private RequestData GetRequestData(IRequest request)
-        {
-            BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
-            if (request.Query.TryGetValue("serializationVersion", out var versionString))
-            {
-                settings = BinarySerializationSettings.ConvertVersionString(versionString);
-            }
-
-            var data = new RequestData(settings);
-            if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
-            {
-                data.RequestFormat = format;
-            }
-            data.ResponseFormat = SerializationFormat.Json;
-            if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
-            {
-                data.ResponseFormat = format;
-            }
-
-            return data;
-        }
-
-        /// <summary>
-        /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
-        /// </summary>
-        /// <param name="request"></param>
-        /// <returns></returns>
-        public ValueTask<IResponse?> HandleAsync(IRequest request)
-        {
-            try
-            {
-                switch (request.Method.KnownMethod)
-                {
-                    case RequestMethod.GET:
-                    case RequestMethod.HEAD:
-                        var current = request.Target.Current?.Value;
-                        if (String.Equals(current,"update"))
-                        {
-                            request.Target.Advance();
-                            current = request.Target.Current?.Value;
-                        }
-                        switch (current)
-                        {
-                            case "operations" or "supported_operations":
-                                return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
-                            case "classes" or "supported_classes":
-                                return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
-                            case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
-                                return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
-                            case "info":
-                                return new ValueTask<IResponse?>(GetServerInfo(request).Build());
-                            case "ping":
-                                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
-                        }
-
-                        Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
-                            string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
-                        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
-                    case RequestMethod.POST:
-                        var target = request.Target.Current;
-                        if (target is not null)
-                        {
-                            var data = GetRequestData(request);
-                            return target.Value switch
-                            {
-                                "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
-                                "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
-                                "notify" or "push" => new ValueTask<IResponse?>(GetPush(request, data).Build()),
-                                _ => HandleDatabaseRequest(request, data),
-                            };
-                        }
-                        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
-                }
-
-                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
-            }
-            catch(Exception e)
-            {
-                Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
-                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
-            }
-        }
-        
-        /// <summary>
-        /// Returns a JSON list of operation names; used for checking support of operations client side
-        /// </summary>
-        /// <param name="request"></param>
-        /// <returns></returns>
-        private IResponseBuilder GetSupportedOperations(IRequest request)
-        {
-            var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
-            return request.Respond()
-                .Type(new FlexibleContentType(ContentType.ApplicationJson))
-                .Content(new ResourceContent(Resource.FromString(serialized).Build()));
-        }
-
-        /// <summary>
-        /// Returns a JSON list of class names; used for checking support of operations client side
-        /// </summary>
-        /// <param name="request"></param>
-        /// <returns></returns>
-        private IResponseBuilder GetSupportedClasses(IRequest request)
-        {
-            var serialized = Serialization.Serialize(operations) ?? "";
-            return request.Respond()
-                .Type(new FlexibleContentType(ContentType.ApplicationJson))
-                .Content(new ResourceContent(Resource.FromString(serialized).Build()));
-        }
-        
-        /// <summary>
-        /// Returns the Splash Logo and Color Scheme for this Database
-        /// </summary>
-        /// <param name="request"></param>
-        /// <returns></returns>
-        private IResponseBuilder GetServerInfo(IRequest request)
-        {
-            var data = GetRequestData(request);
-            InfoResponse response = RestService.Info(new InfoRequest());
-            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
-        }
-
-        /// <summary>
-        /// Gets port for web socket
-        /// </summary>
-        /// <param name="request"></param>
-        /// <returns></returns>
-        private IResponseBuilder GetPush(IRequest request, RequestData data)
-        {
-            var requestObj = Deserialize<PushRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
-            {
-                return request.Respond().Status(ResponseStatus.NotFound);
-            }
-            var response = new PushResponse
-            {
-                Status = StatusCode.OK,
-                SocketPort = WebSocketPort
-            };
-            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
-        }
-
-        #region Authentication
-
-        private IResponseBuilder Validate(IRequest request, RequestData data)
-        {
-            var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            var response = RestService.Validate(requestObj);
-
-            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
-        }
-
-        private IResponseBuilder Check2FA(IRequest request, RequestData data)
-        {
-            var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            var response = RestService.Check2FA(requestObj);
-
-            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
-        }
-
-        #endregion
-
-        #region Database
-
-        private static MethodInfo GetMethod(string name) =>
-            typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
-            ?? throw new Exception($"Invalid method '{name}'");
-
-        private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
-        {
-            new("List", GetMethod(nameof(List))),
-            new("Save", GetMethod(nameof(Save))),
-            new("Delete", GetMethod(nameof(Delete))),
-            new("MultiSave", GetMethod(nameof(MultiSave))),
-            new("MultiDelete", GetMethod(nameof(MultiDelete)))
-        };
-
-        private class RequestData
-        {
-            public SerializationFormat RequestFormat { get; set; }
-            public SerializationFormat ResponseFormat { get; set; }
-
-            public BinarySerializationSettings BinarySerializationSettings { get; set; }
-
-            public RequestData(BinarySerializationSettings binarySerializationSettings)
-            {
-                BinarySerializationSettings = binarySerializationSettings;
-            }
-        }
-
-        private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
-        {
-            var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            return RestService<T>.List(requestObject);
-        }
-        private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
-        {
-            var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            return RestService<T>.Save(requestObject);
-        }
-        private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
-        {
-            var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            return RestService<T>.Delete(requestObject);
-        }
-        private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
-        {
-            var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            return RestService<T>.MultiSave(requestObject);
-        }
-        private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
-        {
-            var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-            return RestService<T>.MultiDelete(requestObject);
-        }
-        private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
-        {
-            var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
-
-            return RestService.QueryMultiple(requestObject, false);
-        }
-
-        private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
-        {
-            if (stream is null)
-                throw new Exception("Stream is null");
-            if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
-            {
-                return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
-            }
-            else
-            {
-                var str = new StreamReader(stream).ReadToEnd();
-                return Serialization.Deserialize<T>(str, strict)
-                    ?? throw new Exception("Deserialization failed");
-            }
-        }
-
-        private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
-        {
-            if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
-            {
-                var stream = new MemoryStream();
-                binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
-
-                var response = request.Respond()
-                    .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
-                    .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
-                return response;
-            }
-            else
-            {
-                var serialized = Serialization.Serialize(result);
-
-                var response = request.Respond()
-                    .Type(new FlexibleContentType(ContentType.ApplicationJson))
-                    .Content(new ResourceContent(Resource.FromString(serialized).Build()));
-                return response;
-            }
-        }
-
-        /// <summary>
-        /// Handler for all database requests
-        /// </summary>
-        /// <param name="request"></param>
-        /// <returns></returns>
-        private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
-        {
-            var endpoint = request.Target.Current?.Value ?? "";
-            if (endpoint.StartsWith("QueryMultiple"))
-            {
-                var result = QueryMultiple(request, requestData);
-                return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
-            }
-
-            foreach (var (name, method) in methodMap)
-                if (endpoint.Length > name.Length && endpoint.StartsWith(name))
-                {
-                    var entityName = endpoint[name.Length..];
-
-                    var entityType = GetEntity(entityName);
-                    if (entityType != null)
-                    {
-                        if (entityType.IsAssignableTo(typeof(ISecure)))
-                        {
-                            Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
-                            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
-                        }
-
-                        var resolvedMethod = method.MakeGenericMethod(entityType);
-                        var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
-
-                        return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
-                    }
-
-                    Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
-                        $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
-                    return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
-                }
-
-            Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
-            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
-        }
-
-        private Dictionary<string, Type>? _persistentRemotable;
-
-        private Type? GetEntity(string entityName)
-        {
-            _persistentRemotable ??= CoreUtils.TypeList(
-                e => e.IsSubclassOf(typeof(Entity)) &&
-                     e.GetInterfaces().Contains(typeof(IRemotable)) &&
-                     e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
-            return _persistentRemotable.GetValueOrDefault(entityName);
-        }
-
-        #endregion
-
-        #region Installer
-
-        
-        private IResponseBuilder GetUpdateFile(IRequest request)
-        {
-            var endpoint = request.Target.Current;
-            string? filename = null;
-            switch (endpoint?.Value)
-            {
-                case "version":
-                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
-                    break;
-                case "releasenotes" or "release_notes":
-                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
-                    break;
-                case "install" or "install_desktop":
-                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
-                    break;
-            }
-
-            if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
-
-            if (File.Exists(filename))
-            {
-                return request.Respond()
-                    .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
-                    .Content(new ResourceContent(
-                        new FileResource(new FileInfo(filename), null, null)));
-            }
-            return request.Respond().Status(ResponseStatus.NotFound);
-        }
-
-        #endregion
-
-        #region GenHTTP stuff
-        public IHandler Parent { get; }
-
-        public ValueTask PrepareAsync()
-        {
-            return new ValueTask();
-        }
-
-        public IEnumerable<ContentElement> GetContent(IRequest request)
-        {
-            return Enumerable.Empty<ContentElement>();
-        }
-
-        #endregion
-    }
-
-    public class RestHandlerBuilder : IHandlerBuilder<RestHandlerBuilder>
-    {
-        private readonly List<IConcernBuilder> _Concerns = new();
-
-        private int? WebSocketPort;
-
-        public RestHandlerBuilder(int? webSocketPort)
-        {
-            WebSocketPort = webSocketPort;
-        }
-
-        public RestHandlerBuilder Add(IConcernBuilder concern)
-        {
-            _Concerns.Add(concern);
-            return this;
-        }
-
-        public IHandler Build(IHandler parent)
-        {
-            return Concerns.Chain(parent, _Concerns, p => new RestHandler(p, WebSocketPort));
-        }
-    }
-
-    class RestPusher : IPusher
-    {
-        private WebSocketServer SocketServer;
-
-        public int Port => SocketServer.Port;
-
-        public RestPusher(int port)
-        {
-            SocketServer = new WebSocketServer(port);
-            SocketServer.Poll += SocketServer_Poll;
-        }
-
-        private void SocketServer_Poll(PushState.Session session)
-        {
-            PushManager.Poll(session.SessionID);
-        }
-
-        public void Start()
-        {
-            SocketServer.Start();
-        }
-
-        public void Stop()
-        {
-            SocketServer.Stop();
-        }
-
-        public void PushToAll<TPush>(TPush push) where TPush : BaseObject
-        {
-            SocketServer.Push(push);
-        }
-
-        public void PushToSession(Guid session, Type TPush, BaseObject push)
-        {
-            SocketServer.Push(session, TPush, push);
-        }
-
-        public void PushToSession<TPush>(Guid session, TPush push) where TPush : BaseObject
-        {
-            SocketServer.Push(session, push);
-        }
-
-        public IEnumerable<Guid> GetUserSessions(Guid userID)
-        {
-            return CredentialsCache.GetUserSessions(userID);
-        }
-
-        public IEnumerable<Guid> GetSessions(Platform platform)
-        {
-            return SocketServer.GetSessions(platform);
-        }
-    }
-
     public static class RestListener
     {
-        private static IServerHost? host;
-        private static X509Certificate2? certificate;
-        private static RestPusher? pusher;
+        private static IServerHost? _host;
+        //private static RestPusher? _pusher;
 
-        public static X509Certificate2? Certificate { get => certificate; }
+        public static X509Certificate2? Certificate { get; private set; }
 
         public static void Start()
         {
-            host?.Start();
-            pusher?.Start();
+            _host?.Start();
+            //_pusher?.Start();
         }
 
         public static void Stop()
         {
-            host?.Stop();
-            pusher?.Stop();
-        }
-
-        public static void InitCertificate(ushort port, X509Certificate2 certificate)
-        {
-            RestListener.certificate = certificate;
-            RestService.IsHTTPS = true;
-            host?.Bind(IPAddress.Any, port, certificate);
-        }
-        public static void InitCertificate(ushort port, string certificateFile)
-        {
-            InitCertificate(port, new X509Certificate2(certificateFile));
-        }
-
-        public static void InitPort(ushort port)
-        {
-            RestService.IsHTTPS = false;
-            host?.Bind(IPAddress.Any, port);
+            _host?.Stop();
+            //_pusher?.Stop();
         }
 
-        /// <summary>
-        /// Clears certificate and host information, and stops the listener.
-        /// </summary>
-        public static void Clear()
+        public static void Init(ushort port, X509Certificate2? cert)
         {
-            host?.Stop();
-            host = null;
-
-            pusher?.Stop();
-            pusher = null;
-
-            certificate = null;
+            _host = Host.Create();
+            
+            _host.Handler(new RestHandlerBuilder()).Defaults().Backlog(1024);
+            
+            Certificate = cert;
+            RestService.IsHTTPS = cert != null;
+            if (cert != null)
+                _host?.Bind(IPAddress.Any, port, cert);
+            else
+                _host?.Bind(IPAddress.Any, port);
         }
+        
+        // /// <summary>
+        // /// Clears certificate and host information, and stops the listener.
+        // /// </summary>
+        // public static void Clear()
+        // {
+        //     _host?.Stop();
+        //     _host = null;
+        //
+        //     //_pusher?.Stop();
+        //     //_pusher = null;
+        //
+        //     Certificate = null;
+        // }
 
         /// <summary>
         /// Initialise rest listener, and set up web socket port if non-zero.
         /// </summary>
         /// <param name="webSocketPort">The web-socket port to use, or 0 for no websocket.</param>
-        public static void Init(int webSocketPort)
-        {
-            if(webSocketPort != 0)
-            {
-                pusher = new RestPusher(webSocketPort);
-                PushManager.AddPusher(pusher);
-            }
-
-            host = Host.Create();
-
-            host.Handler(new RestHandlerBuilder(pusher?.Port))
-                .Defaults().Backlog(1024);
-        }
+        // public static void Init() //int webSocketPort)
+        // {
+        //     if(webSocketPort != 0)
+        //     {
+        //         _pusher = new RestPusher(webSocketPort);
+        //         PushManager.AddPusher(_pusher);
+        //     }
+        //
+        //     _host = Host.Create();
+        //     _host.Handler(new RestHandlerBuilder(_pusher?.Port)).Defaults().Backlog(1024);
+        // }
     }
 }

+ 57 - 0
InABox.Server/Rest/RestPusher.cs

@@ -0,0 +1,57 @@
+namespace InABox.API
+{
+
+    // class RestPusher : IPusher
+    // {
+    //     private WebSocketServer SocketServer;
+    //
+    //     public int Port => SocketServer.Port;
+    //
+    //     public RestPusher(int port)
+    //     {
+    //         SocketServer = new WebSocketServer(port);
+    //         SocketServer.Poll += SocketServer_Poll;
+    //     }
+    //
+    //     private void SocketServer_Poll(PushState.Session session)
+    //     {
+    //         PushManager.Poll(session.SessionID);
+    //     }
+    //
+    //     public void Start()
+    //     {
+    //         SocketServer.Start();
+    //     }
+    //
+    //     public void Stop()
+    //     {
+    //         SocketServer.Stop();
+    //     }
+    //
+    //     public void PushToAll<TPush>(TPush push) where TPush : BaseObject
+    //     {
+    //         SocketServer.Push(push);
+    //     }
+    //
+    //     public void PushToSession(Guid session, Type TPush, BaseObject push)
+    //     {
+    //         SocketServer.Push(session, TPush, push);
+    //     }
+    //
+    //     public void PushToSession<TPush>(Guid session, TPush push) where TPush : BaseObject
+    //     {
+    //         SocketServer.Push(session, push);
+    //     }
+    //
+    //     public IEnumerable<Guid> GetUserSessions(Guid userID)
+    //     {
+    //         return CredentialsCache.GetUserSessions(userID);
+    //     }
+    //
+    //     public IEnumerable<Guid> GetSessions(Platform platform)
+    //     {
+    //         return SocketServer.GetSessions(platform);
+    //     }
+    // }
+    
+}