HTTPDatabaseProxyEngine.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. using GenHTTP.Api.Content;
  2. using GenHTTP.Api.Protocol;
  3. using GenHTTP.Modules.IO;
  4. using GenHTTP.Modules.IO.Streaming;
  5. using GenHTTP.Modules.IO.Strings;
  6. using InABox.Clients;
  7. using InABox.Core;
  8. using PRSServices;
  9. using System;
  10. using System.Collections.Generic;
  11. using System.Linq;
  12. using System.Text;
  13. using System.Threading.Tasks;
  14. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  15. namespace PRSServer;
  16. public class HTTPDatabaseProxyProperties : DatabaseProxyProperties
  17. {
  18. [IntegerEditor]
  19. [EditorSequence(2)]
  20. public int ListenPort { get; set; }
  21. [EditorSequence(3)]
  22. [FileNameEditor("Certificate Files (*.pfx)|*.pfx")]
  23. public string CertificateFile { get; set; }
  24. public override ServerType Type() => ServerType.Proxy;
  25. }
  26. internal class HTTPDatabaseProxyHandler : Handler<HTTPDatabaseProxyProperties>
  27. {
  28. private readonly List<string> endpoints;
  29. private readonly List<string> operations;
  30. public RestHandler(IHandler parent)
  31. {
  32. Parent = parent;
  33. endpoints = new();
  34. operations = new();
  35. var types = CoreUtils.TypeList(
  36. x => x.IsSubclassOf(typeof(Entity))
  37. && x.GetInterfaces().Contains(typeof(IRemotable))
  38. );
  39. var DBTypes = DbFactory.SupportedTypes();
  40. foreach (var t in types)
  41. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  42. {
  43. operations.Add(t.EntityName().Replace(".", "_"));
  44. endpoints.Add(string.Format("List{0}", t.Name));
  45. endpoints.Add(string.Format("Load{0}", t.Name));
  46. endpoints.Add(string.Format("Save{0}", t.Name));
  47. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  48. endpoints.Add(string.Format("Delete{0}", t.Name));
  49. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  50. }
  51. endpoints.Add("QueryMultiple");
  52. }
  53. private RequestData GetRequestData(IRequest request)
  54. {
  55. BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
  56. if (request.Query.TryGetValue("serializationVersion", out var versionString))
  57. {
  58. settings = BinarySerializationSettings.ConvertVersionString(versionString);
  59. }
  60. var data = new RequestData(settings);
  61. if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  62. {
  63. data.RequestFormat = format;
  64. }
  65. data.ResponseFormat = SerializationFormat.Json;
  66. if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
  67. {
  68. data.ResponseFormat = format;
  69. }
  70. return data;
  71. }
  72. /// <summary>
  73. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  74. /// </summary>
  75. /// <param name="request"></param>
  76. /// <returns></returns>
  77. public ValueTask<IResponse?> HandleAsync(IRequest request)
  78. {
  79. try
  80. {
  81. switch (request.Method.KnownMethod)
  82. {
  83. case RequestMethod.GET:
  84. case RequestMethod.HEAD:
  85. var current = request.Target.Current?.Value;
  86. if (String.Equals(current,"update"))
  87. {
  88. request.Target.Advance();
  89. current = request.Target.Current?.Value;
  90. }
  91. switch (current)
  92. {
  93. case "operations" or "supported_operations":
  94. Logger.Send(LogType.Error, "", "Supported operations is no longer supported");
  95. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  96. case "classes" or "supported_classes":
  97. Logger.Send(LogType.Error, "", "Supported classes is no longer supported");
  98. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  99. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  100. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  101. case "info":
  102. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  103. case "ping":
  104. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
  105. }
  106. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  107. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  108. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  109. case RequestMethod.POST:
  110. var target = request.Target.Current;
  111. if (target is not null)
  112. {
  113. var data = GetRequestData(request);
  114. return target.Value switch
  115. {
  116. "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
  117. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
  118. _ => HandleDatabaseRequest(request, data),
  119. };
  120. }
  121. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  122. }
  123. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  124. }
  125. catch(Exception e)
  126. {
  127. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  128. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  129. }
  130. }
  131. /// <summary>
  132. /// Returns the Splash Logo and Color Scheme for this Database
  133. /// </summary>
  134. /// <param name="request"></param>
  135. /// <returns></returns>
  136. private IResponseBuilder GetServerInfo(IRequest request)
  137. {
  138. var data = GetRequestData(request);
  139. var response = new InfoResponse(Client.Info());
  140. response.Status = StatusCode.OK;
  141. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  142. }
  143. #region Authentication
  144. private IResponseBuilder Validate(IRequest request, RequestData data)
  145. {
  146. var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  147. var response = RestService.Validate(requestObj);
  148. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  149. }
  150. private IResponseBuilder Check2FA(IRequest request, RequestData data)
  151. {
  152. var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  153. var response = RestService.Check2FA(requestObj);
  154. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  155. }
  156. #endregion
  157. #region Database
  158. private static MethodInfo GetMethod(string name) =>
  159. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  160. ?? throw new Exception($"Invalid method '{name}'");
  161. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  162. {
  163. new("List", GetMethod(nameof(List))),
  164. new("Save", GetMethod(nameof(Save))),
  165. new("Delete", GetMethod(nameof(Delete))),
  166. new("MultiSave", GetMethod(nameof(MultiSave))),
  167. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  168. };
  169. private class RequestData
  170. {
  171. public SerializationFormat RequestFormat { get; set; }
  172. public SerializationFormat ResponseFormat { get; set; }
  173. public BinarySerializationSettings BinarySerializationSettings { get; set; }
  174. public RequestData(BinarySerializationSettings binarySerializationSettings)
  175. {
  176. BinarySerializationSettings = binarySerializationSettings;
  177. }
  178. }
  179. private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
  180. {
  181. var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  182. return RestService<T>.List(requestObject);
  183. }
  184. private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
  185. {
  186. var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  187. return RestService<T>.Save(requestObject);
  188. }
  189. private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
  190. {
  191. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  192. return RestService<T>.Delete(requestObject);
  193. }
  194. private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
  195. {
  196. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  197. return RestService<T>.MultiSave(requestObject);
  198. }
  199. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
  200. {
  201. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  202. return RestService<T>.MultiDelete(requestObject);
  203. }
  204. private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
  205. {
  206. var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  207. return RestService.QueryMultiple(requestObject, false);
  208. }
  209. private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
  210. {
  211. if (stream is null)
  212. throw new Exception("Stream is null");
  213. if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
  214. {
  215. return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
  216. }
  217. else
  218. {
  219. var str = new StreamReader(stream).ReadToEnd();
  220. return Serialization.Deserialize<T>(str, strict)
  221. ?? throw new Exception("Deserialization failed");
  222. }
  223. }
  224. private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
  225. {
  226. if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
  227. {
  228. var stream = new MemoryStream();
  229. binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
  230. var response = request.Respond()
  231. .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
  232. .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
  233. return response;
  234. }
  235. else
  236. {
  237. var serialized = Serialization.Serialize(result);
  238. var response = request.Respond()
  239. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  240. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  241. return response;
  242. }
  243. }
  244. /// <summary>
  245. /// Handler for all database requests
  246. /// </summary>
  247. /// <param name="request"></param>
  248. /// <returns></returns>
  249. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
  250. {
  251. var endpoint = request.Target.Current?.Value ?? "";
  252. if (endpoint.StartsWith("QueryMultiple"))
  253. {
  254. var result = QueryMultiple(request, requestData);
  255. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  256. }
  257. foreach (var (name, method) in methodMap)
  258. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  259. {
  260. var entityName = endpoint[name.Length..];
  261. var entityType = GetEntity(entityName);
  262. if (entityType != null)
  263. {
  264. if (entityType.IsAssignableTo(typeof(ISecure)))
  265. {
  266. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  267. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  268. }
  269. var resolvedMethod = method.MakeGenericMethod(entityType);
  270. var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
  271. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  272. }
  273. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  274. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  275. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  276. }
  277. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  278. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  279. }
  280. private Dictionary<string, Type>? _persistentRemotable;
  281. private Type? GetEntity(string entityName)
  282. {
  283. _persistentRemotable ??= CoreUtils.TypeList(
  284. e => e.IsSubclassOf(typeof(Entity)) &&
  285. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  286. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  287. return _persistentRemotable.GetValueOrDefault(entityName);
  288. }
  289. #endregion
  290. #region Installer
  291. private IResponseBuilder GetUpdateFile(IRequest request)
  292. {
  293. var endpoint = request.Target.Current;
  294. switch (endpoint?.Value)
  295. {
  296. case "version":
  297. return request.Respond()
  298. .Type(new FlexibleContentType(ContentType.TextPlain))
  299. .Content(new ResourceContent(Resource.FromString(Client.Version()).Build()));
  300. case "releasenotes" or "release_notes":
  301. return request.Respond()
  302. .Type(new FlexibleContentType(ContentType.TextPlain))
  303. .Content(new ResourceContent(Resource.FromString(Client.ReleaseNotes()).Build()));
  304. case "install" or "install_desktop":
  305. return request.Respond()
  306. .Header("Content-Disposition", $"attachment; filename=PRSDesktopSetup.exe")
  307. .Content(new ResourceContent(new ByteArrayResource(Client.Installer() ?? Array.Empty<byte>(), "PRSDesktopSetup.exe", new FlexibleContentType(ContentType.ApplicationOctetStream), null)));
  308. }
  309. return request.Respond().Status(ResponseStatus.NotFound);
  310. }
  311. #endregion
  312. #region GenHTTP stuff
  313. public IHandler Parent { get; }
  314. public ValueTask PrepareAsync()
  315. {
  316. return new ValueTask();
  317. }
  318. public IEnumerable<ContentElement> GetContent(IRequest request)
  319. {
  320. return Enumerable.Empty<ContentElement>();
  321. }
  322. #endregion
  323. }
  324. internal class HTTPDatabaseProxyEngine : DatabaseProxyEngine<HTTPDatabaseProxyProperties>
  325. {
  326. private Listener<HTTPDatabaseProxyHandler, HTTPDatabaseProxyProperties>? Listener;
  327. protected override void RunProxy()
  328. {
  329. Logger.Send(LogType.Information, "", "Registering Classes");
  330. Logger.Send(LogType.Information, "", "Starting Listener on port " + Properties.ListenPort);
  331. try
  332. {
  333. Listener = new Listener<HTTPDatabaseProxyHandler, HTTPDatabaseProxyProperties>(Properties);
  334. Listener.InitHTTPS((ushort)Properties.ListenPort, CertificateFileName());
  335. Listener.Start();
  336. }
  337. catch (Exception eListen)
  338. {
  339. Logger.Send(LogType.Error, ClientFactory.UserID, eListen.Message);
  340. }
  341. }
  342. private string CertificateFileName() =>
  343. !string.IsNullOrWhiteSpace(Properties.CertificateFile)
  344. ? Properties.CertificateFile
  345. : CertificateEngine.CertificateFile;
  346. public override void Stop()
  347. {
  348. Logger.Send(LogType.Information, "", "Stopping");
  349. Listener?.Stop();
  350. }
  351. }