DbFactory.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. using System.Composition;
  2. using System.Diagnostics.CodeAnalysis;
  3. using System.Globalization;
  4. using System.Reflection;
  5. using InABox.Clients;
  6. using InABox.Configuration;
  7. using InABox.Core;
  8. using InABox.Scripting;
  9. using Microsoft.CodeAnalysis.CSharp;
  10. namespace InABox.Database
  11. {
  12. public static class DbFactory
  13. {
  14. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  15. private static IProvider? _provider;
  16. public static IProvider Provider
  17. {
  18. get => _provider ?? throw new Exception("Provider is not set");
  19. set => _provider = value;
  20. }
  21. public static bool IsProviderSet => _provider is not null;
  22. public static string? ColorScheme { get; set; }
  23. public static byte[]? Logo { get; set; }
  24. //public static Type[] Entities { get { return entities; } set { SetEntityTypes(value); } }
  25. public static IEnumerable<Type> Entities
  26. {
  27. get { return CoreUtils.Entities.Where(x => x.GetInterfaces().Contains(typeof(IPersistent))); }
  28. }
  29. public static Type[] Stores
  30. {
  31. get => stores;
  32. set => SetStoreTypes(value);
  33. }
  34. public static DateTime Expiry { get; set; }
  35. public static void Start()
  36. {
  37. CoreUtils.CheckLicensing();
  38. var status = ValidateSchema();
  39. if (status.Equals(SchemaStatus.New))
  40. try
  41. {
  42. Provider.CreateSchema(ConsolidatedObjectModel().ToArray());
  43. SaveSchema();
  44. }
  45. catch (Exception err)
  46. {
  47. throw new Exception(string.Format("Unable to Create Schema\n\n{0}", err.Message));
  48. }
  49. else if (status.Equals(SchemaStatus.Changed))
  50. try
  51. {
  52. Provider.UpgradeSchema(ConsolidatedObjectModel().ToArray());
  53. SaveSchema();
  54. }
  55. catch (Exception err)
  56. {
  57. throw new Exception(string.Format("Unable to Update Schema\n\n{0}", err.Message));
  58. }
  59. // Start the provider
  60. Provider.Types = ConsolidatedObjectModel();
  61. Provider.OnLog += LogMessage;
  62. Provider.Start();
  63. if (!DataUpdater.MigrateDatabase())
  64. {
  65. throw new Exception("Database migration failed. Aborting startup");
  66. }
  67. //Load up your custom properties here!
  68. // Can't use clients (b/c were inside the database layer already
  69. // but we can simply access the store directly :-)
  70. //CustomProperty[] props = FindStore<CustomProperty>("", "", "", "").Load(new Filter<CustomProperty>(x=>x.ID).IsNotEqualTo(Guid.Empty),null);
  71. var props = Provider.Query<CustomProperty>().Rows.Select(x => x.ToObject<CustomProperty>()).ToArray();
  72. DatabaseSchema.Load(props);
  73. AssertLicense();
  74. BeginLicenseCheckTimer();
  75. InitStores();
  76. LoadScripts();
  77. }
  78. #region License
  79. private enum LicenseValidation
  80. {
  81. Valid,
  82. Missing,
  83. Expired,
  84. Corrupt,
  85. Tampered
  86. }
  87. private static LicenseValidation CheckLicenseValidity(out License? license, out LicenseData? licenseData)
  88. {
  89. license = Provider.Load<License>().FirstOrDefault();
  90. if (license is null)
  91. {
  92. licenseData = null;
  93. return LicenseValidation.Missing;
  94. }
  95. if (!LicenseUtils.TryDecryptLicense(license.Data, out licenseData, out var error))
  96. return LicenseValidation.Corrupt;
  97. if (licenseData.Expiry < DateTime.Now)
  98. return LicenseValidation.Expired;
  99. var userTrackingItems = Provider.Query(
  100. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  101. new Columns<UserTracking>(x => x.ID), log: false).Rows.Select(x => x.Get<UserTracking, Guid>(x => x.ID));
  102. foreach(var item in licenseData.UserTrackingItems)
  103. {
  104. if (!userTrackingItems.Contains(item))
  105. {
  106. return LicenseValidation.Tampered;
  107. }
  108. }
  109. return LicenseValidation.Valid;
  110. }
  111. private static int _expiredLicenseCounter = 0;
  112. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  113. private static bool _readOnly;
  114. public static bool IsReadOnly { get => _readOnly; }
  115. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  116. private static void LogRenew(string message)
  117. {
  118. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  119. }
  120. private static void LogLicenseExpiry(DateTime expiry)
  121. {
  122. if (expiry.Date == DateTime.Today)
  123. {
  124. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  125. return;
  126. }
  127. var diffInDays = (expiry - DateTime.Now).TotalDays;
  128. if(diffInDays < 1)
  129. {
  130. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  131. }
  132. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  133. {
  134. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  135. _expiredLicenseCounter = 0;
  136. }
  137. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  138. {
  139. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  140. _expiredLicenseCounter = 0;
  141. }
  142. ++_expiredLicenseCounter;
  143. }
  144. public static void LogReadOnly()
  145. {
  146. LogError("Database is read-only because your license is invalid!");
  147. }
  148. private static void BeginReadOnly()
  149. {
  150. LogImportant("Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  151. _readOnly = true;
  152. }
  153. private static void EndReadOnly()
  154. {
  155. LogImportant("Valid license found; the database is no longer read-only.");
  156. _readOnly = false;
  157. }
  158. private static void BeginLicenseCheckTimer()
  159. {
  160. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  161. LicenseTimer.Start();
  162. }
  163. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  164. {
  165. AssertLicense();
  166. }
  167. private static Random LicenseIDGenerate = new Random();
  168. private static void UpdateValidLicense(License license, LicenseData licenseData)
  169. {
  170. var ids = Provider.Query(
  171. new Filter<UserTracking>(x => x.Created).IsGreaterThanOrEqualTo(licenseData.LastRenewal),
  172. new Columns<UserTracking>(x => x.ID), log: false);
  173. var newIDList = new List<Guid>();
  174. if(ids.Rows.Count > 0)
  175. {
  176. for (int i = 0; i < 10; i++)
  177. {
  178. newIDList.Add(ids.Rows[LicenseIDGenerate.Next(0, ids.Rows.Count)].Get<UserTracking, Guid>(x => x.ID));
  179. }
  180. }
  181. licenseData.UserTrackingItems = newIDList.ToArray();
  182. if(LicenseUtils.TryEncryptLicense(licenseData, out var newData, out var error))
  183. {
  184. license.Data = newData;
  185. Provider.Save(license);
  186. }
  187. }
  188. private static void AssertLicense()
  189. {
  190. var result = CheckLicenseValidity(out var license, out var licenseData);
  191. if (IsReadOnly)
  192. {
  193. if(result == LicenseValidation.Valid)
  194. {
  195. EndReadOnly();
  196. }
  197. return;
  198. }
  199. // TODO: Switch to real system
  200. if(result != LicenseValidation.Valid)
  201. {
  202. var newLicense = LicenseUtils.GenerateNewLicense();
  203. if (LicenseUtils.TryEncryptLicense(newLicense, out var newData, out var error))
  204. {
  205. if (license == null)
  206. license = new License();
  207. license.Data = newData;
  208. Provider.Save(license);
  209. }
  210. else
  211. {
  212. Logger.Send(LogType.Error, "", $"Error updating license: {error}");
  213. }
  214. return;
  215. }
  216. else
  217. {
  218. return;
  219. }
  220. switch (result)
  221. {
  222. case LicenseValidation.Valid:
  223. LogLicenseExpiry(licenseData!.Expiry);
  224. UpdateValidLicense(license, licenseData);
  225. break;
  226. case LicenseValidation.Missing:
  227. LogImportant("Database is unlicensed!");
  228. BeginReadOnly();
  229. break;
  230. case LicenseValidation.Expired:
  231. LogImportant("Database license has expired!");
  232. BeginReadOnly();
  233. break;
  234. case LicenseValidation.Corrupt:
  235. LogImportant("Database license is corrupt - you will need to renew your license.");
  236. BeginReadOnly();
  237. break;
  238. case LicenseValidation.Tampered:
  239. LogImportant("Database license has been tampered with - you will need to renew your license.");
  240. BeginReadOnly();
  241. break;
  242. }
  243. }
  244. #endregion
  245. #region Logging
  246. private static void LogMessage(LogType type, string message)
  247. {
  248. Logger.Send(type, "", message);
  249. }
  250. private static void LogInfo(string message)
  251. {
  252. Logger.Send(LogType.Information, "", message);
  253. }
  254. private static void LogImportant(string message)
  255. {
  256. Logger.Send(LogType.Important, "", message);
  257. }
  258. private static void LogError(string message)
  259. {
  260. Logger.Send(LogType.Error, "", message);
  261. }
  262. #endregion
  263. public static void InitStores()
  264. {
  265. foreach (var storetype in stores)
  266. {
  267. var store = Activator.CreateInstance(storetype) as IStore;
  268. store.Provider = Provider;
  269. store.Init();
  270. }
  271. }
  272. public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, Platform platform, string version)
  273. where TEntity : Entity, new()
  274. {
  275. var defType = typeof(Store<>).MakeGenericType(typeof(TEntity));
  276. Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  277. var store = (Store<TEntity>)Activator.CreateInstance(subType ?? defType)!;
  278. store.Provider = Provider;
  279. store.UserGuid = userguid;
  280. store.UserID = userid;
  281. store.Platform = platform;
  282. store.Version = version;
  283. return store;
  284. }
  285. private static CoreTable DoQueryMultipleQuery<TEntity>(
  286. IQueryDef query,
  287. Guid userguid, string userid, Platform platform, string version)
  288. where TEntity : Entity, new()
  289. {
  290. var store = FindStore<TEntity>(userguid, userid, platform, version);
  291. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  292. }
  293. public static Dictionary<string, CoreTable> QueryMultiple(
  294. Dictionary<string, IQueryDef> queries,
  295. Guid userguid, string userid, Platform platform, string version)
  296. {
  297. var result = new Dictionary<string, CoreTable>();
  298. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  299. var tasks = new List<Task>();
  300. foreach (var item in queries)
  301. tasks.Add(Task.Run(() =>
  302. {
  303. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(Provider, new object[]
  304. {
  305. item.Value,
  306. userguid, userid, platform, version
  307. }) as CoreTable)!;
  308. }));
  309. Task.WaitAll(tasks.ToArray());
  310. return result;
  311. }
  312. #region Supported Types
  313. private class ModuleConfiguration : Dictionary<string, bool>, ILocalConfigurationSettings
  314. {
  315. }
  316. private static Type[]? _dbtypes;
  317. public static IEnumerable<string> SupportedTypes()
  318. {
  319. _dbtypes ??= LoadSupportedTypes();
  320. return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));
  321. }
  322. private static Type[] LoadSupportedTypes()
  323. {
  324. var result = new List<Type>();
  325. var path = Provider.URL.ToLower();
  326. var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();
  327. var bChanged = false;
  328. foreach (var type in Entities)
  329. {
  330. var key = type.EntityName();
  331. if (config.ContainsKey(key))
  332. {
  333. if (config[key])
  334. //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));
  335. result.Add(type);
  336. else
  337. Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));
  338. }
  339. else
  340. {
  341. //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));
  342. config[key] = true;
  343. result.Add(type);
  344. bChanged = true;
  345. }
  346. }
  347. if (bChanged)
  348. new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);
  349. return result.ToArray();
  350. }
  351. public static bool IsSupported<T>() where T : Entity
  352. {
  353. _dbtypes ??= LoadSupportedTypes();
  354. return _dbtypes.Contains(typeof(T));
  355. }
  356. #endregion
  357. //public static void OpenSession(bool write)
  358. //{
  359. // Provider.OpenSession(write);
  360. //}
  361. //public static void CloseSession()
  362. //{
  363. // Provider.CloseSession();
  364. //}
  365. #region Private Methods
  366. public static void LoadScripts()
  367. {
  368. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  369. LoadedScripts.Clear();
  370. var scripts = Provider.Load(
  371. new Filter<Script>
  372. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  373. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  374. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  375. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  376. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  377. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  378. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  379. );
  380. foreach (var script in scripts)
  381. {
  382. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  383. var doc = new ScriptDocument(script.Code);
  384. if (doc.Compile())
  385. {
  386. Logger.Send(LogType.Information, "",
  387. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  388. LoadedScripts[key] = doc;
  389. }
  390. else
  391. {
  392. Logger.Send(LogType.Error, "",
  393. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  394. }
  395. }
  396. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  397. }
  398. //private static Type[] entities = null;
  399. //private static void SetEntityTypes(Type[] types)
  400. //{
  401. // foreach (Type type in types)
  402. // {
  403. // if (!type.IsSubclassOf(typeof(Entity)))
  404. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  405. // }
  406. // entities = types;
  407. //}
  408. private static Type[] stores = { };
  409. private static void SetStoreTypes(Type[] types)
  410. {
  411. types = types.Where(
  412. myType => myType.IsClass
  413. && !myType.IsAbstract
  414. && !myType.IsGenericType).ToArray();
  415. foreach (var type in types)
  416. if (!type.GetInterfaces().Contains(typeof(IStore)))
  417. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  418. stores = types;
  419. }
  420. private static Type[] ConsolidatedObjectModel()
  421. {
  422. // Add the core types from InABox.Core
  423. var types = new List<Type>();
  424. //var coreTypes = CoreUtils.TypeList(
  425. // new Assembly[] { typeof(Entity).Assembly },
  426. // myType =>
  427. // myType.IsClass
  428. // && !myType.IsAbstract
  429. // && !myType.IsGenericType
  430. // && myType.IsSubclassOf(typeof(Entity))
  431. // && myType.GetInterfaces().Contains(typeof(IRemotable))
  432. //);
  433. //types.AddRange(coreTypes);
  434. // Now add the end-user object model
  435. types.AddRange(Entities.Where(x =>
  436. x.GetTypeInfo().IsClass
  437. && !x.GetTypeInfo().IsGenericType
  438. && x.GetTypeInfo().IsSubclassOf(typeof(Entity))
  439. ));
  440. return types.ToArray();
  441. }
  442. private enum SchemaStatus
  443. {
  444. New,
  445. Changed,
  446. Validated
  447. }
  448. private static Dictionary<string, Type> GetSchema()
  449. {
  450. var model = new Dictionary<string, Type>();
  451. var objectmodel = ConsolidatedObjectModel();
  452. foreach (var type in objectmodel)
  453. {
  454. Dictionary<string, Type> thismodel = CoreUtils.PropertyList(type, x => true, true);
  455. foreach (var key in thismodel.Keys)
  456. model[type.Name + "." + key] = thismodel[key];
  457. }
  458. return model;
  459. //return Serialization.Serialize(model, Formatting.Indented);
  460. }
  461. private static SchemaStatus ValidateSchema()
  462. {
  463. var db_schema = Provider.GetSchema();
  464. if (db_schema.Count() == 0)
  465. return SchemaStatus.New;
  466. var mdl_json = Serialization.Serialize(GetSchema());
  467. var db_json = Serialization.Serialize(db_schema);
  468. return mdl_json.Equals(db_json) ? SchemaStatus.Validated : SchemaStatus.Changed;
  469. }
  470. private static void SaveSchema()
  471. {
  472. Provider.SaveSchema(GetSchema());
  473. }
  474. #endregion
  475. }
  476. }