RestHandler.cs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. using System.Reflection;
  2. using GenHTTP.Api.Content;
  3. using GenHTTP.Api.Protocol;
  4. using GenHTTP.Modules.IO;
  5. using GenHTTP.Modules.IO.FileSystem;
  6. using GenHTTP.Modules.IO.Streaming;
  7. using InABox.Clients;
  8. using InABox.Core;
  9. using InABox.Database;
  10. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  11. namespace InABox.API
  12. {
  13. public class RestHandler : IHandler
  14. {
  15. private readonly List<string> endpoints;
  16. private readonly List<string> operations;
  17. //private int? WebSocketPort;
  18. public RestHandler(IHandler parent) //, int? webSocketPort)
  19. {
  20. // WebSocketPort = webSocketPort;
  21. Parent = parent;
  22. endpoints = new();
  23. operations = new();
  24. var types = CoreUtils.TypeList(
  25. x => x.IsSubclassOf(typeof(Entity))
  26. && x.GetInterfaces().Contains(typeof(IRemotable))
  27. );
  28. var DBTypes = DbFactory.SupportedTypes();
  29. foreach (var t in types)
  30. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  31. {
  32. operations.Add(t.EntityName().Replace(".", "_"));
  33. endpoints.Add(string.Format("List{0}", t.Name));
  34. endpoints.Add(string.Format("Load{0}", t.Name));
  35. endpoints.Add(string.Format("Save{0}", t.Name));
  36. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  37. endpoints.Add(string.Format("Delete{0}", t.Name));
  38. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  39. }
  40. endpoints.Add("QueryMultiple");
  41. }
  42. private RequestData GetRequestData(IRequest request)
  43. {
  44. BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
  45. if (request.Query.TryGetValue("serializationVersion", out var versionString))
  46. {
  47. settings = BinarySerializationSettings.ConvertVersionString(versionString);
  48. }
  49. var data = new RequestData(settings);
  50. if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  51. {
  52. data.RequestFormat = format;
  53. }
  54. data.ResponseFormat = SerializationFormat.Json;
  55. if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
  56. {
  57. data.ResponseFormat = format;
  58. }
  59. return data;
  60. }
  61. /// <summary>
  62. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  63. /// </summary>
  64. /// <param name="request"></param>
  65. /// <returns></returns>
  66. public ValueTask<IResponse?> HandleAsync(IRequest request)
  67. {
  68. try
  69. {
  70. switch (request.Method.KnownMethod)
  71. {
  72. case RequestMethod.GET:
  73. case RequestMethod.HEAD:
  74. var current = request.Target.Current?.Value;
  75. if (String.Equals(current,"update"))
  76. {
  77. request.Target.Advance();
  78. current = request.Target.Current?.Value;
  79. }
  80. switch (current)
  81. {
  82. case "operations" or "supported_operations":
  83. return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
  84. case "classes" or "supported_classes":
  85. return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
  86. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  87. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  88. case "info":
  89. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  90. case "ping":
  91. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
  92. }
  93. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  94. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  95. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  96. case RequestMethod.POST:
  97. var target = request.Target.Current;
  98. if (target is not null)
  99. {
  100. var data = GetRequestData(request);
  101. return target.Value switch
  102. {
  103. "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
  104. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
  105. //"notify" or "push" => new ValueTask<IResponse?>(GetPush(request, data).Build()),
  106. _ => HandleDatabaseRequest(request, data),
  107. };
  108. }
  109. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  110. }
  111. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  112. }
  113. catch(Exception e)
  114. {
  115. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  116. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  117. }
  118. }
  119. /// <summary>
  120. /// Returns a JSON list of operation names; used for checking support of operations client side
  121. /// </summary>
  122. /// <param name="request"></param>
  123. /// <returns></returns>
  124. private IResponseBuilder GetSupportedOperations(IRequest request)
  125. {
  126. var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
  127. return request.Respond()
  128. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  129. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  130. }
  131. /// <summary>
  132. /// Returns a JSON list of class names; used for checking support of operations client side
  133. /// </summary>
  134. /// <param name="request"></param>
  135. /// <returns></returns>
  136. private IResponseBuilder GetSupportedClasses(IRequest request)
  137. {
  138. var serialized = Serialization.Serialize(operations) ?? "";
  139. return request.Respond()
  140. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  141. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  142. }
  143. /// <summary>
  144. /// Returns the Splash Logo and Color Scheme for this Database
  145. /// </summary>
  146. /// <param name="request"></param>
  147. /// <returns></returns>
  148. private IResponseBuilder GetServerInfo(IRequest request)
  149. {
  150. var data = GetRequestData(request);
  151. InfoResponse response = RestService.Info(new InfoRequest());
  152. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  153. }
  154. /// <summary>
  155. /// Gets port for web socket
  156. /// </summary>
  157. /// <param name="request"></param>
  158. /// <returns></returns>
  159. // private IResponseBuilder GetPush(IRequest request, RequestData data)
  160. // {
  161. // var requestObj = Deserialize<PushRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  162. // if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
  163. // {
  164. // return request.Respond().Status(ResponseStatus.NotFound);
  165. // }
  166. // var response = new PushResponse
  167. // {
  168. // Status = StatusCode.OK,
  169. // SocketPort = WebSocketPort
  170. // };
  171. // return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  172. // }
  173. #region Authentication
  174. private IResponseBuilder Validate(IRequest request, RequestData data)
  175. {
  176. var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  177. var response = RestService.Validate(requestObj);
  178. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  179. }
  180. private IResponseBuilder Check2FA(IRequest request, RequestData data)
  181. {
  182. var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  183. var response = RestService.Check2FA(requestObj);
  184. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  185. }
  186. #endregion
  187. #region Database
  188. private static MethodInfo GetMethod(string name) =>
  189. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  190. ?? throw new Exception($"Invalid method '{name}'");
  191. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  192. {
  193. new("List", GetMethod(nameof(List))),
  194. new("Save", GetMethod(nameof(Save))),
  195. new("Delete", GetMethod(nameof(Delete))),
  196. new("MultiSave", GetMethod(nameof(MultiSave))),
  197. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  198. };
  199. private class RequestData
  200. {
  201. public SerializationFormat RequestFormat { get; set; }
  202. public SerializationFormat ResponseFormat { get; set; }
  203. public BinarySerializationSettings BinarySerializationSettings { get; set; }
  204. public RequestData(BinarySerializationSettings binarySerializationSettings)
  205. {
  206. BinarySerializationSettings = binarySerializationSettings;
  207. }
  208. }
  209. private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
  210. {
  211. var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  212. return RestService<T>.List(requestObject);
  213. }
  214. private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
  215. {
  216. var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  217. return RestService<T>.Save(requestObject);
  218. }
  219. private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
  220. {
  221. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  222. return RestService<T>.Delete(requestObject);
  223. }
  224. private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
  225. {
  226. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  227. return RestService<T>.MultiSave(requestObject);
  228. }
  229. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
  230. {
  231. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  232. return RestService<T>.MultiDelete(requestObject);
  233. }
  234. private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
  235. {
  236. var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  237. return RestService.QueryMultiple(requestObject, false);
  238. }
  239. private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
  240. {
  241. if (stream is null)
  242. throw new Exception("Stream is null");
  243. if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
  244. {
  245. return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
  246. }
  247. else
  248. {
  249. var str = new StreamReader(stream).ReadToEnd();
  250. return Serialization.Deserialize<T>(str, strict)
  251. ?? throw new Exception("Deserialization failed");
  252. }
  253. }
  254. private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
  255. {
  256. if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
  257. {
  258. var stream = new MemoryStream();
  259. binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
  260. var response = request.Respond()
  261. .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
  262. .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
  263. return response;
  264. }
  265. else
  266. {
  267. var serialized = Serialization.Serialize(result);
  268. var response = request.Respond()
  269. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  270. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  271. return response;
  272. }
  273. }
  274. /// <summary>
  275. /// Handler for all database requests
  276. /// </summary>
  277. /// <param name="request"></param>
  278. /// <returns></returns>
  279. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
  280. {
  281. var endpoint = request.Target.Current?.Value ?? "";
  282. if (endpoint.StartsWith("QueryMultiple"))
  283. {
  284. var result = QueryMultiple(request, requestData);
  285. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  286. }
  287. foreach (var (name, method) in methodMap)
  288. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  289. {
  290. var entityName = endpoint[name.Length..];
  291. var entityType = GetEntity(entityName);
  292. if (entityType != null)
  293. {
  294. if (entityType.IsAssignableTo(typeof(ISecure)))
  295. {
  296. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  297. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  298. }
  299. var resolvedMethod = method.MakeGenericMethod(entityType);
  300. var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
  301. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  302. }
  303. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  304. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  305. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  306. }
  307. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  308. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  309. }
  310. private Dictionary<string, Type>? _persistentRemotable;
  311. private Type? GetEntity(string entityName)
  312. {
  313. _persistentRemotable ??= CoreUtils.TypeList(
  314. e => e.IsSubclassOf(typeof(Entity)) &&
  315. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  316. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  317. return _persistentRemotable.GetValueOrDefault(entityName);
  318. }
  319. #endregion
  320. #region Installer
  321. private IResponseBuilder GetUpdateFile(IRequest request)
  322. {
  323. var endpoint = request.Target.Current;
  324. string? filename = null;
  325. switch (endpoint?.Value)
  326. {
  327. case "version":
  328. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
  329. break;
  330. case "releasenotes" or "release_notes":
  331. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
  332. break;
  333. case "install" or "install_desktop":
  334. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
  335. break;
  336. }
  337. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  338. if (File.Exists(filename))
  339. {
  340. return request.Respond()
  341. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  342. .Content(new ResourceContent(
  343. new FileResource(new FileInfo(filename), null, null)));
  344. }
  345. return request.Respond().Status(ResponseStatus.NotFound);
  346. }
  347. #endregion
  348. #region GenHTTP stuff
  349. public IHandler Parent { get; }
  350. public ValueTask PrepareAsync()
  351. {
  352. return new ValueTask();
  353. }
  354. public IEnumerable<ContentElement> GetContent(IRequest request)
  355. {
  356. return Enumerable.Empty<ContentElement>();
  357. }
  358. #endregion
  359. }
  360. }