| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 | using System.Reflection;using GenHTTP.Api.Content;using GenHTTP.Api.Protocol;using GenHTTP.Modules.IO;using GenHTTP.Modules.IO.FileSystem;using GenHTTP.Modules.IO.Streaming;using InABox.Clients;using InABox.Core;using InABox.Database;using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;namespace InABox.API{    public class RestHandler : IHandler    {        private readonly List<string> endpoints;        private readonly List<string> operations;        public RestHandler(IHandler parent)        {            Parent = parent;            endpoints = new();            operations = new();            var types = CoreUtils.TypeList(                x => x.IsSubclassOf(typeof(Entity))                     && x.GetInterfaces().Contains(typeof(IRemotable))            );            var DBTypes = DbFactory.SupportedTypes();            foreach (var t in types)                if (DBTypes.Contains(t.EntityName().Replace(".", "_")))                {                    operations.Add(t.EntityName().Replace(".", "_"));                    endpoints.Add(string.Format("List{0}", t.Name));                    endpoints.Add(string.Format("Load{0}", t.Name));                    endpoints.Add(string.Format("Save{0}", t.Name));                    endpoints.Add(string.Format("MultiSave{0}", t.Name));                    endpoints.Add(string.Format("Delete{0}", t.Name));                    endpoints.Add(string.Format("MultiDelete{0}", t.Name));                }            endpoints.Add("QueryMultiple");        }                private RequestData GetRequestData(IRequest request)        {            BinarySerializationSettings settings = BinarySerializationSettings.V1_0;            if (request.Query.TryGetValue("serializationVersion", out var versionString))            {                settings = BinarySerializationSettings.ConvertVersionString(versionString);            }            var data = new RequestData(settings);            if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))            {                data.RequestFormat = format;            }            data.ResponseFormat = SerializationFormat.Json;            if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))            {                data.ResponseFormat = format;            }            return data;        }        /// <summary>        /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.        /// </summary>        /// <param name="request"></param>        /// <returns></returns>        public ValueTask<IResponse?> HandleAsync(IRequest request)        {            try            {                switch (request.Method.KnownMethod)                {                    case RequestMethod.GET:                    case RequestMethod.HEAD:                        var current = request.Target.Current?.Value;                        if (String.Equals(current,"update"))                        {                            request.Target.Advance();                            current = request.Target.Current?.Value;                        }                        switch (current)                        {                            case "operations" or "supported_operations":                                return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());                            case "classes" or "supported_classes":                                return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());                            case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":                                return new ValueTask<IResponse?>(GetUpdateFile(request).Build());                            case "info":                                return new ValueTask<IResponse?>(GetServerInfo(request).Build());                            case "ping":                                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());                        }                        Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),                            string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));                        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());                    case RequestMethod.POST:                        var target = request.Target.Current;                        if (target is not null)                        {                            var data = GetRequestData(request);                            return target.Value switch                            {                                "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),                                "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),                                _ => HandleDatabaseRequest(request, data),                            };                        }                        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());                }                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());            }            catch(Exception e)            {                Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());            }        }                /// <summary>        /// Returns a JSON list of operation names; used for checking support of operations client side        /// </summary>        /// <param name="request"></param>        /// <returns></returns>        private IResponseBuilder GetSupportedOperations(IRequest request)        {            var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";            return request.Respond()                .Type(new FlexibleContentType(ContentType.ApplicationJson))                .Content(new ResourceContent(Resource.FromString(serialized).Build()));        }        /// <summary>        /// Returns a JSON list of class names; used for checking support of operations client side        /// </summary>        /// <param name="request"></param>        /// <returns></returns>        private IResponseBuilder GetSupportedClasses(IRequest request)        {            var serialized = Serialization.Serialize(operations) ?? "";            return request.Respond()                .Type(new FlexibleContentType(ContentType.ApplicationJson))                .Content(new ResourceContent(Resource.FromString(serialized).Build()));        }                /// <summary>        /// Returns the Splash Logo and Color Scheme for this Database        /// </summary>        /// <param name="request"></param>        /// <returns></returns>        private IResponseBuilder GetServerInfo(IRequest request)        {            var data = GetRequestData(request);            InfoResponse response = RestService.Info(new InfoRequest());            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);        }        #region Authentication        private IResponseBuilder Validate(IRequest request, RequestData data)        {            var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            var response = RestService.Validate(requestObj);            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);        }        private IResponseBuilder Check2FA(IRequest request, RequestData data)        {            var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            var response = RestService.Check2FA(requestObj);            return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);        }        #endregion        #region Database        private static MethodInfo GetMethod(string name) =>            typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)            ?? throw new Exception($"Invalid method '{name}'");        private static readonly List<Tuple<string, MethodInfo>> methodMap = new()        {            new("List", GetMethod(nameof(List))),            new("Save", GetMethod(nameof(Save))),            new("Delete", GetMethod(nameof(Delete))),            new("MultiSave", GetMethod(nameof(MultiSave))),            new("MultiDelete", GetMethod(nameof(MultiDelete)))        };        private class RequestData        {            public SerializationFormat RequestFormat { get; set; }            public SerializationFormat ResponseFormat { get; set; }            public BinarySerializationSettings BinarySerializationSettings { get; set; }            public RequestData(BinarySerializationSettings binarySerializationSettings)            {                BinarySerializationSettings = binarySerializationSettings;            }        }        private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()        {            var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            return RestService<T>.List(requestObject);        }        private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()        {            var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            return RestService<T>.Save(requestObject);        }        private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()        {            var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            return RestService<T>.Delete(requestObject);        }        private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()        {            var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            return RestService<T>.MultiSave(requestObject);        }        private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()        {            var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            return RestService<T>.MultiDelete(requestObject);        }        private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)        {            var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);            return RestService.QueryMultiple(requestObject, false);        }        private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)        {            if (stream is null)                throw new Exception("Stream is null");            if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))            {                return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);            }            else            {                var str = new StreamReader(stream).ReadToEnd();                return Serialization.Deserialize<T>(str, strict)                       ?? throw new Exception("Deserialization failed");            }        }        private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)        {            if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)            {                var stream = new MemoryStream();                binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));                var response = request.Respond()                    .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))                    .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));                return response;            }            else            {                var serialized = Serialization.Serialize(result);                var response = request.Respond()                    .Type(new FlexibleContentType(ContentType.ApplicationJson))                    .Content(new ResourceContent(Resource.FromString(serialized).Build()));                return response;            }        }        /// <summary>        /// Handler for all database requests        /// </summary>        /// <param name="request"></param>        /// <returns></returns>        private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)        {            var endpoint = request.Target.Current?.Value ?? "";            if (endpoint.StartsWith("QueryMultiple"))            {                var result = QueryMultiple(request, requestData);                return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());            }            foreach (var (name, method) in methodMap)                if (endpoint.Length > name.Length && endpoint.StartsWith(name))                {                    var entityName = endpoint[name.Length..];                    var entityType = GetEntity(entityName);                    if (entityType != null)                    {                        if (entityType.IsAssignableTo(typeof(ISecure)))                        {                            Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");                            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());                        }                        var resolvedMethod = method.MakeGenericMethod(entityType);                        var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;                        return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());                    }                    Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),                        $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");                    return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());                }            Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());        }        private Dictionary<string, Type>? _persistentRemotable;        private Type? GetEntity(string entityName)        {            _persistentRemotable ??= CoreUtils.TypeList(                e => e.IsSubclassOf(typeof(Entity)) &&                     e.GetInterfaces().Contains(typeof(IRemotable)) &&                     e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);            return _persistentRemotable.GetValueOrDefault(entityName);        }        #endregion        #region Installer                private IResponseBuilder GetUpdateFile(IRequest request)        {            var endpoint = request.Target.Current;            string? filename = null;            switch (endpoint?.Value)            {                case "version":                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");                    break;                case "releasenotes" or "release_notes":                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");                    break;                case "install" or "install_desktop":                    filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");                    break;            }            if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);            if (File.Exists(filename))            {                return request.Respond()                    .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")                    .Content(new ResourceContent(                        new FileResource(new FileInfo(filename), null, null)));            }            return request.Respond().Status(ResponseStatus.NotFound);        }        #endregion        #region GenHTTP stuff        public IHandler Parent { get; }        public ValueTask PrepareAsync()        {            return new ValueTask();        }        public IEnumerable<ContentElement> GetContent(IRequest request)        {            return Enumerable.Empty<ContentElement>();        }        #endregion    }}
 |