DbFactory.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. using System.Reflection;
  2. using InABox.Clients;
  3. using InABox.Configuration;
  4. using InABox.Core;
  5. using InABox.Scripting;
  6. namespace InABox.Database;
  7. public class DatabaseMetadata : BaseObject, IGlobalConfigurationSettings
  8. {
  9. public Guid DatabaseID { get; set; } = Guid.NewGuid();
  10. }
  11. public static class DbFactory
  12. {
  13. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  14. private static DatabaseMetadata MetaData { get; set; }
  15. public static Guid ID
  16. {
  17. get => MetaData.DatabaseID;
  18. set
  19. {
  20. MetaData.DatabaseID = value;
  21. SaveMetadata();
  22. }
  23. }
  24. private static IProvider? _provider;
  25. public static IProvider Provider
  26. {
  27. get => _provider ?? throw new Exception("Provider is not set");
  28. set => _provider = value;
  29. }
  30. public static bool IsProviderSet => _provider is not null;
  31. public static string? ColorScheme { get; set; }
  32. public static byte[]? Logo { get; set; }
  33. // See notes in Request.DatabaseInfo class
  34. // Once RPC transport is stable, these settings need
  35. // to be removed
  36. public static int RestPort { get; set; }
  37. public static int RPCPort { get; set; }
  38. //public static Type[] Entities { get { return entities; } set { SetEntityTypes(value); } }
  39. public static IEnumerable<Type> Entities
  40. {
  41. get { return CoreUtils.Entities.Where(x => x.GetInterfaces().Contains(typeof(IPersistent))); }
  42. }
  43. public static Type[] Stores
  44. {
  45. get => stores;
  46. set => SetStoreTypes(value);
  47. }
  48. public static DateTime Expiry { get; set; }
  49. public static void Start()
  50. {
  51. CoreUtils.CheckLicensing();
  52. var status = ValidateSchema();
  53. if (status.Equals(SchemaStatus.New))
  54. try
  55. {
  56. Provider.CreateSchema(ConsolidatedObjectModel().ToArray());
  57. SaveSchema();
  58. }
  59. catch (Exception err)
  60. {
  61. throw new Exception(string.Format("Unable to Create Schema\n\n{0}", err.Message));
  62. }
  63. else if (status.Equals(SchemaStatus.Changed))
  64. try
  65. {
  66. Provider.UpgradeSchema(ConsolidatedObjectModel().ToArray());
  67. SaveSchema();
  68. }
  69. catch (Exception err)
  70. {
  71. throw new Exception(string.Format("Unable to Update Schema\n\n{0}", err.Message));
  72. }
  73. // Start the provider
  74. Provider.Types = ConsolidatedObjectModel();
  75. Provider.OnLog += LogMessage;
  76. Provider.Start();
  77. CheckMetadata();
  78. if (!DataUpdater.MigrateDatabase())
  79. {
  80. throw new Exception("Database migration failed. Aborting startup");
  81. }
  82. //Load up your custom properties here!
  83. // Can't use clients (b/c were inside the database layer already
  84. // but we can simply access the store directly :-)
  85. //CustomProperty[] props = FindStore<CustomProperty>("", "", "", "").Load(new Filter<CustomProperty>(x=>x.ID).IsNotEqualTo(Guid.Empty),null);
  86. var props = Provider.Query<CustomProperty>().Rows.Select(x => x.ToObject<CustomProperty>()).ToArray();
  87. DatabaseSchema.Load(props);
  88. AssertLicense();
  89. BeginLicenseCheckTimer();
  90. InitStores();
  91. LoadScripts();
  92. }
  93. #region MetaData
  94. private static void SaveMetadata()
  95. {
  96. var settings = new GlobalSettings
  97. {
  98. Section = nameof(DatabaseMetadata),
  99. Key = "",
  100. Contents = Serialization.Serialize(MetaData)
  101. };
  102. DbFactory.Provider.Save(settings);
  103. }
  104. private static void CheckMetadata()
  105. {
  106. var result = DbFactory.Provider.Query(new Filter<GlobalSettings>(x => x.Section).IsEqualTo(nameof(DatabaseMetadata)))
  107. .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
  108. var data = result is not null ? Serialization.Deserialize<DatabaseMetadata>(result.Contents) : null;
  109. if (data is null)
  110. {
  111. MetaData = new DatabaseMetadata();
  112. SaveMetadata();
  113. }
  114. else
  115. {
  116. MetaData = data;
  117. }
  118. }
  119. #endregion
  120. #region License
  121. private enum LicenseValidation
  122. {
  123. Valid,
  124. Missing,
  125. Expired,
  126. Corrupt,
  127. Tampered
  128. }
  129. private static LicenseValidation CheckLicenseValidity(out License? license, out LicenseData? licenseData)
  130. {
  131. license = Provider.Load<License>().FirstOrDefault();
  132. if (license is null)
  133. {
  134. licenseData = null;
  135. return LicenseValidation.Missing;
  136. }
  137. if (!LicenseUtils.TryDecryptLicense(license.Data, out licenseData, out var error))
  138. return LicenseValidation.Corrupt;
  139. if (licenseData.Expiry < DateTime.Now)
  140. return LicenseValidation.Expired;
  141. var userTrackingItems = Provider.Query(
  142. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  143. new Columns<UserTracking>(x => x.ID), log: false).Rows.Select(x => x.Get<UserTracking, Guid>(x => x.ID));
  144. foreach(var item in licenseData.UserTrackingItems)
  145. {
  146. if (!userTrackingItems.Contains(item))
  147. {
  148. return LicenseValidation.Tampered;
  149. }
  150. }
  151. return LicenseValidation.Valid;
  152. }
  153. private static int _expiredLicenseCounter = 0;
  154. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  155. private static bool _readOnly;
  156. public static bool IsReadOnly { get => _readOnly; }
  157. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  158. private static void LogRenew(string message)
  159. {
  160. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  161. }
  162. private static void LogLicenseExpiry(DateTime expiry)
  163. {
  164. if (expiry.Date == DateTime.Today)
  165. {
  166. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  167. return;
  168. }
  169. var diffInDays = (expiry - DateTime.Now).TotalDays;
  170. if(diffInDays < 1)
  171. {
  172. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  173. }
  174. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  175. {
  176. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  177. _expiredLicenseCounter = 0;
  178. }
  179. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  180. {
  181. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  182. _expiredLicenseCounter = 0;
  183. }
  184. ++_expiredLicenseCounter;
  185. }
  186. public static void LogReadOnly()
  187. {
  188. LogError("Database is read-only because your license is invalid!");
  189. }
  190. private static void BeginReadOnly()
  191. {
  192. LogImportant("Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  193. _readOnly = true;
  194. }
  195. private static void EndReadOnly()
  196. {
  197. LogImportant("Valid license found; the database is no longer read-only.");
  198. _readOnly = false;
  199. }
  200. private static void BeginLicenseCheckTimer()
  201. {
  202. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  203. LicenseTimer.Start();
  204. }
  205. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  206. {
  207. AssertLicense();
  208. }
  209. private static Random LicenseIDGenerate = new Random();
  210. private static void UpdateValidLicense(License license, LicenseData licenseData)
  211. {
  212. var ids = Provider.Query(
  213. new Filter<UserTracking>(x => x.Created).IsGreaterThanOrEqualTo(licenseData.LastRenewal),
  214. new Columns<UserTracking>(x => x.ID), log: false);
  215. var newIDList = new List<Guid>();
  216. if(ids.Rows.Count > 0)
  217. {
  218. for (int i = 0; i < 10; i++)
  219. {
  220. newIDList.Add(ids.Rows[LicenseIDGenerate.Next(0, ids.Rows.Count)].Get<UserTracking, Guid>(x => x.ID));
  221. }
  222. }
  223. licenseData.UserTrackingItems = newIDList.ToArray();
  224. if(LicenseUtils.TryEncryptLicense(licenseData, out var newData, out var error))
  225. {
  226. license.Data = newData;
  227. Provider.Save(license);
  228. }
  229. }
  230. private static void AssertLicense()
  231. {
  232. var result = CheckLicenseValidity(out var license, out var licenseData);
  233. if (IsReadOnly)
  234. {
  235. if(result == LicenseValidation.Valid)
  236. {
  237. EndReadOnly();
  238. }
  239. return;
  240. }
  241. // TODO: Switch to real system
  242. if(result != LicenseValidation.Valid)
  243. {
  244. var newLicense = LicenseUtils.GenerateNewLicense();
  245. if (LicenseUtils.TryEncryptLicense(newLicense, out var newData, out var error))
  246. {
  247. if (license == null)
  248. license = new License();
  249. license.Data = newData;
  250. Provider.Save(license);
  251. }
  252. else
  253. {
  254. Logger.Send(LogType.Error, "", $"Error updating license: {error}");
  255. }
  256. return;
  257. }
  258. else
  259. {
  260. return;
  261. }
  262. switch (result)
  263. {
  264. case LicenseValidation.Valid:
  265. LogLicenseExpiry(licenseData!.Expiry);
  266. UpdateValidLicense(license, licenseData);
  267. break;
  268. case LicenseValidation.Missing:
  269. LogImportant("Database is unlicensed!");
  270. BeginReadOnly();
  271. break;
  272. case LicenseValidation.Expired:
  273. LogImportant("Database license has expired!");
  274. BeginReadOnly();
  275. break;
  276. case LicenseValidation.Corrupt:
  277. LogImportant("Database license is corrupt - you will need to renew your license.");
  278. BeginReadOnly();
  279. break;
  280. case LicenseValidation.Tampered:
  281. LogImportant("Database license has been tampered with - you will need to renew your license.");
  282. BeginReadOnly();
  283. break;
  284. }
  285. }
  286. #endregion
  287. #region Logging
  288. private static void LogMessage(LogType type, string message)
  289. {
  290. Logger.Send(type, "", message);
  291. }
  292. private static void LogInfo(string message)
  293. {
  294. Logger.Send(LogType.Information, "", message);
  295. }
  296. private static void LogImportant(string message)
  297. {
  298. Logger.Send(LogType.Important, "", message);
  299. }
  300. private static void LogError(string message)
  301. {
  302. Logger.Send(LogType.Error, "", message);
  303. }
  304. #endregion
  305. public static void InitStores()
  306. {
  307. foreach (var storetype in stores)
  308. {
  309. var store = (Activator.CreateInstance(storetype) as IStore)!;
  310. store.Provider = Provider;
  311. store.Init();
  312. }
  313. }
  314. public static IStore FindStore(Type type, Guid userguid, string userid, Platform platform, string version)
  315. {
  316. var defType = typeof(Store<>).MakeGenericType(type);
  317. Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  318. var store = (Activator.CreateInstance(subType ?? defType) as IStore)!;
  319. store.Provider = Provider;
  320. store.UserGuid = userguid;
  321. store.UserID = userid;
  322. store.Platform = platform;
  323. store.Version = version;
  324. return store;
  325. }
  326. public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, Platform platform, string version)
  327. where TEntity : Entity, new()
  328. {
  329. return (FindStore(typeof(TEntity), userguid, userid, platform, version) as IStore<TEntity>)!;
  330. }
  331. private static CoreTable DoQueryMultipleQuery<TEntity>(
  332. IQueryDef query,
  333. Guid userguid, string userid, Platform platform, string version)
  334. where TEntity : Entity, new()
  335. {
  336. var store = FindStore<TEntity>(userguid, userid, platform, version);
  337. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  338. }
  339. public static Dictionary<string, CoreTable> QueryMultiple(
  340. Dictionary<string, IQueryDef> queries,
  341. Guid userguid, string userid, Platform platform, string version)
  342. {
  343. var result = new Dictionary<string, CoreTable>();
  344. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  345. var tasks = new List<Task>();
  346. foreach (var item in queries)
  347. tasks.Add(Task.Run(() =>
  348. {
  349. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(Provider, new object[]
  350. {
  351. item.Value,
  352. userguid, userid, platform, version
  353. }) as CoreTable)!;
  354. }));
  355. Task.WaitAll(tasks.ToArray());
  356. return result;
  357. }
  358. #region Supported Types
  359. private class ModuleConfiguration : Dictionary<string, bool>, ILocalConfigurationSettings
  360. {
  361. }
  362. private static Type[]? _dbtypes;
  363. public static IEnumerable<string> SupportedTypes()
  364. {
  365. _dbtypes ??= LoadSupportedTypes();
  366. return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));
  367. }
  368. private static Type[] LoadSupportedTypes()
  369. {
  370. var result = new List<Type>();
  371. var path = Provider.URL.ToLower();
  372. var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();
  373. var bChanged = false;
  374. foreach (var type in Entities)
  375. {
  376. var key = type.EntityName();
  377. if (config.ContainsKey(key))
  378. {
  379. if (config[key])
  380. //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));
  381. result.Add(type);
  382. else
  383. Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));
  384. }
  385. else
  386. {
  387. //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));
  388. config[key] = true;
  389. result.Add(type);
  390. bChanged = true;
  391. }
  392. }
  393. if (bChanged)
  394. new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);
  395. return result.ToArray();
  396. }
  397. public static bool IsSupported<T>() where T : Entity
  398. {
  399. _dbtypes ??= LoadSupportedTypes();
  400. return _dbtypes.Contains(typeof(T));
  401. }
  402. #endregion
  403. //public static void OpenSession(bool write)
  404. //{
  405. // Provider.OpenSession(write);
  406. //}
  407. //public static void CloseSession()
  408. //{
  409. // Provider.CloseSession();
  410. //}
  411. #region Private Methods
  412. public static void LoadScripts()
  413. {
  414. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  415. LoadedScripts.Clear();
  416. var scripts = Provider.Load(
  417. new Filter<Script>
  418. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  419. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  420. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  421. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  422. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  423. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  424. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  425. );
  426. foreach (var script in scripts)
  427. {
  428. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  429. var doc = new ScriptDocument(script.Code);
  430. if (doc.Compile())
  431. {
  432. Logger.Send(LogType.Information, "",
  433. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  434. LoadedScripts[key] = doc;
  435. }
  436. else
  437. {
  438. Logger.Send(LogType.Error, "",
  439. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  440. }
  441. }
  442. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  443. }
  444. //private static Type[] entities = null;
  445. //private static void SetEntityTypes(Type[] types)
  446. //{
  447. // foreach (Type type in types)
  448. // {
  449. // if (!type.IsSubclassOf(typeof(Entity)))
  450. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  451. // }
  452. // entities = types;
  453. //}
  454. private static Type[] stores = { };
  455. private static void SetStoreTypes(Type[] types)
  456. {
  457. types = types.Where(
  458. myType => myType.IsClass
  459. && !myType.IsAbstract
  460. && !myType.IsGenericType).ToArray();
  461. foreach (var type in types)
  462. if (!type.GetInterfaces().Contains(typeof(IStore)))
  463. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  464. stores = types;
  465. }
  466. private static Type[] ConsolidatedObjectModel()
  467. {
  468. // Add the core types from InABox.Core
  469. var types = new List<Type>();
  470. //var coreTypes = CoreUtils.TypeList(
  471. // new Assembly[] { typeof(Entity).Assembly },
  472. // myType =>
  473. // myType.IsClass
  474. // && !myType.IsAbstract
  475. // && !myType.IsGenericType
  476. // && myType.IsSubclassOf(typeof(Entity))
  477. // && myType.GetInterfaces().Contains(typeof(IRemotable))
  478. //);
  479. //types.AddRange(coreTypes);
  480. // Now add the end-user object model
  481. types.AddRange(Entities.Where(x =>
  482. x.GetTypeInfo().IsClass
  483. && !x.GetTypeInfo().IsGenericType
  484. && x.GetTypeInfo().IsSubclassOf(typeof(Entity))
  485. ));
  486. return types.ToArray();
  487. }
  488. private enum SchemaStatus
  489. {
  490. New,
  491. Changed,
  492. Validated
  493. }
  494. private static Dictionary<string, Type> GetSchema()
  495. {
  496. var model = new Dictionary<string, Type>();
  497. var objectmodel = ConsolidatedObjectModel();
  498. foreach (var type in objectmodel)
  499. {
  500. Dictionary<string, Type> thismodel = CoreUtils.PropertyList(type, x => true, true);
  501. foreach (var key in thismodel.Keys)
  502. model[type.Name + "." + key] = thismodel[key];
  503. }
  504. return model;
  505. //return Serialization.Serialize(model, Formatting.Indented);
  506. }
  507. private static SchemaStatus ValidateSchema()
  508. {
  509. var db_schema = Provider.GetSchema();
  510. if (db_schema.Count() == 0)
  511. return SchemaStatus.New;
  512. var mdl_json = Serialization.Serialize(GetSchema());
  513. var db_json = Serialization.Serialize(db_schema);
  514. return mdl_json.Equals(db_json) ? SchemaStatus.Validated : SchemaStatus.Changed;
  515. }
  516. private static void SaveSchema()
  517. {
  518. Provider.SaveSchema(GetSchema());
  519. }
  520. #endregion
  521. }