RestListener.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. using System.Collections.Generic;
  2. using System.Diagnostics.CodeAnalysis;
  3. using System.Net;
  4. using System.Reflection;
  5. using System.Security.Cryptography.X509Certificates;
  6. using GenHTTP.Api.Content;
  7. using GenHTTP.Api.Infrastructure;
  8. using GenHTTP.Api.Protocol;
  9. using GenHTTP.Engine;
  10. using GenHTTP.Modules.IO;
  11. using GenHTTP.Modules.IO.FileSystem;
  12. using GenHTTP.Modules.IO.Streaming;
  13. using GenHTTP.Modules.Practices;
  14. using InABox.Clients;
  15. using InABox.Core;
  16. using InABox.Database;
  17. using InABox.Server.WebSocket;
  18. using InABox.WebSocket.Shared;
  19. using NPOI.SS.Formula.Functions;
  20. using NPOI.XSSF.Streaming.Values;
  21. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  22. using StreamContent = GenHTTP.Modules.IO.Streaming.StreamContent;
  23. namespace InABox.API
  24. {
  25. public class RestHandler : IHandler
  26. {
  27. private readonly List<string> endpoints;
  28. private readonly List<string> operations;
  29. private int? WebSocketPort;
  30. public RestHandler(IHandler parent, int? webSocketPort)
  31. {
  32. WebSocketPort = webSocketPort;
  33. Parent = parent;
  34. endpoints = new();
  35. operations = new();
  36. var types = CoreUtils.TypeList(
  37. x => x.IsSubclassOf(typeof(Entity))
  38. && x.GetInterfaces().Contains(typeof(IRemotable))
  39. );
  40. var DBTypes = DbFactory.SupportedTypes();
  41. foreach (var t in types)
  42. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  43. {
  44. operations.Add(t.EntityName().Replace(".", "_"));
  45. endpoints.Add(string.Format("List{0}", t.Name));
  46. endpoints.Add(string.Format("Load{0}", t.Name));
  47. endpoints.Add(string.Format("Save{0}", t.Name));
  48. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  49. endpoints.Add(string.Format("Delete{0}", t.Name));
  50. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  51. }
  52. endpoints.Add("QueryMultiple");
  53. }
  54. /// <summary>
  55. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  56. /// </summary>
  57. /// <param name="request"></param>
  58. /// <returns></returns>
  59. public ValueTask<IResponse?> HandleAsync(IRequest request)
  60. {
  61. try
  62. {
  63. switch (request.Method.KnownMethod)
  64. {
  65. case RequestMethod.GET:
  66. case RequestMethod.HEAD:
  67. var current = request.Target.Current?.Value;
  68. if (String.Equals(current,"update"))
  69. {
  70. request.Target.Advance();
  71. current = request.Target.Current?.Value;
  72. }
  73. switch (current)
  74. {
  75. case "operations" or "supported_operations":
  76. return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
  77. case "classes" or "supported_classes":
  78. return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
  79. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  80. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  81. case "info":
  82. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  83. }
  84. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  85. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  86. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  87. case RequestMethod.POST:
  88. var target = request.Target.Current;
  89. if (target is not null)
  90. {
  91. return target.Value switch
  92. {
  93. "validate" => new ValueTask<IResponse?>(Validate(request).Build()),
  94. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request).Build()),
  95. "notify" => new ValueTask<IResponse?>(GetNotify(request).Build()),
  96. _ => HandleDatabaseRequest(request),
  97. };
  98. }
  99. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  100. }
  101. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  102. }
  103. catch(Exception e)
  104. {
  105. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  106. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  107. }
  108. }
  109. /// <summary>
  110. /// Returns a JSON list of operation names; used for checking support of operations client side
  111. /// </summary>
  112. /// <param name="request"></param>
  113. /// <returns></returns>
  114. private IResponseBuilder GetSupportedOperations(IRequest request)
  115. {
  116. var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
  117. return request.Respond()
  118. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  119. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  120. }
  121. /// <summary>
  122. /// Returns a JSON list of class names; used for checking support of operations client side
  123. /// </summary>
  124. /// <param name="request"></param>
  125. /// <returns></returns>
  126. private IResponseBuilder GetSupportedClasses(IRequest request)
  127. {
  128. var serialized = Serialization.Serialize(operations) ?? "";
  129. return request.Respond()
  130. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  131. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  132. }
  133. /// <summary>
  134. /// Returns the Splash Logo and Color Scheme for this Database
  135. /// </summary>
  136. /// <param name="request"></param>
  137. /// <returns></returns>
  138. private IResponseBuilder GetServerInfo(IRequest request)
  139. {
  140. InfoResponse response = RestService.Info(new InfoRequest());
  141. var serialized = Core.Serialization.Serialize(response, true) ?? "";
  142. return request.Respond()
  143. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  144. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  145. }
  146. /// <summary>
  147. /// Gets port for web socket
  148. /// </summary>
  149. /// <param name="request"></param>
  150. /// <returns></returns>
  151. private IResponseBuilder GetNotify(IRequest request)
  152. {
  153. var requestObj = Deserialize<NotifyRequest>(request.Content, true);
  154. if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
  155. {
  156. return request.Respond().Status(ResponseStatus.NotFound);
  157. }
  158. var response = new NotifyResponse
  159. {
  160. Status = StatusCode.OK,
  161. SocketPort = WebSocketPort
  162. };
  163. var serialized = Serialization.Serialize(response);
  164. return request.Respond()
  165. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  166. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  167. }
  168. #region Authentication
  169. private IResponseBuilder Validate(IRequest request)
  170. {
  171. var requestObj = Deserialize<ValidateRequest>(request.Content, true);
  172. var response = RestService.Validate(requestObj);
  173. var serialized = Serialization.Serialize(response);
  174. return request.Respond()
  175. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  176. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  177. }
  178. private IResponseBuilder Check2FA(IRequest request)
  179. {
  180. var requestObj = Deserialize<Check2FARequest>(request.Content, true);
  181. var response = RestService.Check2FA(requestObj);
  182. var serialized = Serialization.Serialize(response);
  183. return request.Respond()
  184. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  185. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  186. }
  187. #endregion
  188. #region Database
  189. private static MethodInfo GetMethod(string name) =>
  190. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  191. ?? throw new Exception($"Invalid method '{name}'");
  192. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  193. {
  194. new("List", GetMethod(nameof(List))),
  195. new("Save", GetMethod(nameof(Save))),
  196. new("Delete", GetMethod(nameof(Delete))),
  197. new("MultiSave", GetMethod(nameof(MultiSave))),
  198. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  199. };
  200. private static QueryResponse<T> List<T>(IRequest request) where T : Entity, new()
  201. {
  202. var requestObject = Deserialize<QueryRequest<T>>(request.Content, true);
  203. return RestService<T>.List(requestObject);
  204. }
  205. private static SaveResponse<T> Save<T>(IRequest request) where T : Entity, new()
  206. {
  207. var requestObject = Deserialize<SaveRequest<T>>(request.Content, true);
  208. return RestService<T>.Save(requestObject);
  209. }
  210. private static DeleteResponse<T> Delete<T>(IRequest request) where T : Entity, new()
  211. {
  212. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, true);
  213. return RestService<T>.Delete(requestObject);
  214. }
  215. private static MultiSaveResponse<T> MultiSave<T>(IRequest request) where T : Entity, new()
  216. {
  217. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, true);
  218. return RestService<T>.MultiSave(requestObject);
  219. }
  220. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request) where T : Entity, new()
  221. {
  222. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, true);
  223. return RestService<T>.MultiDelete(requestObject);
  224. }
  225. private static T Deserialize<T>(Stream? stream, bool strict = false)
  226. {
  227. if (stream is null)
  228. throw new Exception("Stream is null");
  229. var str = new StreamReader(stream).ReadToEnd();
  230. return Serialization.Deserialize<T>(str, strict)
  231. ?? throw new Exception("Deserialization failed");
  232. }
  233. /// <summary>
  234. /// Handler for all database requests
  235. /// </summary>
  236. /// <param name="request"></param>
  237. /// <returns></returns>
  238. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request)
  239. {
  240. var endpoint = request.Target.Current.Value;
  241. if (endpoint.StartsWith("QueryMultiple"))
  242. {
  243. var requestObject = Deserialize<MultiQueryRequest>(request.Content, true);
  244. var result = RestService.QueryMultiple(requestObject, false);
  245. var serialized = Serialization.Serialize(result);
  246. var response = request.Respond()
  247. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  248. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  249. return new ValueTask<IResponse?>(response.Build());
  250. }
  251. foreach (var (name, method) in methodMap)
  252. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  253. {
  254. var entityName = endpoint[name.Length..];
  255. var entityType = GetEntity(entityName);
  256. if (entityType != null)
  257. {
  258. if (entityType.IsAssignableTo(typeof(ISecure)))
  259. {
  260. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  261. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  262. }
  263. var resolvedMethod = method.MakeGenericMethod(entityType);
  264. var result = resolvedMethod.Invoke(null, new object[] { request }) as Response;
  265. var serialized = Serialization.Serialize(result);
  266. var response = request.Respond()
  267. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  268. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  269. return new ValueTask<IResponse?>(response.Build());
  270. }
  271. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  272. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  273. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  274. }
  275. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  276. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  277. }
  278. private Dictionary<string, Type>? _persistentRemotable;
  279. private Type? GetEntity(string entityName)
  280. {
  281. _persistentRemotable ??= CoreUtils.TypeList(
  282. e => e.IsSubclassOf(typeof(Entity)) &&
  283. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  284. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  285. return _persistentRemotable.GetValueOrDefault(entityName);
  286. }
  287. #endregion
  288. #region Installer
  289. private IResponseBuilder GetUpdateFile(IRequest request)
  290. {
  291. var endpoint = request.Target.Current;
  292. string? filename = null;
  293. switch (endpoint?.Value)
  294. {
  295. case "version":
  296. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
  297. break;
  298. case "releasenotes" or "release_notes":
  299. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
  300. break;
  301. case "install" or "install_desktop":
  302. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
  303. break;
  304. }
  305. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  306. if (File.Exists(filename))
  307. {
  308. return request.Respond()
  309. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  310. .Content(new ResourceContent(
  311. new FileResource(new FileInfo(filename), null, null)));
  312. }
  313. return request.Respond().Status(ResponseStatus.NotFound);
  314. }
  315. #endregion
  316. #region GenHTTP stuff
  317. public IHandler Parent { get; }
  318. public ValueTask PrepareAsync()
  319. {
  320. return new ValueTask();
  321. }
  322. public IEnumerable<ContentElement> GetContent(IRequest request)
  323. {
  324. return Enumerable.Empty<ContentElement>();
  325. }
  326. #endregion
  327. }
  328. public class RestHandlerBuilder : IHandlerBuilder<RestHandlerBuilder>
  329. {
  330. private readonly List<IConcernBuilder> _Concerns = new();
  331. private int? WebSocketPort;
  332. public RestHandlerBuilder(int? webSocketPort)
  333. {
  334. WebSocketPort = webSocketPort;
  335. }
  336. public RestHandlerBuilder Add(IConcernBuilder concern)
  337. {
  338. _Concerns.Add(concern);
  339. return this;
  340. }
  341. public IHandler Build(IHandler parent)
  342. {
  343. return Concerns.Chain(parent, _Concerns, p => new RestHandler(p, WebSocketPort));
  344. }
  345. }
  346. class RestNotifier : Notifier
  347. {
  348. private WebSocketServer SocketServer;
  349. public int Port => SocketServer.Port;
  350. public RestNotifier(int port)
  351. {
  352. SocketServer = new WebSocketServer(port);
  353. SocketServer.Poll += SocketServer_Poll;
  354. }
  355. private void SocketServer_Poll(NotifyState.Session session)
  356. {
  357. Poll(session.SessionID);
  358. }
  359. public void Start()
  360. {
  361. SocketServer.Start();
  362. }
  363. public void Stop()
  364. {
  365. SocketServer.Stop();
  366. }
  367. protected override void NotifyAll<TNotification>(TNotification notification)
  368. {
  369. SocketServer.Push(notification);
  370. }
  371. protected override void NotifySession(Guid session, Type TNotification, object? notification)
  372. {
  373. SocketServer.Push(session, TNotification, notification);
  374. }
  375. protected override void NotifySession<TNotification>(Guid session, TNotification notification)
  376. {
  377. SocketServer.Push(session, notification);
  378. }
  379. protected override IEnumerable<Guid> GetUserSessions(Guid userID)
  380. {
  381. return CredentialsCache.GetUserSessions(userID);
  382. }
  383. protected override IEnumerable<Guid> GetSessions(Platform platform)
  384. {
  385. return SocketServer.GetSessions(platform);
  386. }
  387. }
  388. public static class RestListener
  389. {
  390. private static IServerHost? host;
  391. private static X509Certificate2? certificate;
  392. private static RestNotifier? notifier;
  393. public static X509Certificate2? Certificate { get => certificate; }
  394. public static void Start()
  395. {
  396. host?.Start();
  397. notifier?.Start();
  398. }
  399. public static void Stop()
  400. {
  401. host?.Stop();
  402. notifier?.Stop();
  403. }
  404. public static void InitCertificate(ushort port, X509Certificate2 certificate)
  405. {
  406. RestListener.certificate = certificate;
  407. host?.Bind(IPAddress.Any, port, certificate);
  408. }
  409. public static void InitCertificate(ushort port, string certificateFile)
  410. {
  411. InitCertificate(port, new X509Certificate2(certificateFile));
  412. }
  413. public static void InitPort(ushort port)
  414. {
  415. host?.Bind(IPAddress.Any, port);
  416. }
  417. /// <summary>
  418. /// Clears certificate and host information, and stops the listener.
  419. /// </summary>
  420. public static void Clear()
  421. {
  422. host?.Stop();
  423. host = null;
  424. notifier?.Stop();
  425. notifier = null;
  426. certificate = null;
  427. }
  428. public static void Init(int webSocketPort)
  429. {
  430. if(webSocketPort != 0)
  431. {
  432. notifier = new RestNotifier(webSocketPort);
  433. Notify.Notifier = notifier;
  434. }
  435. host = Host.Create();
  436. host.Handler(new RestHandlerBuilder(notifier?.Port))
  437. .Defaults();
  438. }
  439. }
  440. }