Browse Source

Added EntitySecurity attribute

Kenric Nugteren 2 weeks ago
parent
commit
48b19ccb4d

+ 53 - 7
InABox.Core/CoreUtils.cs

@@ -14,6 +14,7 @@ using System.Reflection;
 using System.Text;
 using System.Text.RegularExpressions;
 using System.Runtime.CompilerServices;
+using System.Threading.Tasks;
 
 namespace InABox.Core
 {
@@ -2731,19 +2732,19 @@ namespace InABox.Core
         }
 
         /// <summary>
-        /// Concatenate all <paramref name="arrays"/> together.
+        /// Concatenate all <paramref name="arrays"/> together into an array.
         /// </summary>
         /// <typeparam name="T"></typeparam>
         /// <param name="arrays"></param>
         /// <returns></returns>
-        public static T[] Concatenate<T>(params T[][] arrays)
+        public static T[] Concatenate<T>(params IList<T>[] arrays)
         {
-            var newArr = new T[arrays.Sum(x => x.Length)];
+            var newArr = new T[arrays.Sum(x => x.Count)];
             for(int i = 0, idx = 0; i < arrays.Length; ++i)
             {
                 var arr = arrays[i];
                 arr.CopyTo(newArr, idx);
-                idx += arr.Length;
+                idx += arr.Count;
             }
             return newArr;
         }
@@ -2916,6 +2917,43 @@ namespace InABox.Core
             list.Sort((a, b) => comparison(a).CompareTo(comparison(b)));
         }
 
+        public static Comparison<T> OrderBy<T, TProp>(Func<T, TProp> comparison)
+            where TProp : IComparable
+        {
+            return (a, b) => comparison(a).CompareTo(comparison(b));
+        }
+        public static Comparison<T> ThenBy<T>(this Comparison<T> comparison1, Comparison<T> comparison2)
+        {
+            return (x, y) =>
+            {
+                var result = comparison1(x, y);
+                if (result == 0)
+                {
+                    return comparison2(x, y);
+                }
+                else
+                {
+                    return result;
+                }
+            };
+        }
+        public static Comparison<T> ThenBy<T, TProp>(this Comparison<T> comparison1, Func<T, TProp> comparison2)
+            where TProp : IComparable
+        {
+            return (x, y) =>
+            {
+                var result = comparison1(x, y);
+                if (result == 0)
+                {
+                    return comparison2(x).CompareTo(comparison2(y));
+                }
+                else
+                {
+                    return result;
+                }
+            };
+        }
+
         /// <summary>
         /// Compare the elements in this list to the other list, returning <see langword="true"/> if they have all the same
         /// elements, regardless of order.
@@ -2961,8 +2999,16 @@ namespace InABox.Core
         }
 
         #endregion
-        
-        
-        
+
+        #region Task Utilities
+
+        public static void WaitAllNotNull(params Task?[] tasks)
+        {
+            var nonNull = tasks.NotNull().ToArray();
+            Task.WaitAll(nonNull);
+        }
+
+        #endregion
+
     }
 }

+ 18 - 0
InABox.Core/Security/EntitySecurityAttribute.cs

@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace InABox.Core
+{
+    /// <summary>
+    /// Allows for customising the security tokens for entity actions, overriding the <see cref="AutoSecurityDescriptor{TEntity, TAction}"/> for
+    /// a given <see cref="Entity"/>.
+    /// </summary>
+    [AttributeUsage(AttributeTargets.Class)]
+    public class EntitySecurityAttribute : Attribute
+    {
+        public Type? CanView { get; set; }
+        public Type? CanEdit { get; set; }
+        public Type? CanDelete { get; set; }
+    }
+}

+ 166 - 65
InABox.Core/Security/Security.cs

@@ -4,6 +4,7 @@ using System.Collections.Generic;
 using System.ComponentModel;
 using System.Linq;
 using System.Reflection;
+using System.Security;
 using System.Threading.Tasks;
 using InABox.Clients;
 
