using System.Reflection; using GenHTTP.Api.Content; using GenHTTP.Api.Protocol; using GenHTTP.Modules.IO; using GenHTTP.Modules.IO.FileSystem; using GenHTTP.Modules.IO.Streaming; using InABox.Clients; using InABox.Core; using InABox.Database; using InABox.Server; using RequestMethod = GenHTTP.Api.Protocol.RequestMethod; namespace InABox.API { public class RestHandler : IHandler { private readonly List endpoints; private readonly List operations; public RestHandler(IHandler parent) { Parent = parent; endpoints = new(); operations = new(); var types = CoreUtils.TypeList( x => x.IsSubclassOf(typeof(Entity)) && x.GetInterfaces().Contains(typeof(IRemotable)) ); var DBTypes = DbFactory.SupportedTypes(); foreach (var t in types) if (DBTypes.Contains(t.EntityName().Replace(".", "_"))) { operations.Add(t.EntityName().Replace(".", "_")); endpoints.Add(string.Format("List{0}", t.Name)); endpoints.Add(string.Format("Load{0}", t.Name)); endpoints.Add(string.Format("Save{0}", t.Name)); endpoints.Add(string.Format("MultiSave{0}", t.Name)); endpoints.Add(string.Format("Delete{0}", t.Name)); endpoints.Add(string.Format("MultiDelete{0}", t.Name)); } endpoints.Add("QueryMultiple"); } private RequestData GetRequestData(IRequest request) { BinarySerializationSettings settings = BinarySerializationSettings.V1_0; if (request.Query.TryGetValue("serializationVersion", out var versionString)) { settings = BinarySerializationSettings.ConvertVersionString(versionString); } var data = new RequestData(settings); if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse(formatString, out var format)) { data.RequestFormat = format; } data.ResponseFormat = SerializationFormat.Json; if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse(formatString, out format)) { data.ResponseFormat = format; } return data; } /// /// The main handler for the server; an HTTP request comes in, an HTTP response goes out. /// /// /// public ValueTask HandleAsync(IRequest request) { try { switch (request.Method.KnownMethod) { case RequestMethod.GET: case RequestMethod.HEAD: var current = request.Target.Current?.Value; if (String.Equals(current,"update")) { request.Target.Advance(); current = request.Target.Current?.Value; } switch (current) { case "operations" or "supported_operations": return new ValueTask(GetSupportedOperations(request).Build()); case "classes" or "supported_classes": return new ValueTask(GetSupportedClasses(request).Build()); case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop": return new ValueTask(GetUpdateFile(request).Build()); case "info": return new ValueTask(GetServerInfo(request).Build()); case "ping": return new ValueTask(request.Respond().Status(ResponseStatus.OK).Build()); } Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current)); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); case RequestMethod.POST: var target = request.Target.Current; if (target is not null) { var data = GetRequestData(request); return target.Value switch { "validate" => new ValueTask(Validate(request, data).Build()), "check_2fa" => new ValueTask(Check2FA(request, data).Build()), _ => HandleDatabaseRequest(request, data), }; } return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } return new ValueTask(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build()); } catch(Exception e) { Logger.Send(LogType.Error, "", CoreUtils.FormatException(e)); return new ValueTask(request.Respond().Status(ResponseStatus.InternalServerError).Build()); } } /// /// Returns a JSON list of operation names; used for checking support of operations client side /// /// /// private IResponseBuilder GetSupportedOperations(IRequest request) { var serialized = Core.Serialization.Serialize(endpoints, true) ?? ""; return request.Respond() .Type(new FlexibleContentType(ContentType.ApplicationJson)) .Content(new ResourceContent(Resource.FromString(serialized).Build())); } /// /// Returns a JSON list of class names; used for checking support of operations client side /// /// /// private IResponseBuilder GetSupportedClasses(IRequest request) { var serialized = Serialization.Serialize(operations) ?? ""; return request.Respond() .Type(new FlexibleContentType(ContentType.ApplicationJson)) .Content(new ResourceContent(Resource.FromString(serialized).Build())); } /// /// Returns the Splash Logo and Color Scheme for this Database /// /// /// private IResponseBuilder GetServerInfo(IRequest request) { var data = GetRequestData(request); InfoResponse response = RestService.Info(new InfoRequest()); return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response); } #region Authentication private IResponseBuilder Validate(IRequest request, RequestData data) { var requestObj = Deserialize(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); var response = RestService.Validate(requestObj); return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response); } private IResponseBuilder Check2FA(IRequest request, RequestData data) { var requestObj = Deserialize(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); var response = RestService.Check2FA(requestObj); return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response); } #endregion #region Database private static MethodInfo GetMethod(string name) => typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static) ?? throw new Exception($"Invalid method '{name}'"); private static readonly List> methodMap = new() { new("List", GetMethod(nameof(List))), new("Save", GetMethod(nameof(Save))), new("Delete", GetMethod(nameof(Delete))), new("MultiSave", GetMethod(nameof(MultiSave))), new("MultiDelete", GetMethod(nameof(MultiDelete))) }; private class RequestData { public SerializationFormat RequestFormat { get; set; } public SerializationFormat ResponseFormat { get; set; } public BinarySerializationSettings BinarySerializationSettings { get; set; } public RequestData(BinarySerializationSettings binarySerializationSettings) { BinarySerializationSettings = binarySerializationSettings; } } private static QueryResponse List(IRequest request, RequestData data) where T : Entity, new() { var requestObject = Deserialize>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); return RestService.List(requestObject); } private static SaveResponse Save(IRequest request, RequestData data) where T : Entity, new() { var requestObject = Deserialize>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); return RestService.Save(requestObject); } private static DeleteResponse Delete(IRequest request, RequestData data) where T : Entity, new() { var requestObject = Deserialize>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); return RestService.Delete(requestObject); } private static MultiSaveResponse MultiSave(IRequest request, RequestData data) where T : Entity, new() { var requestObject = Deserialize>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); return RestService.MultiSave(requestObject); } private static MultiDeleteResponse MultiDelete(IRequest request, RequestData data) where T : Entity, new() { var requestObject = Deserialize>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); return RestService.MultiDelete(requestObject); } private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data) { var requestObject = Deserialize(request.Content, data.RequestFormat, data.BinarySerializationSettings, true); return RestService.QueryMultiple(requestObject, false); } private static T Deserialize(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false) { if (stream is null) throw new Exception("Stream is null"); if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary))) { return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings); } else { var str = new StreamReader(stream).ReadToEnd(); return Serialization.Deserialize(str, strict) ?? throw new Exception("Deserialization failed"); } } private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result) { if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary) { var stream = new MemoryStream(); binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings)); var response = request.Respond() .Type(new FlexibleContentType(ContentType.ApplicationOctetStream)) .Content(stream, (ulong?)stream.Length, () => new ValueTask((ulong)stream.GetHashCode())); return response; } else { var serialized = Serialization.Serialize(result); var response = request.Respond() .Type(new FlexibleContentType(ContentType.ApplicationJson)) .Content(new ResourceContent(Resource.FromString(serialized).Build())); return response; } } /// /// Handler for all database requests /// /// /// private ValueTask HandleDatabaseRequest(IRequest request, RequestData requestData) { var endpoint = request.Target.Current?.Value ?? ""; if (endpoint.StartsWith("QueryMultiple")) { var result = QueryMultiple(request, requestData); return new ValueTask(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build()); } foreach (var (name, method) in methodMap) if (endpoint.Length > name.Length && endpoint.StartsWith(name)) { var entityName = endpoint[name.Length..]; var entityType = GetEntity(entityName); if (entityType != null) { if (entityType.IsAssignableTo(typeof(ISecure))) { Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } var resolvedMethod = method.MakeGenericMethod(entityType); var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response; return new ValueTask(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build()); } Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist"); return new ValueTask(request.Respond().Status(ResponseStatus.NotFound).Build()); } private Dictionary? _persistentRemotable; private Type? GetEntity(string entityName) { _persistentRemotable ??= CoreUtils.TypeList( e => e.IsSubclassOf(typeof(Entity)) && e.GetInterfaces().Contains(typeof(IRemotable)) && e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x); return _persistentRemotable.GetValueOrDefault(entityName); } #endregion #region Installer private IResponseBuilder GetUpdateFile(IRequest request) { var endpoint = request.Target.Current; string? filename = null; switch (endpoint?.Value) { case "version": filename = UpdateData.GetUpdateVersionFile(); break; case "releasenotes" or "release_notes": filename = UpdateData.GetReleaseNotesFile(); break; case "install" or "install_desktop": filename = UpdateData.GetUpdateInstallerFile(); break; } if (filename is null) return request.Respond().Status(ResponseStatus.NotFound); if (File.Exists(filename)) { return request.Respond() .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}") .Content(new ResourceContent( new FileResource(new FileInfo(filename), null, null))); } return request.Respond().Status(ResponseStatus.NotFound); } #endregion #region GenHTTP stuff public IHandler Parent { get; } public ValueTask PrepareAsync() { return new ValueTask(); } public IEnumerable GetContent(IRequest request) { return Enumerable.Empty(); } #endregion } }