DbFactory.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. using System.Reflection;
  2. using FluentResults;
  3. using InABox.Clients;
  4. using InABox.Configuration;
  5. using InABox.Core;
  6. using InABox.Scripting;
  7. namespace InABox.Database;
  8. public class DatabaseMetadata : BaseObject, IGlobalConfigurationSettings
  9. {
  10. public Guid DatabaseID { get; set; } = Guid.NewGuid();
  11. }
  12. public class DbLockedException : Exception
  13. {
  14. public DbLockedException(): base("Database is read-only due to PRS license expiry.") { }
  15. }
  16. public static class DbFactory
  17. {
  18. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  19. private static DatabaseMetadata MetaData { get; set; } = new();
  20. public static Guid ID
  21. {
  22. get => MetaData.DatabaseID;
  23. set
  24. {
  25. MetaData.DatabaseID = value;
  26. SaveMetadata();
  27. }
  28. }
  29. private static IProviderFactory? _providerFactory;
  30. public static IProviderFactory ProviderFactory
  31. {
  32. get => _providerFactory ?? throw new Exception("Provider is not set");
  33. set => _providerFactory = value;
  34. }
  35. public static bool IsProviderSet => _providerFactory is not null;
  36. public static string? ColorScheme { get; set; }
  37. public static byte[]? Logo { get; set; }
  38. // See notes in Request.DatabaseInfo class
  39. // Once RPC transport is stable, these settings need
  40. // to be removed
  41. public static int RestPort { get; set; }
  42. public static int RPCPort { get; set; }
  43. /// <summary>
  44. /// Return every <see cref="IPersistent"/> entity in <see cref="CoreUtils.Entities"/>.
  45. /// </summary>
  46. public static IEnumerable<Type> Entities => CoreUtils.Entities.Where(x => x.HasInterface<IPersistent>());
  47. public static Type[] Stores
  48. {
  49. get => stores;
  50. set => SetStoreTypes(value);
  51. }
  52. public static DateTime Expiry { get; set; }
  53. public static IProvider NewProvider(Logger logger) => ProviderFactory.NewProvider(logger);
  54. public static void Start(Type[]? types = null)
  55. {
  56. CoreUtils.CheckLicensing();
  57. if(types is not null)
  58. {
  59. ProviderFactory.Types = types.Concat(CoreUtils.IterateTypes(typeof(CoreUtils).Assembly).Where(x => !x.IsAbstract))
  60. .Where(x => x.IsClass && !x.IsGenericType && x.IsSubclassOf(typeof(Entity)))
  61. .ToArray();
  62. }
  63. else
  64. {
  65. ProviderFactory.Types = Entities.Where(x =>
  66. x.IsClass
  67. && !x.IsGenericType
  68. && x.IsSubclassOf(typeof(Entity))
  69. ).ToArray();
  70. }
  71. // Start the provider
  72. ProviderFactory.Start();
  73. CheckMetadata();
  74. if (!DataUpdater.MigrateDatabase())
  75. {
  76. throw new Exception("Database migration failed. Aborting startup");
  77. }
  78. //Load up your custom properties here!
  79. // Can't use clients (b/c we're inside the database layer already
  80. // but we can simply access the store directly :-)
  81. //CustomProperty[] props = FindStore<CustomProperty>("", "", "", "").Load(new Filter<CustomProperty>(x=>x.ID).IsNotEqualTo(Guid.Empty),null);
  82. var props = ProviderFactory.NewProvider(Logger.Main).Query<CustomProperty>().ToArray<CustomProperty>();
  83. DatabaseSchema.Load(props);
  84. AssertLicense();
  85. BeginLicenseCheckTimer();
  86. InitStores();
  87. LoadScripts();
  88. }
  89. #region MetaData
  90. private static void SaveMetadata()
  91. {
  92. var settings = new GlobalSettings
  93. {
  94. Section = nameof(DatabaseMetadata),
  95. Key = "",
  96. Contents = Serialization.Serialize(MetaData)
  97. };
  98. ProviderFactory.NewProvider(Logger.Main).Save(settings);
  99. }
  100. private static void CheckMetadata()
  101. {
  102. var result = ProviderFactory.NewProvider(Logger.Main).Query(new Filter<GlobalSettings>(x => x.Section).IsEqualTo(nameof(DatabaseMetadata)))
  103. .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
  104. var data = result is not null ? Serialization.Deserialize<DatabaseMetadata>(result.Contents) : null;
  105. if (data is null)
  106. {
  107. MetaData = new DatabaseMetadata();
  108. SaveMetadata();
  109. }
  110. else
  111. {
  112. MetaData = data;
  113. }
  114. }
  115. #endregion
  116. #region License
  117. private enum LicenseValidation
  118. {
  119. Valid,
  120. Missing,
  121. Expired,
  122. Corrupt,
  123. Tampered
  124. }
  125. private static LicenseValidation CheckLicenseValidity(out DateTime expiry)
  126. {
  127. var provider = ProviderFactory.NewProvider(Logger.New());
  128. expiry = DateTime.MinValue;
  129. var license = provider.Load<License>().FirstOrDefault();
  130. if (license is null)
  131. return LicenseValidation.Missing;
  132. if (!LicenseUtils.TryDecryptLicense(license.Data, out var licenseData, out var error))
  133. return LicenseValidation.Corrupt;
  134. if (!LicenseUtils.ValidateMacAddresses(licenseData.Addresses))
  135. return LicenseValidation.Tampered;
  136. var userTrackingItems = provider.Query(
  137. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  138. Columns.None<UserTracking>().Add(x => x.ID)
  139. , log: false
  140. ).Rows
  141. .Select(r => r.Get<UserTracking, Guid>(c => c.ID))
  142. .ToArray();
  143. foreach(var item in licenseData.UserTrackingItems)
  144. {
  145. if (!userTrackingItems.Contains(item))
  146. return LicenseValidation.Tampered;
  147. }
  148. expiry = licenseData.Expiry;
  149. if (licenseData.Expiry < DateTime.Now)
  150. return LicenseValidation.Expired;
  151. return LicenseValidation.Valid;
  152. }
  153. private static int _expiredLicenseCounter = 0;
  154. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  155. private static bool _readOnly;
  156. public static bool IsReadOnly { get => _readOnly; }
  157. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  158. private static void LogRenew(string message)
  159. {
  160. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  161. }
  162. public static void LogReadOnly()
  163. {
  164. LogImportant($"Your database is in read-only mode; please renew your license to enable database updates.");
  165. }
  166. private static void LogLicenseExpiry(DateTime expiry)
  167. {
  168. if (expiry.Date == DateTime.Today)
  169. {
  170. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  171. return;
  172. }
  173. var diffInDays = (expiry - DateTime.Now).TotalDays;
  174. if(diffInDays < 1)
  175. {
  176. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  177. }
  178. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  179. {
  180. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  181. _expiredLicenseCounter = 0;
  182. }
  183. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  184. {
  185. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  186. _expiredLicenseCounter = 0;
  187. }
  188. ++_expiredLicenseCounter;
  189. }
  190. private static void BeginReadOnly()
  191. {
  192. if (!IsReadOnly)
  193. {
  194. LogImportant(
  195. "Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  196. _readOnly = true;
  197. }
  198. }
  199. private static void EndReadOnly()
  200. {
  201. if (IsReadOnly)
  202. {
  203. LogImportant("Valid license found; the database is no longer read-only.");
  204. _readOnly = false;
  205. }
  206. }
  207. private static void BeginLicenseCheckTimer()
  208. {
  209. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  210. LicenseTimer.Start();
  211. }
  212. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  213. {
  214. AssertLicense();
  215. }
  216. public static void AssertLicense()
  217. {
  218. var result = CheckLicenseValidity(out DateTime expiry);
  219. switch (result)
  220. {
  221. case LicenseValidation.Valid:
  222. LogLicenseExpiry(expiry);
  223. EndReadOnly();
  224. break;
  225. case LicenseValidation.Missing:
  226. LogImportant("Database is unlicensed!");
  227. BeginReadOnly();
  228. break;
  229. case LicenseValidation.Expired:
  230. LogImportant("Database license has expired!");
  231. BeginReadOnly();
  232. break;
  233. case LicenseValidation.Corrupt:
  234. LogImportant("Database license is corrupt - you will need to renew your license.");
  235. BeginReadOnly();
  236. break;
  237. case LicenseValidation.Tampered:
  238. LogImportant("Database license has been tampered with - you will need to renew your license.");
  239. BeginReadOnly();
  240. break;
  241. }
  242. }
  243. #endregion
  244. #region Logging
  245. private static void LogInfo(string message)
  246. {
  247. Logger.Send(LogType.Information, "", message);
  248. }
  249. private static void LogImportant(string message)
  250. {
  251. Logger.Send(LogType.Important, "", message);
  252. }
  253. private static void LogError(string message)
  254. {
  255. Logger.Send(LogType.Error, "", message);
  256. }
  257. #endregion
  258. public static void InitStores()
  259. {
  260. foreach (var storetype in stores)
  261. {
  262. var store = (Activator.CreateInstance(storetype) as IStore)!;
  263. store.Provider = ProviderFactory.NewProvider(Logger.Main);
  264. store.Logger = Logger.Main;
  265. store.Init();
  266. }
  267. }
  268. public static IStore FindStore(Type type, Guid userguid, string userid, Platform platform, string version, Logger logger)
  269. {
  270. var defType = typeof(Store<>).MakeGenericType(type);
  271. Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  272. var store = (Activator.CreateInstance(subType ?? defType) as IStore)!;
  273. store.Provider = ProviderFactory.NewProvider(logger);
  274. store.UserGuid = userguid;
  275. store.UserID = userid;
  276. store.Platform = platform;
  277. store.Version = version;
  278. store.Logger = logger;
  279. return store;
  280. }
  281. public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, Platform platform, string version, Logger logger)
  282. where TEntity : Entity, new()
  283. {
  284. return (FindStore(typeof(TEntity), userguid, userid, platform, version, logger) as IStore<TEntity>)!;
  285. }
  286. private static CoreTable DoQueryMultipleQuery<TEntity>(
  287. IQueryDef query,
  288. Guid userguid, string userid, Platform platform, string version, Logger logger)
  289. where TEntity : Entity, new()
  290. {
  291. var store = FindStore<TEntity>(userguid, userid, platform, version, logger);
  292. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  293. }
  294. public static Dictionary<string, CoreTable> QueryMultiple(
  295. Dictionary<string, IQueryDef> queries,
  296. Guid userguid, string userid, Platform platform, string version, Logger logger)
  297. {
  298. var result = new Dictionary<string, CoreTable>();
  299. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  300. var tasks = new List<Task>();
  301. foreach (var item in queries)
  302. tasks.Add(Task.Run(() =>
  303. {
  304. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(ProviderFactory, new object[]
  305. {
  306. item.Value,
  307. userguid, userid, platform, version, logger
  308. }) as CoreTable)!;
  309. }));
  310. Task.WaitAll(tasks.ToArray());
  311. return result;
  312. }
  313. #region Supported Types
  314. private class ModuleConfiguration : Dictionary<string, bool>, ILocalConfigurationSettings
  315. {
  316. }
  317. private static Type[]? _dbtypes;
  318. public static IEnumerable<string> SupportedTypes()
  319. {
  320. _dbtypes ??= LoadSupportedTypes();
  321. return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));
  322. }
  323. private static Type[] LoadSupportedTypes()
  324. {
  325. var result = new List<Type>();
  326. var path = ProviderFactory.URL.ToLower();
  327. var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();
  328. var bChanged = false;
  329. foreach (var type in Entities)
  330. {
  331. var key = type.EntityName();
  332. if (config.TryGetValue(key, out bool value))
  333. {
  334. if (value)
  335. //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));
  336. result.Add(type);
  337. else
  338. Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));
  339. }
  340. else
  341. {
  342. //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));
  343. config[key] = true;
  344. result.Add(type);
  345. bChanged = true;
  346. }
  347. }
  348. if (bChanged)
  349. new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);
  350. return result.ToArray();
  351. }
  352. public static bool IsSupported<T>() where T : Entity
  353. {
  354. _dbtypes ??= LoadSupportedTypes();
  355. return _dbtypes.Contains(typeof(T));
  356. }
  357. #endregion
  358. //public static void OpenSession(bool write)
  359. //{
  360. // Provider.OpenSession(write);
  361. //}
  362. //public static void CloseSession()
  363. //{
  364. // Provider.CloseSession();
  365. //}
  366. #region Private Methods
  367. public static void LoadScripts()
  368. {
  369. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  370. LoadedScripts.Clear();
  371. var scripts = ProviderFactory.NewProvider(Logger.Main).Load(
  372. new Filter<Script>
  373. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  374. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  375. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  376. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  377. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  378. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  379. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  380. );
  381. foreach (var script in scripts)
  382. {
  383. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  384. var doc = new ScriptDocument(script.Code);
  385. if (doc.Compile())
  386. {
  387. Logger.Send(LogType.Information, "",
  388. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  389. LoadedScripts[key] = doc;
  390. }
  391. else
  392. {
  393. Logger.Send(LogType.Error, "",
  394. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  395. }
  396. }
  397. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  398. }
  399. //private static Type[] entities = null;
  400. //private static void SetEntityTypes(Type[] types)
  401. //{
  402. // foreach (Type type in types)
  403. // {
  404. // if (!type.IsSubclassOf(typeof(Entity)))
  405. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  406. // }
  407. // entities = types;
  408. //}
  409. private static Type[] stores = { };
  410. private static void SetStoreTypes(Type[] types)
  411. {
  412. types = types.Where(
  413. myType => myType.IsClass
  414. && !myType.IsAbstract
  415. && !myType.IsGenericType).ToArray();
  416. foreach (var type in types)
  417. if (!type.GetInterfaces().Contains(typeof(IStore)))
  418. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  419. stores = types;
  420. }
  421. #endregion
  422. }