@@ -11,10 +12,10 @@ namespace InABox.Core
 {
     public static class Security
     {
-        private static ConcurrentBag<ISecurityDescriptor>? _descriptors;
+        private static ISecurityDescriptor[]? _descriptors;
         private static GlobalSecurityToken[]? _globaltokens;
-        private static SecurityToken[]? _grouptokens;
-        private static UserSecurityToken[]? _usertokens;
+        private static Dictionary<Guid, SecurityToken[]> _grouptokens = new Dictionary<Guid, SecurityToken[]>();
+        private static Dictionary<Guid, UserSecurityToken[]> _usertokens = new Dictionary<Guid, UserSecurityToken[]>();
 
         public static IEnumerable<ISecurityDescriptor> Descriptors
         {
@@ -22,125 +23,172 @@ namespace InABox.Core
             {
                 if (_descriptors == null)
                 {
-                    _descriptors = new ConcurrentBag<ISecurityDescriptor>();
+                    ISecurityDescriptor[] GetTokens(params Task<ISecurityDescriptor[]>[] tasks)
+                    {
+                        Task.WaitAll(tasks);
+                        return CoreUtils.Concatenate(tasks.ToArray(x => x.Result));
+                    }
 
                     var custom = Task.Run(() =>
                     {
-                        var tokens = CoreUtils.Entities.Where(
-                            x => !x.IsGenericType && x.HasInterface<ISecurityDescriptor>());
-                        foreach (var _class in tokens)
-                        {
-                            var token = (Activator.CreateInstance(_class) as ISecurityDescriptor)!;
-                            _descriptors.Add(token);
-                        }
+                        return CoreUtils.Entities.Where(x => !x.IsGenericType && x.HasInterface<ISecurityDescriptor>())
+                            .Select(x => Activator.CreateInstance(x) as ISecurityDescriptor)
+                            .NotNull()
+                            .ToArray();
                     });
 
+                    bool Overridden(Type @class, Func<EntitySecurityAttribute, Type?> getToken)
+                    {
+                        return @class.GetCustomAttribute<EntitySecurityAttribute>() is EntitySecurityAttribute attr && getToken(attr) != null;
+                    }
+
                     var auto = Task.Run(() =>
                     {
-                        var tokens = CoreUtils.Entities.Where( x => !x.IsGenericType && x.IsSubclassOf(typeof(Entity))).ToArray();
+                        var entities = CoreUtils.Entities.Where( x => !x.IsGenericType && x.IsSubclassOf(typeof(Entity))).ToArray();
                         var view = Task.Run(() =>
                         {
-                            foreach (var _class in tokens)
-                                CheckAutoToken(_class, typeof(CanView<>));
+                            return entities
+                                .Where(x => !Overridden(x, x => x.CanView))
+                                .Select(x => GetAutoToken(x, typeof(CanView<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var edit = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetCustomAttribute<AutoEntity>() == null))
-                                CheckAutoToken(_class, typeof(CanEdit<>));
+                            return entities
+                                .Where(x => !Overridden(x, x => x.CanEdit))
+                                .Select(x => GetAutoToken(x, typeof(CanEdit<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var delete = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetCustomAttribute<AutoEntity>() == null))
-                                CheckAutoToken(_class, typeof(CanDelete<>));
+                            return entities
+                                .Where(x => !Overridden(x, x => x.CanDelete))
+                                .Select(x => GetAutoToken(x, typeof(CanDelete<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var issues = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetInterfaces().Contains(typeof(IIssues))))
-                                CheckAutoToken(_class, typeof(CanManageIssues<>));
+                            return entities.Where(x => x.GetInterfaces().Contains(typeof(IIssues)))
+                                .Select(x => GetAutoToken(x, typeof(CanManageIssues<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var exports = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetInterfaces().Contains(typeof(IExportable))))
-                                CheckAutoToken(_class, typeof(CanExport<>));
+                            return entities.Where(x => x.GetInterfaces().Contains(typeof(IExportable)))
+                                .Select(x => GetAutoToken(x, typeof(CanExport<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var imports = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetInterfaces().Contains(typeof(IImportable))))
-                                CheckAutoToken(_class, typeof(CanImport<>));
+                            return entities.Where(x => x.GetInterfaces().Contains(typeof(IImportable)))
+                                .Select(x => GetAutoToken(x, typeof(CanImport<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var merges = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetInterfaces().Contains(typeof(IMergeable))))
-                                CheckAutoToken(_class, typeof(CanMerge<>));
+                            return entities.Where(x => x.GetInterfaces().Contains(typeof(IMergeable)))
+                                .Select(x => GetAutoToken(x, typeof(CanMerge<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var posts = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetInterfaces().Contains(typeof(IPostable))))
-                                CheckAutoToken(_class, typeof(CanPost<>));
+                            return entities.Where(x => x.GetInterfaces().Contains(typeof(IPostable)))
+                                .Select(x => GetAutoToken(x, typeof(CanPost<>)))
+                                .NotNull()
+                                .ToArray();
                         });
                         var configPosts = Task.Run(() =>
                         {
-                            foreach (var _class in tokens.Where(x => x.GetInterfaces().Contains(typeof(IPostable))))
-                                CheckAutoToken(_class, typeof(CanConfigurePost<>));
+                            return entities.Where(x => x.GetInterfaces().Contains(typeof(IPostable)))
+                                .Select(x => GetAutoToken(x, typeof(CanConfigurePost<>)))
+                                .NotNull()
+                                .ToArray();
                         });
-                        Task.WaitAll(view, edit, delete, issues, exports, merges, posts, configPosts);
+                        return GetTokens(view, edit, delete, issues, exports, merges, posts, configPosts);
                     });
-                    Task.WaitAll(custom, auto);
+                    _descriptors = GetTokens(custom, auto);
+                    Array.Sort(_descriptors, CoreUtils.OrderBy((ISecurityDescriptor x) => x.Type).ThenBy(x => x.Code));
                 }
 
-                return _descriptors.OrderBy(x => x.Type).ThenBy(x => x.Code);
+                return _descriptors;
             }
         }
 
         public static void Reset()
         {
             _globaltokens = null;
-            _grouptokens = null;
-            _usertokens = null;
+            _grouptokens.Clear();
+            _usertokens.Clear();
             _descriptors = null;
         }
 
-        public static void CheckTokens()
+        public static void CheckTokens(Guid userId, Guid securityID)
         {
-            _usertokens ??= Client.Query(
-                Filter<UserSecurityToken>.Where(x => x.User.ID).IsEqualTo(ClientFactory.UserGuid),
-                Columns.None<UserSecurityToken>().Add(x => x.Descriptor).Add(x => x.Enabled)
-            ).ToArray<UserSecurityToken>();
-            _grouptokens ??= Client.Query(
-                Filter<SecurityToken>.Where(x => x.Group.ID).IsEqualTo(ClientFactory.UserSecurityID),
-                Columns.None<SecurityToken>().Add(x => x.Descriptor).Add(x => x.Enabled)
-            ).ToArray<SecurityToken>();
-            _globaltokens ??= Client.Query(
-                null,
-                Columns.None<GlobalSecurityToken>().Add(x => x.Descriptor).Add(x => x.Enabled))
-            .ToArray<GlobalSecurityToken>();
+            var userTask = !_usertokens.ContainsKey(userId)
+                ? Client.QueryAsync(
+                    Filter<UserSecurityToken>.Where(x => x.User.ID).IsEqualTo(ClientFactory.UserGuid),
+                    Columns.None<UserSecurityToken>().Add(x => x.Descriptor).Add(x => x.Enabled))
+                : null;
+            var groupTask = !_grouptokens.ContainsKey(securityID)
+                ?  Client.QueryAsync(
+                    Filter<SecurityToken>.Where(x => x.Group.ID).IsEqualTo(ClientFactory.UserSecurityID),
+                    Columns.None<SecurityToken>().Add(x => x.Descriptor).Add(x => x.Enabled))
+                : null;
+            var globalTask = _globaltokens is null
+                ? Client.QueryAsync(
+                    null,
+                    Columns.None<GlobalSecurityToken>().Add(x => x.Descriptor).Add(x => x.Enabled))
+                : null;
+            if (userTask is null && groupTask is null && globalTask is null) return;
+
+            CoreUtils.WaitAllNotNull(userTask, groupTask, globalTask);
+            if(userTask != null)
+            {
+                _usertokens.Add(userId, userTask.Result.ToArray<UserSecurityToken>());
+            }
+            if(groupTask != null)
+            {
+                _grouptokens.Add(securityID, groupTask.Result.ToArray<SecurityToken>());
+            }
+            if(globalTask != null)
+            {
+                _globaltokens = globalTask.Result.ToArray<GlobalSecurityToken>();
+            }
         }
         
-        private static void CheckAutoToken(Type _class, Type type)
+        private static ISecurityDescriptor? GetAutoToken(Type _class, Type type)
         {
             var basetype = typeof(AutoSecurityDescriptor<,>);
             var actiontype = type.MakeGenericType(_class);
             var descriptortype = basetype.MakeGenericType(_class, actiontype);
             var descriptor = (Activator.CreateInstance(descriptortype) as ISecurityDescriptor)!;
-            if (!_descriptors.Any(x => string.Equals(x.Code, descriptor.Code)))
-                _descriptors.Add(descriptor);
+            return descriptor;
+            // if (!_descriptors.Any(x => string.Equals(x.Code, descriptor.Code)))
+            //     _descriptors.Add(descriptor);
         }
 
-        private static bool IsAllowedInternal(ISecurityDescriptor descriptor, Guid userGuid)
+        private static bool IsAllowedInternal(ISecurityDescriptor descriptor, Guid userGuid, Guid securityId)
         {
             // If you're not logged in, you can't do jack!
             if (userGuid == Guid.Empty)
                 return false;
 
-            CheckTokens();
+            CheckTokens(userGuid, securityId);
             
             // First Check for a matching User Token (override)
-            var usertoken = _usertokens.FirstOrDefault(x => x.Descriptor.Equals(descriptor.Code));
+            var usertoken = _usertokens[userGuid].FirstOrDefault(x => x.Descriptor.Equals(descriptor.Code));
             if (usertoken != null)
                 return usertoken.Enabled;
 
             // If not found, fall back to the Group Token
-            var grouptoken = _grouptokens.FirstOrDefault(x => x.Descriptor.Equals(descriptor.Code));
+            var grouptoken = _grouptokens[securityId].FirstOrDefault(x => x.Descriptor.Equals(descriptor.Code));
             if (grouptoken != null)
                 return grouptoken.Enabled;
 
@@ -158,7 +206,7 @@ namespace InABox.Core
             var descriptor = (Activator.CreateInstance(T) as ISecurityDescriptor)!;
             try
             {
-                if(IsAllowedInternal(descriptor, userGuid))
+                if(IsAllowedInternal(descriptor, userGuid, securityId))
                 {
                     if(descriptor is IDependentSecurityDescriptor dependent)
                     {
@@ -190,38 +238,75 @@ namespace InABox.Core
         public static bool IsAllowed(Type T) 
             => IsAllowed(T, ClientFactory.UserGuid, ClientFactory.UserSecurityID);
 
+        private static Type CreateAutoDescriptor(Type TAction, Type TEntity)
+        {
+            return typeof(AutoSecurityDescriptor<,>).MakeGenericType(TEntity, TAction.MakeGenericType(TEntity));
+        }
+
+        #region CanView
+
+        private static Type CanViewSecurityDescriptor(Type T)
+        {
+            var security = T.GetCustomAttribute<EntitySecurityAttribute>();
+            return security?.CanView ?? CreateAutoDescriptor(typeof(CanView<>), T);
+        }
+        private static Type CanViewSecurityDescriptor<T>()
+            where T : Entity, new()
+        {
+            var security = typeof(T).GetCustomAttribute<EntitySecurityAttribute>();
+            return security?.CanView ?? typeof(AutoSecurityDescriptor<T, CanView<T>>);
+        }
+
         public static bool CanView<TEntity>(Guid userGuid, Guid securityId) where TEntity : Entity, new()
         {
-            return IsAllowed<AutoSecurityDescriptor<TEntity, CanView<TEntity>>>(userGuid, securityId);
+            return IsAllowed(CanViewSecurityDescriptor<TEntity>(), userGuid, securityId);
         }
 
         public static bool CanView(Type TEntity)
         {
-            return IsAllowed(typeof(AutoSecurityDescriptor<,>).MakeGenericType(TEntity, typeof(CanView<>).MakeGenericType(TEntity)));
+            return IsAllowed(CanViewSecurityDescriptor(TEntity));
         }
         public static bool CanView<TEntity>() where TEntity : Entity, new()
         {
-            return IsAllowed<AutoSecurityDescriptor<TEntity, CanView<TEntity>>>();
+            return IsAllowed(CanViewSecurityDescriptor<TEntity>());
+        }
+
+        #endregion
+
+        #region CanEdit
+
+        private static Type CanEditSecurityDescriptor(Type T)
+        {
+            var security = T.GetCustomAttribute<EntitySecurityAttribute>();
+            return security?.CanEdit ?? CreateAutoDescriptor(typeof(CanEdit<>), T);
+        }
+        private static Type CanEditSecurityDescriptor<T>()
+            where T : Entity, new()
+        {
+            var security = typeof(T).GetCustomAttribute<EntitySecurityAttribute>();
+            return security?.CanEdit ?? typeof(AutoSecurityDescriptor<T, CanEdit<T>>);
         }
 
         public static bool CanEdit(Type TEntity, Guid userGuid, Guid securityId)
         {
-            return IsAllowed(typeof(AutoSecurityDescriptor<,>).MakeGenericType(TEntity, typeof(CanEdit<>).MakeGenericType(TEntity)), userGuid, securityId);
+            return IsAllowed(CanEditSecurityDescriptor(TEntity), userGuid, securityId);
         }
         public static bool CanEdit<TEntity>(Guid userGuid, Guid securityId) where TEntity : Entity, new()
         {
-            return IsAllowed<AutoSecurityDescriptor<TEntity, CanEdit<TEntity>>>(userGuid, securityId);
+            return IsAllowed(CanEditSecurityDescriptor<TEntity>(), userGuid, securityId);
         }
 
         public static bool CanEdit(Type TEntity)
         {
-            return IsAllowed(typeof(AutoSecurityDescriptor<,>).MakeGenericType(TEntity, typeof(CanEdit<>).MakeGenericType(TEntity)));
+            return IsAllowed(CanEditSecurityDescriptor(TEntity));
         }
         public static bool CanEdit<TEntity>() where TEntity : Entity, new()
         {
-            return IsAllowed<AutoSecurityDescriptor<TEntity, CanEdit<TEntity>>>();
+            return IsAllowed(CanEditSecurityDescriptor<TEntity>());
         }
 
+        #endregion
+
         public static bool CanImport<TEntity>() where TEntity : Entity, new()
         {
             return IsAllowed<AutoSecurityDescriptor<TEntity, CanImport<TEntity>>>();
@@ -247,16 +332,32 @@ namespace InABox.Core
             return IsAllowed<AutoSecurityDescriptor<TEntity, CanConfigurePost<TEntity>>>();
         }
 
+        #region CanDelete
+
+        private static Type CanDeleteSecurityDescriptor(Type T)
+        {
+            var security = T.GetCustomAttribute<EntitySecurityAttribute>();
+            return security?.CanDelete ?? CreateAutoDescriptor(typeof(CanDelete<>), T);
+        }
+        private static Type CanDeleteSecurityDescriptor<T>()
+            where T : Entity, new()
+        {
+            var security = typeof(T).GetCustomAttribute<EntitySecurityAttribute>();
+            return security?.CanDelete ?? typeof(AutoSecurityDescriptor<T, CanDelete<T>>);
+        }
+
         public static bool CanDelete<TEntity>() where TEntity : Entity, new()
         {
-            return IsAllowed<AutoSecurityDescriptor<TEntity, CanDelete<TEntity>>>();
+            return IsAllowed(CanDeleteSecurityDescriptor<TEntity>());
         }
         
         public static bool CanDelete(Type TEntity)
         {
-            return IsAllowed(typeof(AutoSecurityDescriptor<,>).MakeGenericType(TEntity, typeof(CanDelete<>).MakeGenericType(TEntity)));
+            return IsAllowed(CanDeleteSecurityDescriptor(TEntity));
         }
 
+        #endregion
+
         public static bool CanManageIssues(Type TEntity)
         {
             return IsAllowed(typeof(AutoSecurityDescriptor<,>).MakeGenericType(TEntity, typeof(CanManageIssues<>).MakeGenericType(TEntity)));