RestHandler.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. using System.Reflection;
  2. using GenHTTP.Api.Content;
  3. using GenHTTP.Api.Protocol;
  4. using GenHTTP.Modules.IO;
  5. using GenHTTP.Modules.IO.FileSystem;
  6. using GenHTTP.Modules.IO.Streaming;
  7. using InABox.Clients;
  8. using InABox.Core;
  9. using InABox.Database;
  10. using InABox.Server;
  11. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  12. namespace InABox.API
  13. {
  14. public class RestHandler : IHandler
  15. {
  16. public RestHandler(IHandler parent)
  17. {
  18. Parent = parent;
  19. }
  20. private RequestData GetRequestData(IRequest request)
  21. {
  22. BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
  23. if (request.Query.TryGetValue("serializationVersion", out var versionString))
  24. {
  25. settings = BinarySerializationSettings.ConvertVersionString(versionString);
  26. }
  27. var data = new RequestData(settings);
  28. if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  29. {
  30. data.RequestFormat = format;
  31. }
  32. data.ResponseFormat = SerializationFormat.Json;
  33. if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
  34. {
  35. data.ResponseFormat = format;
  36. }
  37. return data;
  38. }
  39. /// <summary>
  40. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  41. /// </summary>
  42. /// <param name="request"></param>
  43. /// <returns></returns>
  44. public ValueTask<IResponse?> HandleAsync(IRequest request)
  45. {
  46. try
  47. {
  48. switch (request.Method.KnownMethod)
  49. {
  50. case RequestMethod.GET:
  51. case RequestMethod.HEAD:
  52. var current = request.Target.Current?.Value;
  53. if (String.Equals(current,"update"))
  54. {
  55. request.Target.Advance();
  56. current = request.Target.Current?.Value;
  57. }
  58. switch (current)
  59. {
  60. case "operations" or "supported_operations":
  61. Logger.Send(LogType.Error, "", "Supported operations is no longer supported");
  62. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  63. case "classes" or "supported_classes":
  64. Logger.Send(LogType.Error, "", "Supported classes is no longer supported");
  65. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  66. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  67. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  68. case "info":
  69. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  70. case "ping":
  71. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
  72. }
  73. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  74. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  75. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  76. case RequestMethod.POST:
  77. var target = request.Target.Current;
  78. if (target is not null)
  79. {
  80. var data = GetRequestData(request);
  81. return target.Value switch
  82. {
  83. "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
  84. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
  85. _ => HandleDatabaseRequest(request, data),
  86. };
  87. }
  88. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  89. }
  90. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  91. }
  92. catch(Exception e)
  93. {
  94. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  95. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  96. }
  97. }
  98. /// <summary>
  99. /// Returns the Splash Logo and Color Scheme for this Database
  100. /// </summary>
  101. /// <param name="request"></param>
  102. /// <returns></returns>
  103. private IResponseBuilder GetServerInfo(IRequest request)
  104. {
  105. var data = GetRequestData(request);
  106. InfoResponse response = RestService.Info(new InfoRequest());
  107. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  108. }
  109. #region Authentication
  110. private IResponseBuilder Validate(IRequest request, RequestData data)
  111. {
  112. var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  113. var response = RestService.Validate(requestObj);
  114. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  115. }
  116. private IResponseBuilder Check2FA(IRequest request, RequestData data)
  117. {
  118. var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  119. var response = RestService.Check2FA(requestObj);
  120. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  121. }
  122. #endregion
  123. #region Database
  124. private static MethodInfo GetMethod(string name) =>
  125. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  126. ?? throw new Exception($"Invalid method '{name}'");
  127. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  128. {
  129. new("List", GetMethod(nameof(List))),
  130. new("Save", GetMethod(nameof(Save))),
  131. new("Delete", GetMethod(nameof(Delete))),
  132. new("MultiSave", GetMethod(nameof(MultiSave))),
  133. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  134. };
  135. private class RequestData
  136. {
  137. public SerializationFormat RequestFormat { get; set; }
  138. public SerializationFormat ResponseFormat { get; set; }
  139. public BinarySerializationSettings BinarySerializationSettings { get; set; }
  140. public RequestData(BinarySerializationSettings binarySerializationSettings)
  141. {
  142. BinarySerializationSettings = binarySerializationSettings;
  143. }
  144. }
  145. private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
  146. {
  147. var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  148. return RestService<T>.List(requestObject);
  149. }
  150. private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
  151. {
  152. var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  153. return RestService<T>.Save(requestObject);
  154. }
  155. private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
  156. {
  157. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  158. return RestService<T>.Delete(requestObject);
  159. }
  160. private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
  161. {
  162. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  163. return RestService<T>.MultiSave(requestObject);
  164. }
  165. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
  166. {
  167. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  168. return RestService<T>.MultiDelete(requestObject);
  169. }
  170. private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
  171. {
  172. var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  173. return RestService.QueryMultiple(requestObject, false);
  174. }
  175. private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
  176. {
  177. if (stream is null)
  178. throw new Exception("Stream is null");
  179. if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
  180. {
  181. return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
  182. }
  183. else
  184. {
  185. var str = new StreamReader(stream).ReadToEnd();
  186. return Serialization.Deserialize<T>(str, strict)
  187. ?? throw new Exception("Deserialization failed");
  188. }
  189. }
  190. private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
  191. {
  192. if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
  193. {
  194. var stream = new MemoryStream();
  195. binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
  196. var response = request.Respond()
  197. .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
  198. .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
  199. return response;
  200. }
  201. else
  202. {
  203. var serialized = Serialization.Serialize(result);
  204. var response = request.Respond()
  205. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  206. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  207. return response;
  208. }
  209. }
  210. /// <summary>
  211. /// Handler for all database requests
  212. /// </summary>
  213. /// <param name="request"></param>
  214. /// <returns></returns>
  215. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
  216. {
  217. var endpoint = request.Target.Current?.Value ?? "";
  218. if (endpoint.StartsWith("QueryMultiple"))
  219. {
  220. var result = QueryMultiple(request, requestData);
  221. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  222. }
  223. foreach (var (name, method) in methodMap)
  224. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  225. {
  226. var entityName = endpoint[name.Length..];
  227. var entityType = GetEntity(entityName);
  228. if (entityType != null)
  229. {
  230. if (entityType.IsAssignableTo(typeof(ISecure)))
  231. {
  232. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  233. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  234. }
  235. var resolvedMethod = method.MakeGenericMethod(entityType);
  236. var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
  237. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  238. }
  239. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  240. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  241. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  242. }
  243. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  244. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  245. }
  246. private Dictionary<string, Type>? _persistentRemotable;
  247. private Type? GetEntity(string entityName)
  248. {
  249. _persistentRemotable ??= CoreUtils.TypeList(
  250. e => e.IsSubclassOf(typeof(Entity)) &&
  251. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  252. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  253. return _persistentRemotable.GetValueOrDefault(entityName);
  254. }
  255. #endregion
  256. #region Installer
  257. private IResponseBuilder GetUpdateFile(IRequest request)
  258. {
  259. var endpoint = request.Target.Current;
  260. string? filename = null;
  261. switch (endpoint?.Value)
  262. {
  263. case "version":
  264. filename = UpdateData.GetUpdateVersionFile();
  265. break;
  266. case "releasenotes" or "release_notes":
  267. filename = UpdateData.GetReleaseNotesFile();
  268. break;
  269. case "install" or "install_desktop":
  270. filename = UpdateData.GetUpdateInstallerFile();
  271. break;
  272. }
  273. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  274. if (File.Exists(filename))
  275. {
  276. return request.Respond()
  277. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  278. .Content(new ResourceContent(
  279. new FileResource(new FileInfo(filename), null, null)));
  280. }
  281. return request.Respond().Status(ResponseStatus.NotFound);
  282. }
  283. #endregion
  284. #region GenHTTP stuff
  285. public IHandler Parent { get; }
  286. public ValueTask PrepareAsync()
  287. {
  288. return new ValueTask();
  289. }
  290. public IEnumerable<ContentElement> GetContent(IRequest request)
  291. {
  292. return Enumerable.Empty<ContentElement>();
  293. }
  294. #endregion
  295. }
  296. }