RestHandler.cs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. using System.Reflection;
  2. using GenHTTP.Api.Content;
  3. using GenHTTP.Api.Protocol;
  4. using GenHTTP.Modules.IO;
  5. using GenHTTP.Modules.IO.FileSystem;
  6. using GenHTTP.Modules.IO.Streaming;
  7. using InABox.Clients;
  8. using InABox.Core;
  9. using InABox.Database;
  10. using InABox.Server;
  11. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  12. namespace InABox.API
  13. {
  14. public class RestHandler : IHandler
  15. {
  16. private readonly List<string> endpoints;
  17. private readonly List<string> operations;
  18. public RestHandler(IHandler parent)
  19. {
  20. Parent = parent;
  21. endpoints = new();
  22. operations = new();
  23. var types = CoreUtils.TypeList(
  24. x => x.IsSubclassOf(typeof(Entity))
  25. && x.GetInterfaces().Contains(typeof(IRemotable))
  26. );
  27. var DBTypes = DbFactory.SupportedTypes();
  28. foreach (var t in types)
  29. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  30. {
  31. operations.Add(t.EntityName().Replace(".", "_"));
  32. endpoints.Add(string.Format("List{0}", t.Name));
  33. endpoints.Add(string.Format("Load{0}", t.Name));
  34. endpoints.Add(string.Format("Save{0}", t.Name));
  35. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  36. endpoints.Add(string.Format("Delete{0}", t.Name));
  37. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  38. }
  39. endpoints.Add("QueryMultiple");
  40. }
  41. private RequestData GetRequestData(IRequest request)
  42. {
  43. BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
  44. if (request.Query.TryGetValue("serializationVersion", out var versionString))
  45. {
  46. settings = BinarySerializationSettings.ConvertVersionString(versionString);
  47. }
  48. var data = new RequestData(settings);
  49. if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  50. {
  51. data.RequestFormat = format;
  52. }
  53. data.ResponseFormat = SerializationFormat.Json;
  54. if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
  55. {
  56. data.ResponseFormat = format;
  57. }
  58. return data;
  59. }
  60. public IAsyncEnumerable<ContentElement> GetContentAsync(IRequest request)
  61. {
  62. throw new NotImplementedException();
  63. }
  64. /// <summary>
  65. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  66. /// </summary>
  67. /// <param name="request"></param>
  68. /// <returns></returns>
  69. public ValueTask<IResponse?> HandleAsync(IRequest request)
  70. {
  71. try
  72. {
  73. switch (request.Method.KnownMethod)
  74. {
  75. case RequestMethod.GET:
  76. case RequestMethod.HEAD:
  77. var current = request.Target.Current?.Value;
  78. if (String.Equals(current,"update"))
  79. {
  80. request.Target.Advance();
  81. current = request.Target.Current?.Value;
  82. }
  83. switch (current)
  84. {
  85. case "operations" or "supported_operations":
  86. return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
  87. case "classes" or "supported_classes":
  88. return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
  89. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  90. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  91. case "info":
  92. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  93. case "ping":
  94. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
  95. }
  96. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  97. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  98. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  99. case RequestMethod.POST:
  100. var target = request.Target.Current;
  101. if (target is not null)
  102. {
  103. var data = GetRequestData(request);
  104. return target.Value switch
  105. {
  106. "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
  107. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
  108. _ => HandleDatabaseRequest(request, data),
  109. };
  110. }
  111. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  112. }
  113. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  114. }
  115. catch(Exception e)
  116. {
  117. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  118. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  119. }
  120. }
  121. /// <summary>
  122. /// Returns a JSON list of operation names; used for checking support of operations client side
  123. /// </summary>
  124. /// <param name="request"></param>
  125. /// <returns></returns>
  126. private IResponseBuilder GetSupportedOperations(IRequest request)
  127. {
  128. var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
  129. return request.Respond()
  130. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  131. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  132. }
  133. /// <summary>
  134. /// Returns a JSON list of class names; used for checking support of operations client side
  135. /// </summary>
  136. /// <param name="request"></param>
  137. /// <returns></returns>
  138. private IResponseBuilder GetSupportedClasses(IRequest request)
  139. {
  140. var serialized = Serialization.Serialize(operations) ?? "";
  141. return request.Respond()
  142. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  143. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  144. }
  145. /// <summary>
  146. /// Returns the Splash Logo and Color Scheme for this Database
  147. /// </summary>
  148. /// <param name="request"></param>
  149. /// <returns></returns>
  150. private IResponseBuilder GetServerInfo(IRequest request)
  151. {
  152. var data = GetRequestData(request);
  153. InfoResponse response = RestService.Info(new InfoRequest(), Logger.New());
  154. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  155. }
  156. #region Authentication
  157. private IResponseBuilder Validate(IRequest request, RequestData data)
  158. {
  159. var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  160. var response = RestService.Validate(requestObj, Logger.New());
  161. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  162. }
  163. private IResponseBuilder Check2FA(IRequest request, RequestData data)
  164. {
  165. var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  166. var response = RestService.Check2FA(requestObj, Logger.New());
  167. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  168. }
  169. #endregion
  170. #region Database
  171. private static MethodInfo GetMethod(string name) =>
  172. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  173. ?? throw new Exception($"Invalid method '{name}'");
  174. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  175. {
  176. new("List", GetMethod(nameof(List))),
  177. new("Save", GetMethod(nameof(Save))),
  178. new("Delete", GetMethod(nameof(Delete))),
  179. new("MultiSave", GetMethod(nameof(MultiSave))),
  180. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  181. };
  182. private class RequestData
  183. {
  184. public SerializationFormat RequestFormat { get; set; }
  185. public SerializationFormat ResponseFormat { get; set; }
  186. public BinarySerializationSettings BinarySerializationSettings { get; set; }
  187. public RequestData(BinarySerializationSettings binarySerializationSettings)
  188. {
  189. BinarySerializationSettings = binarySerializationSettings;
  190. }
  191. }
  192. private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
  193. {
  194. var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  195. return RestService<T>.List(requestObject, Logger.New());
  196. }
  197. private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
  198. {
  199. var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  200. return RestService<T>.Save(requestObject, Logger.New());
  201. }
  202. private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
  203. {
  204. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  205. return RestService<T>.Delete(requestObject, Logger.New());
  206. }
  207. private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
  208. {
  209. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  210. return RestService<T>.MultiSave(requestObject, Logger.New());
  211. }
  212. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
  213. {
  214. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  215. return RestService<T>.MultiDelete(requestObject, Logger.New());
  216. }
  217. private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
  218. {
  219. var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  220. return RestService.QueryMultiple(requestObject, false, Logger.New());
  221. }
  222. private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
  223. {
  224. if (stream is null)
  225. throw new Exception("Stream is null");
  226. if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
  227. {
  228. return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
  229. }
  230. else
  231. {
  232. var str = new StreamReader(stream).ReadToEnd();
  233. return Serialization.Deserialize<T>(str, strict)
  234. ?? throw new Exception("Deserialization failed");
  235. }
  236. }
  237. private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
  238. {
  239. if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
  240. {
  241. var stream = new MemoryStream();
  242. binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
  243. var response = request.Respond()
  244. .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
  245. .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
  246. return response;
  247. }
  248. else
  249. {
  250. var serialized = Serialization.Serialize(result);
  251. var response = request.Respond()
  252. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  253. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  254. return response;
  255. }
  256. }
  257. /// <summary>
  258. /// Handler for all database requests
  259. /// </summary>
  260. /// <param name="request"></param>
  261. /// <returns></returns>
  262. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
  263. {
  264. var endpoint = request.Target.Current?.Value ?? "";
  265. if (endpoint.StartsWith("QueryMultiple"))
  266. {
  267. var result = QueryMultiple(request, requestData);
  268. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  269. }
  270. foreach (var (name, method) in methodMap)
  271. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  272. {
  273. var entityName = endpoint[name.Length..];
  274. var entityType = GetEntity(entityName);
  275. if (entityType != null)
  276. {
  277. if (entityType.IsAssignableTo(typeof(ISecure)))
  278. {
  279. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  280. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  281. }
  282. var resolvedMethod = method.MakeGenericMethod(entityType);
  283. var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
  284. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  285. }
  286. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  287. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  288. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  289. }
  290. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  291. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  292. }
  293. private Dictionary<string, Type>? _persistentRemotable;
  294. private Type? GetEntity(string entityName)
  295. {
  296. _persistentRemotable ??= CoreUtils.TypeList(
  297. e => e.IsSubclassOf(typeof(Entity)) &&
  298. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  299. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  300. return _persistentRemotable.GetValueOrDefault(entityName);
  301. }
  302. #endregion
  303. #region Installer
  304. private IResponseBuilder GetUpdateFile(IRequest request)
  305. {
  306. var endpoint = request.Target.Current;
  307. string? filename = null;
  308. switch (endpoint?.Value)
  309. {
  310. case "version":
  311. filename = UpdateData.GetUpdateVersionFile();
  312. break;
  313. case "releasenotes" or "release_notes":
  314. filename = UpdateData.GetReleaseNotesFile();
  315. break;
  316. case "install" or "install_desktop":
  317. filename = UpdateData.GetUpdateInstallerFile();
  318. break;
  319. }
  320. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  321. if (File.Exists(filename))
  322. {
  323. return request.Respond()
  324. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  325. .Content(new ResourceContent(
  326. new FileResource(new FileInfo(filename), null, null)));
  327. }
  328. return request.Respond().Status(ResponseStatus.NotFound);
  329. }
  330. #endregion
  331. #region GenHTTP stuff
  332. public IHandler Parent { get; }
  333. public ValueTask PrepareAsync()
  334. {
  335. return new ValueTask();
  336. }
  337. public IEnumerable<ContentElement> GetContent(IRequest request)
  338. {
  339. return Enumerable.Empty<ContentElement>();
  340. }
  341. #endregion
  342. }
  343. }