Browse Source

- Added custom functions for dealing with 'decimal' values in SQLite aggregates and formulae.
- Added a way to customise type schema in DBFactory, for testing purposes.

Kenric Nugteren 2 weeks ago
parent
commit
13117f52f0

+ 9 - 1
InABox.Core/Aggregate.cs

@@ -140,6 +140,8 @@ namespace InABox.Core
     {
         Type TAggregate { get; }
 
+        Type TResult { get; }
+
         IComplexFormulaNode GetExpression();
 
         AggregateCalculation GetCalculation();
@@ -161,6 +163,8 @@ namespace InABox.Core
 
         Type IComplexFormulaAggregateNode.TAggregate => typeof(TAggregate);
 
+        Type IComplexFormulaAggregateNode.TResult => typeof(TResult);
+
         public ComplexFormulaAggregateNode(IComplexFormulaNode<TAggregate, TResult> expression, AggregateCalculation calculation, Filter<TAggregate>? filter, Dictionary<Expression<Func<TAggregate, object?>>, Expression<Func<TType, object?>>> links)
         {
             Expression = expression;
@@ -260,6 +264,8 @@ namespace InABox.Core
 
     public interface IComplexFormulaFormulaNode : IComplexFormulaNode
     {
+        Type TResult { get; }
+
         IEnumerable<IComplexFormulaNode> GetOperands();
 
         FormulaOperator GetOperator();
@@ -270,6 +276,8 @@ namespace InABox.Core
 
         public FormulaOperator Operator { get; set; }
 
+        Type IComplexFormulaFormulaNode.TResult => typeof(TResult);
+
         public ComplexFormulaFormulaNode(IComplexFormulaNode<TType, TResult>[] operands, FormulaOperator op)
         {
             Operands = operands;
@@ -430,7 +438,7 @@ namespace InABox.Core
 
     public interface IComplexFormulaGenerator<TType, TResult>
     {
-        IComplexFormulaNode<TType, TResult> Property(Expression<Func<TType, TResult>> epxression);
+        IComplexFormulaNode<TType, TResult> Property(Expression<Func<TType, TResult>> expression);
 
         IComplexFormulaNode<TType, TResult> Formula(FormulaOperator op, params IComplexFormulaNode<TType, TResult>[] operands);
 

+ 4 - 2
InABox.Core/Query/Column.cs

@@ -22,6 +22,8 @@ namespace InABox.Core
         /// Every column needs a name to distinguish it from other columns, for example in query result tables, or in SQL queries.
         /// </remarks>
         string Name { get; }
+
+        Type Type { get; }
     }
 
     public interface IBaseColumns
@@ -38,6 +40,8 @@ namespace InABox.Core
     {
         public string Name { get; }
 
+        public Type Type => typeof(TResult);
+
         public IComplexFormulaNode<T, TResult> Formula { get; }
 
         public ComplexColumn(string name, IComplexFormulaNode<T, TResult> formula)
@@ -57,8 +61,6 @@ namespace InABox.Core
     {
         string Property { get; }
 
-        Type Type { get; }
-
         string IBaseColumn.Name => Property;
     }
 

+ 16 - 6
InABox.Database/DbFactory.cs

@@ -66,16 +66,26 @@ public static class DbFactory
 
     public static IProvider NewProvider(Logger logger) => ProviderFactory.NewProvider(logger);
 
-    public static void Start()
+    public static void Start(Type[]? types = null)
     {
         CoreUtils.CheckLicensing();
+
+        if(types is not null)
+        {
+            ProviderFactory.Types = types.Concat(CoreUtils.IterateTypes(typeof(CoreUtils).Assembly).Where(x => !x.IsAbstract))
+                .Where(x => x.IsClass && !x.IsGenericType && x.IsSubclassOf(typeof(Entity)))
+                .ToArray();
+        }
+        else
+        {
+            ProviderFactory.Types = Entities.Where(x =>
+                x.IsClass
+                && !x.IsGenericType
+                && x.IsSubclassOf(typeof(Entity))
+            ).ToArray();
+        }
         
         // Start the provider
-        ProviderFactory.Types = Entities.Where(x =>
-            x.IsClass
-            && !x.IsGenericType
-            && x.IsSubclassOf(typeof(Entity))
-        ).ToArray();
 
         ProviderFactory.Start();
 

+ 204 - 25
inabox.database.sqlite/SQLiteProvider.cs

@@ -1,5 +1,6 @@
 using System.Collections;
 using System.Data;
+using System.Data.Common;
 using System.Data.SQLite;
 using System.Diagnostics.CodeAnalysis;
 using System.Linq.Expressions;
@@ -41,12 +42,163 @@ internal abstract class SQLiteAccessor : IDisposable
 
         Connection = new SQLiteConnection(conn);
         Connection.BusyTimeout = Convert.ToInt32(TimeSpan.FromMinutes(2).TotalMilliseconds);
+
         Connection.Open();
+
         Connection.SetLimitOption(SQLiteLimitOpsEnum.SQLITE_LIMIT_VARIABLE_NUMBER, 10000);
         ++nConnections;
     }
 }
 
+#region Custom Decimal Functions
+
+[SQLiteFunction(Name = "DECIMAL_SUM", Arguments = 1, FuncType = FunctionType.Aggregate)]
+public class SQLiteDecimalSum : SQLiteFunction
+{
+    public override void Step(object[] args, int stepNumber, ref object contextData)
+    {
+        if (args.Length < 1 || args[0] == DBNull.Value)
+            return;
+        decimal d = Convert.ToDecimal(args[0]);
+        if (contextData != null) d += (decimal)contextData;
+        contextData = d;
+    }
+
+    public override object Final(object contextData)
+    {
+        return contextData;
+    }
+}
+
+[SQLiteFunction(Name = "DECIMAL_ADD", Arguments = -1, FuncType = FunctionType.Scalar)]
+public class SQLiteDecimalAdd : SQLiteFunction
+{
+    public override object? Invoke(object[] args)
+    {
+        var result = 0.0M;
+        for(int i = 0; i < args.Length; ++i)
+        {
+            var arg = args[i];
+            if(arg == DBNull.Value)
+            {
+                return null;
+            }
+            else
+            {
+                result += Convert.ToDecimal(arg);
+            }
+        }
+        return result;
+    }
+}
+[SQLiteFunction(Name = "DECIMAL_SUB", Arguments = -1, FuncType = FunctionType.Scalar)]
+public class SQLiteDecimalSub : SQLiteFunction
+{
+    public override object? Invoke(object[] args)
+    {
+        if(args.Length == 0)
+        {
+            return 0.0M;
+        }
+        else if(args.Length == 1)
+        {
+            if (args[0] == DBNull.Value)
+            {
+                return null;
+            }
+            else
+            {
+                return -Convert.ToDecimal(args[0]);
+            }
+        }
+        else
+        {
+            if (args[0] == DBNull.Value)
+            {
+                return null;
+            }
+            var result = Convert.ToDecimal(args[0]);
+            foreach(var arg in args.Skip(1))
+            {
+                if(arg == DBNull.Value)
+                {
+                    return null;
+                }
+                result -= Convert.ToDecimal(arg);
+            }
+            return result;
+        }
+    }
+}
+[SQLiteFunction(Name = "DECIMAL_MUL", Arguments = -1, FuncType = FunctionType.Scalar)]
+public class SQLiteDecimalMult : SQLiteFunction
+{
+    public override object? Invoke(object[] args)
+    {
+        var result = 1.0M;
+        foreach(var arg in args)
+        {
+            if(arg == DBNull.Value)
+            {
+                return null;
+            }
+            result *= Convert.ToDecimal(arg);
+        }
+        return result;
+    }
+}
+[SQLiteFunction(Name = "DECIMAL_DIV", Arguments = -1, FuncType = FunctionType.Scalar)]
+public class SQLiteDecimalDiv : SQLiteFunction
+{
+    public override object? Invoke(object[] args)
+    {
+        if(args.Length == 0)
+        {
+            return 1.0M;
+        }
+        else if(args.Length == 1)
+        {
+            if (args[0] == DBNull.Value)
+            {
+                return null;
+            }
+            else
+            {
+                var denom = Convert.ToDecimal(args[0]);
+                if(denom == 0M)
+                {
+                    return new Exception("Attempt to divide by zero.");
+                }
+                return 1.0M / denom;
+            }
+        }
+        else
+        {
+            if (args[0] == DBNull.Value)
+            {
+                return null;
+            }
+            var result = Convert.ToDecimal(args[0]);
+            foreach(var arg in args.Skip(1))
+            {
+                if(arg == DBNull.Value)
+                {
+                    return null;
+                }
+                var denom = Convert.ToDecimal(arg);
+                if(denom == 0M)
+                {
+                    return new Exception("Attempt to divide by zero.");
+                }
+                result /= denom;
+            }
+            return result;
+        }
+    }
+}
+
+#endregion
+
 internal class SQLiteReadAccessor : SQLiteAccessor
 {
     public SQLiteReadAccessor(string url)
@@ -478,7 +630,7 @@ public class SQLiteProviderFactory : IProviderFactory
         if (type == typeof(byte[]))
             return "BLOB";
 
-        if (type.IsFloatingPoint())
+        if (type.IsFloatingPoint() || type == typeof(decimal))
             return "NUM";
 
         if (type.GetInterfaces().Contains(typeof(IPackable)))
@@ -1720,11 +1872,11 @@ public class SQLiteProvider : IProvider
         Dictionary<string, string> fieldmap, List<string> columns, bool useparams)
         => GetSortClauseNonGeneric(typeof(T), command, sort, prefix, tables, fieldmap, columns, useparams);
 
-    private static string GetCalculation(AggregateCalculation calculation, string columnname)
+    private static string GetCalculation(AggregateCalculation calculation, string columnname, Type TResult)
     {
         return calculation switch
         {
-            AggregateCalculation.Sum => "SUM",
+            AggregateCalculation.Sum => TResult == typeof(decimal) ? "DECIMAL_SUM" : "SUM",
             AggregateCalculation.Count => "COUNT",
             AggregateCalculation.Maximum => "MAX",
             AggregateCalculation.Minimum => "MIN",
@@ -1866,7 +2018,7 @@ public class SQLiteProvider : IProvider
 
                 var aggregates = new Dictionary<string, string>
                 {
-                    { aggCol, GetCalculation(agg.GetCalculation(), aggCol) }
+                    { aggCol, GetCalculation(agg.GetCalculation(), aggCol, agg.TResult) }
                 };
 
                 var subquery = string.Format("({0})",
@@ -1947,48 +2099,75 @@ public class SQLiteProvider : IProvider
                 switch (op)
                 {
                     case FormulaOperator.Add:
-                        if(operands.Count == 0)
+                        if(formula.TResult == typeof(decimal))
                         {
-                            return "0.00";
+                            return $"DECIMAL_ADD({string.Join(',', operands)})";
                         }
                         else
                         {
-                            return $"({string.Join('+', operands)})";
+                            if(operands.Count == 0)
+                            {
+                                return "0.00";
+                            }
+                            {
+                                return $"({string.Join('+', operands)})";
+                            }
                         }
                     case FormulaOperator.Subtract:
-                        if(operands.Count == 0)
-                        {
-                            return "0.00";
-                        }
-                        else if(operands.Count == 1)
+                        if (formula.TResult == typeof(decimal))
                         {
-                            return $"(-{operands[0]})";
+                            return $"DECIMAL_SUB({string.Join(',', operands)})";
                         }
                         else
                         {
-                            return $"({string.Join('-', operands)})";
+                            if (operands.Count == 0)
+                            {
+                                return "0.00";
+                            }
+                            else if (operands.Count == 1)
+                            {
+                                return $"(-{operands[0]})";
+                            }
+                            else
+                            {
+                                return $"({string.Join('-', operands)})";
+                            }
                         }
                     case FormulaOperator.Multiply:
-                        if(operands.Count == 0)
+                        if (formula.TResult == typeof(decimal))
                         {
-                            return "1.00";
+                            return $"DECIMAL_MUL({string.Join(',', operands)})";
                         }
                         else
                         {
-                            return $"({string.Join('*', operands)})";
+                            if (operands.Count == 0)
+                            {
+                                return "1.00";
+                            }
+                            else
+                            {
+                                return $"({string.Join('*', operands)})";
+                            }
                         }
                     case FormulaOperator.Divide:
-                        if(operands.Count == 0)
-                        {
-                            return "1.00";
-                        }
-                        else if(operands.Count == 1)
+                        if (formula.TResult == typeof(decimal))
                         {
-                            return $"(1.00 / {operands[0]})";
+                            return $"DECIMAL_DIV({string.Join(',', operands)})";
                         }
                         else
                         {
-                            return $"({string.Join('/', operands)})";
+                            if (operands.Count == 0)
+                            {
+                                return "1.00";
+                            }
+                            else if (operands.Count == 1)
+                            {
+                                return $"(1.00 / {operands[0]})";
+                            }
+                            else
+                            {
+                                return $"({string.Join('/', operands)})";
+                            }
                         }
                     case FormulaOperator.Maximum:
                         return $"MAX({string.Join(',', operands)})";
@@ -2056,7 +2235,7 @@ public class SQLiteProvider : IProvider
 
                             if (!internalaggregate)
                             {
-                                var scols = new Dictionary<string, string> { { agg.Aggregate, GetCalculation(agg.Calculation, baseCol.Name) } };
+                                var scols = new Dictionary<string, string> { { agg.Aggregate, GetCalculation(agg.Calculation, baseCol.Name, baseCol.Type) } };
 
                                 var linkedtype = agg.Source;
                                 /*var siblings = columns.Where(x => !x.Equals(baseCol.Name) && x.Split('.').First().Equals(bits.First()))