RestListener.cs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. using System.Collections.Generic;
  2. using System.Diagnostics.CodeAnalysis;
  3. using System.Net;
  4. using System.Reflection;
  5. using System.Security.Cryptography.X509Certificates;
  6. using GenHTTP.Api.Content;
  7. using GenHTTP.Api.Infrastructure;
  8. using GenHTTP.Api.Protocol;
  9. using GenHTTP.Engine;
  10. using GenHTTP.Modules.IO;
  11. using GenHTTP.Modules.IO.FileSystem;
  12. using GenHTTP.Modules.IO.Streaming;
  13. using GenHTTP.Modules.Practices;
  14. using InABox.Clients;
  15. using InABox.Core;
  16. using InABox.Database;
  17. using InABox.Remote.Shared;
  18. using InABox.Server.WebSocket;
  19. using InABox.WebSocket.Shared;
  20. using NPOI.SS.Formula.Functions;
  21. using NPOI.XSSF.Streaming.Values;
  22. using Twilio.Rest.Taskrouter.V1.Workspace.TaskQueue;
  23. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  24. using StreamContent = GenHTTP.Modules.IO.Streaming.StreamContent;
  25. namespace InABox.API
  26. {
  27. public class RestHandler : IHandler
  28. {
  29. private readonly List<string> endpoints;
  30. private readonly List<string> operations;
  31. private int? WebSocketPort;
  32. public RestHandler(IHandler parent, int? webSocketPort)
  33. {
  34. WebSocketPort = webSocketPort;
  35. Parent = parent;
  36. endpoints = new();
  37. operations = new();
  38. var types = CoreUtils.TypeList(
  39. x => x.IsSubclassOf(typeof(Entity))
  40. && x.GetInterfaces().Contains(typeof(IRemotable))
  41. );
  42. var DBTypes = DbFactory.SupportedTypes();
  43. foreach (var t in types)
  44. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  45. {
  46. operations.Add(t.EntityName().Replace(".", "_"));
  47. endpoints.Add(string.Format("List{0}", t.Name));
  48. endpoints.Add(string.Format("Load{0}", t.Name));
  49. endpoints.Add(string.Format("Save{0}", t.Name));
  50. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  51. endpoints.Add(string.Format("Delete{0}", t.Name));
  52. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  53. }
  54. endpoints.Add("QueryMultiple");
  55. }
  56. /// <summary>
  57. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  58. /// </summary>
  59. /// <param name="request"></param>
  60. /// <returns></returns>
  61. public ValueTask<IResponse?> HandleAsync(IRequest request)
  62. {
  63. try
  64. {
  65. switch (request.Method.KnownMethod)
  66. {
  67. case RequestMethod.GET:
  68. case RequestMethod.HEAD:
  69. var current = request.Target.Current?.Value;
  70. if (String.Equals(current,"update"))
  71. {
  72. request.Target.Advance();
  73. current = request.Target.Current?.Value;
  74. }
  75. switch (current)
  76. {
  77. case "operations" or "supported_operations":
  78. return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
  79. case "classes" or "supported_classes":
  80. return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
  81. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  82. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  83. case "info":
  84. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  85. }
  86. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  87. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  88. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  89. case RequestMethod.POST:
  90. var target = request.Target.Current;
  91. if (target is not null)
  92. {
  93. return target.Value switch
  94. {
  95. "validate" => new ValueTask<IResponse?>(Validate(request).Build()),
  96. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request).Build()),
  97. "notify" => new ValueTask<IResponse?>(GetNotify(request).Build()),
  98. _ => HandleDatabaseRequest(request),
  99. };
  100. }
  101. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  102. }
  103. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  104. }
  105. catch(Exception e)
  106. {
  107. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  108. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  109. }
  110. }
  111. /// <summary>
  112. /// Returns a JSON list of operation names; used for checking support of operations client side
  113. /// </summary>
  114. /// <param name="request"></param>
  115. /// <returns></returns>
  116. private IResponseBuilder GetSupportedOperations(IRequest request)
  117. {
  118. var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
  119. return request.Respond()
  120. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  121. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  122. }
  123. /// <summary>
  124. /// Returns a JSON list of class names; used for checking support of operations client side
  125. /// </summary>
  126. /// <param name="request"></param>
  127. /// <returns></returns>
  128. private IResponseBuilder GetSupportedClasses(IRequest request)
  129. {
  130. var serialized = Serialization.Serialize(operations) ?? "";
  131. return request.Respond()
  132. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  133. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  134. }
  135. /// <summary>
  136. /// Returns the Splash Logo and Color Scheme for this Database
  137. /// </summary>
  138. /// <param name="request"></param>
  139. /// <returns></returns>
  140. private IResponseBuilder GetServerInfo(IRequest request)
  141. {
  142. InfoResponse response = RestService.Info(new InfoRequest());
  143. var serialized = Core.Serialization.Serialize(response, true) ?? "";
  144. return request.Respond()
  145. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  146. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  147. }
  148. /// <summary>
  149. /// Gets port for web socket
  150. /// </summary>
  151. /// <param name="request"></param>
  152. /// <returns></returns>
  153. private IResponseBuilder GetNotify(IRequest request)
  154. {
  155. var requestObj = Deserialize<NotifyRequest>(request.Content, true);
  156. if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
  157. {
  158. return request.Respond().Status(ResponseStatus.NotFound);
  159. }
  160. var response = new NotifyResponse
  161. {
  162. Status = StatusCode.OK,
  163. SocketPort = WebSocketPort
  164. };
  165. var serialized = Serialization.Serialize(response);
  166. return request.Respond()
  167. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  168. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  169. }
  170. #region Authentication
  171. private IResponseBuilder Validate(IRequest request)
  172. {
  173. var requestObj = Deserialize<ValidateRequest>(request.Content, true);
  174. var response = RestService.Validate(requestObj);
  175. var serialized = Serialization.Serialize(response);
  176. return request.Respond()
  177. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  178. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  179. }
  180. private IResponseBuilder Check2FA(IRequest request)
  181. {
  182. var requestObj = Deserialize<Check2FARequest>(request.Content, true);
  183. var response = RestService.Check2FA(requestObj);
  184. var serialized = Serialization.Serialize(response);
  185. return request.Respond()
  186. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  187. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  188. }
  189. #endregion
  190. #region Database
  191. private static MethodInfo GetMethod(string name) =>
  192. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  193. ?? throw new Exception($"Invalid method '{name}'");
  194. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  195. {
  196. new("List", GetMethod(nameof(List))),
  197. new("Save", GetMethod(nameof(Save))),
  198. new("Delete", GetMethod(nameof(Delete))),
  199. new("MultiSave", GetMethod(nameof(MultiSave))),
  200. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  201. };
  202. private static QueryResponse<T> List<T>(IRequest request) where T : Entity, new()
  203. {
  204. var requestObject = Deserialize<QueryRequest<T>>(request.Content, true);
  205. return RestService<T>.List(requestObject);
  206. }
  207. private static SaveResponse<T> Save<T>(IRequest request) where T : Entity, new()
  208. {
  209. var requestObject = Deserialize<SaveRequest<T>>(request.Content, true);
  210. return RestService<T>.Save(requestObject);
  211. }
  212. private static DeleteResponse<T> Delete<T>(IRequest request) where T : Entity, new()
  213. {
  214. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, true);
  215. return RestService<T>.Delete(requestObject);
  216. }
  217. private static MultiSaveResponse<T> MultiSave<T>(IRequest request) where T : Entity, new()
  218. {
  219. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, true);
  220. return RestService<T>.MultiSave(requestObject);
  221. }
  222. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request) where T : Entity, new()
  223. {
  224. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, true);
  225. return RestService<T>.MultiDelete(requestObject);
  226. }
  227. private static T Deserialize<T>(Stream? stream, bool strict = false)
  228. {
  229. if (stream is null)
  230. throw new Exception("Stream is null");
  231. var str = new StreamReader(stream).ReadToEnd();
  232. return Serialization.Deserialize<T>(str, strict)
  233. ?? throw new Exception("Deserialization failed");
  234. }
  235. private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, Response? result)
  236. {
  237. if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
  238. {
  239. var stream = new MemoryStream();
  240. binary.SerializeBinary(new BinaryWriter(stream));
  241. var response = request.Respond()
  242. .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
  243. .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
  244. return response;
  245. }
  246. else
  247. {
  248. var serialized = Serialization.Serialize(result);
  249. var response = request.Respond()
  250. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  251. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  252. return response;
  253. }
  254. }
  255. /// <summary>
  256. /// Handler for all database requests
  257. /// </summary>
  258. /// <param name="request"></param>
  259. /// <returns></returns>
  260. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request)
  261. {
  262. var responseFormat = SerializationFormat.Json;
  263. if (request.Query.TryGetValue("responseFormat", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  264. {
  265. responseFormat = format;
  266. }
  267. var endpoint = request.Target.Current.Value;
  268. if (endpoint.StartsWith("QueryMultiple"))
  269. {
  270. var requestObject = Deserialize<MultiQueryRequest>(request.Content, true);
  271. var result = RestService.QueryMultiple(requestObject, false);
  272. return new ValueTask<IResponse?>(SerializeResponse(request, responseFormat, result).Build());
  273. }
  274. foreach (var (name, method) in methodMap)
  275. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  276. {
  277. var entityName = endpoint[name.Length..];
  278. var entityType = GetEntity(entityName);
  279. if (entityType != null)
  280. {
  281. if (entityType.IsAssignableTo(typeof(ISecure)))
  282. {
  283. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  284. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  285. }
  286. var resolvedMethod = method.MakeGenericMethod(entityType);
  287. var result = resolvedMethod.Invoke(null, new object[] { request }) as Response;
  288. return new ValueTask<IResponse?>(SerializeResponse(request, responseFormat, result).Build());
  289. }
  290. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  291. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  292. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  293. }
  294. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  295. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  296. }
  297. private Dictionary<string, Type>? _persistentRemotable;
  298. private Type? GetEntity(string entityName)
  299. {
  300. _persistentRemotable ??= CoreUtils.TypeList(
  301. e => e.IsSubclassOf(typeof(Entity)) &&
  302. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  303. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  304. return _persistentRemotable.GetValueOrDefault(entityName);
  305. }
  306. #endregion
  307. #region Installer
  308. private IResponseBuilder GetUpdateFile(IRequest request)
  309. {
  310. var endpoint = request.Target.Current;
  311. string? filename = null;
  312. switch (endpoint?.Value)
  313. {
  314. case "version":
  315. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
  316. break;
  317. case "releasenotes" or "release_notes":
  318. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
  319. break;
  320. case "install" or "install_desktop":
  321. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
  322. break;
  323. }
  324. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  325. if (File.Exists(filename))
  326. {
  327. return request.Respond()
  328. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  329. .Content(new ResourceContent(
  330. new FileResource(new FileInfo(filename), null, null)));
  331. }
  332. return request.Respond().Status(ResponseStatus.NotFound);
  333. }
  334. #endregion
  335. #region GenHTTP stuff
  336. public IHandler Parent { get; }
  337. public ValueTask PrepareAsync()
  338. {
  339. return new ValueTask();
  340. }
  341. public IEnumerable<ContentElement> GetContent(IRequest request)
  342. {
  343. return Enumerable.Empty<ContentElement>();
  344. }
  345. #endregion
  346. }
  347. public class RestHandlerBuilder : IHandlerBuilder<RestHandlerBuilder>
  348. {
  349. private readonly List<IConcernBuilder> _Concerns = new();
  350. private int? WebSocketPort;
  351. public RestHandlerBuilder(int? webSocketPort)
  352. {
  353. WebSocketPort = webSocketPort;
  354. }
  355. public RestHandlerBuilder Add(IConcernBuilder concern)
  356. {
  357. _Concerns.Add(concern);
  358. return this;
  359. }
  360. public IHandler Build(IHandler parent)
  361. {
  362. return Concerns.Chain(parent, _Concerns, p => new RestHandler(p, WebSocketPort));
  363. }
  364. }
  365. class RestNotifier : Notifier
  366. {
  367. private WebSocketServer SocketServer;
  368. public int Port => SocketServer.Port;
  369. public RestNotifier(int port)
  370. {
  371. SocketServer = new WebSocketServer(port);
  372. SocketServer.Poll += SocketServer_Poll;
  373. }
  374. private void SocketServer_Poll(NotifyState.Session session)
  375. {
  376. Poll(session.SessionID);
  377. }
  378. public void Start()
  379. {
  380. SocketServer.Start();
  381. }
  382. public void Stop()
  383. {
  384. SocketServer.Stop();
  385. }
  386. protected override void NotifyAll<TNotification>(TNotification notification)
  387. {
  388. SocketServer.Push(notification);
  389. }
  390. protected override void NotifySession(Guid session, Type TNotification, object? notification)
  391. {
  392. SocketServer.Push(session, TNotification, notification);
  393. }
  394. protected override void NotifySession<TNotification>(Guid session, TNotification notification)
  395. {
  396. SocketServer.Push(session, notification);
  397. }
  398. protected override IEnumerable<Guid> GetUserSessions(Guid userID)
  399. {
  400. return CredentialsCache.GetUserSessions(userID);
  401. }
  402. protected override IEnumerable<Guid> GetSessions(Platform platform)
  403. {
  404. return SocketServer.GetSessions(platform);
  405. }
  406. }
  407. public static class RestListener
  408. {
  409. private static IServerHost? host;
  410. private static X509Certificate2? certificate;
  411. private static RestNotifier? notifier;
  412. public static X509Certificate2? Certificate { get => certificate; }
  413. public static void Start()
  414. {
  415. host?.Start();
  416. notifier?.Start();
  417. }
  418. public static void Stop()
  419. {
  420. host?.Stop();
  421. notifier?.Stop();
  422. }
  423. public static void InitCertificate(ushort port, X509Certificate2 certificate)
  424. {
  425. RestListener.certificate = certificate;
  426. host?.Bind(IPAddress.Any, port, certificate);
  427. }
  428. public static void InitCertificate(ushort port, string certificateFile)
  429. {
  430. InitCertificate(port, new X509Certificate2(certificateFile));
  431. }
  432. public static void InitPort(ushort port)
  433. {
  434. host?.Bind(IPAddress.Any, port);
  435. }
  436. /// <summary>
  437. /// Clears certificate and host information, and stops the listener.
  438. /// </summary>
  439. public static void Clear()
  440. {
  441. host?.Stop();
  442. host = null;
  443. notifier?.Stop();
  444. notifier = null;
  445. certificate = null;
  446. }
  447. public static void Init(int webSocketPort)
  448. {
  449. if(webSocketPort != 0)
  450. {
  451. notifier = new RestNotifier(webSocketPort);
  452. Notify.Notifier = notifier;
  453. }
  454. host = Host.Create();
  455. host.Handler(new RestHandlerBuilder(notifier?.Port))
  456. .Defaults();
  457. }
  458. }
  459. }