Generic Repository: Fake IDbSet implementation update (Find Method & Identity key)

UPDATE (again) Just a quick one: see https://github.com/refactorthis/GraphDiff/blob/master/EFDetachedUpdate/DetachedUpdate/DbContextExtensions.cs on line 209 for a replacement GetKeyProperties method which allows for convention and fluent API mapped keys (You no longer need to annotate your model with KeyAttribute)

UPDATE: Thanks to Eli Weinstock-Herman for pointing out the fact that Find should return null if no result is found (SingleOrDefault instead of Single). Cheers Eli.

Hey guys,

I've been back in the coding seat lately creating a new generic repository for a system that we are building. I've made some improvements to the FakeDbSet that I posted about earlier Here.

I want to add some notes to the previous post which are long enough they warrant a new post. Firstly, IT IS MUCH EASIER if you do not use foreign keys in your objects but instead use 'association' object references. This means you will not have to co-ordinate two different fields when setting up test data. Of course EF does this for you when connected to the database but in memory you would have to do this yourself.

Secondly the implementation of find was quite hard, though I believe I have come up with an elegant generic solution. If you look at the IDbSet documentation MSDN you will see that Find() expects the keys to passed in "the same order that they are defined in the model".

If I use reflection to find my key properties I can then iterate through the keys and ensure that each object given in the find method equals the value of that key, as shown below.

 private List<PropertyInfo> _keyProperties;

public virtual T Find(params object[] keyValues)  
        {
            if (keyValues.Length != _keyProperties.Count)
                throw new ArgumentException("Incorrect number of keys passed to find method");

            IQueryable<T> keyQuery = this.AsQueryable<T>();
            for (int i = 0; i < keyValues.Length; i++)
            {
                var x = i; // nested linq
                keyQuery = keyQuery
                   .Where(entity => _keyProperties[x].GetValue(entity, null).Equals(keyValues[x]));
            }

            return keyQuery.SingleOrDefault();
        }

        private void GetKeyProperties()
        {
            _keyProperties = new List<PropertyInfo>();
            PropertyInfo[] properties = typeof(T).GetProperties();
            foreach (PropertyInfo property in properties)
            {
                foreach (Attribute attribute in property.GetCustomAttributes(true))
                {
                    if (attribute is KeyAttribute)
                    {
                        _keyProperties.Add(property);
                    }
                }
            }
        }

Now thirdly I wanted the FakeDbSet to act like the database and use an identity column for properties that are ints and marked with the [Key] attribute. I made these changes here

private int _identity = 1;

private void GenerateId(T entity)  
{
     // If non-composite integer key
     if (_keyProperties.Count == 1 &amp;&amp; _keyProperties[0].PropertyType == typeof(Int32))
         _keyProperties[0].SetValue(entity, _identity++, null);
}

  public T Add(T item)
  {
      GenerateId(item);
      _data.Add(item);
      return item;
 }

Now of course this is being done in the Add method not the commit method as the database would. For my purposes this makes no difference. If however you want the key generation to be done on commit then you need to keep an un-comitted list inside of the FakeDbSet and then when commit is called you would iterate the list generating id's for each element and then adding them to the 'comitted' list.

Here is the new FakeDbSet implementation

public class FakeDbSet<T> : IDbSet<T> where T : class  
    {
        private readonly HashSet<T> _data;
        private readonly IQueryable _query;
        private int _identity = 1;
        private List<PropertyInfo> _keyProperties;

        private void GetKeyProperties()
        {
            _keyProperties = new List<PropertyInfo>();
            PropertyInfo[] properties = typeof(T).GetProperties();
            foreach (PropertyInfo property in properties)
            {
                foreach (Attribute attribute in property.GetCustomAttributes(true))
                {
                    if (attribute is KeyAttribute)
                    {
                        _keyProperties.Add(property);
                    }
                }
            }
        }

        private void GenerateId(T entity)
        {
            // If non-composite integer key
            if (_keyProperties.Count == 1 &amp;&amp; _keyProperties[0].PropertyType == typeof(Int32))
                _keyProperties[0].SetValue(entity, _identity++, null);
        }

        public FakeDbSet(IEnumerable<T> startData = null)
        {
            GetKeyProperties();
            _data = (startData != null ? new HashSet<T>(startData) : new HashSet<T>());
            _query = _data.AsQueryable();
        }

        public virtual T Find(params object[] keyValues)
        {
            if (keyValues.Length != _keyProperties.Count)
                throw new ArgumentException("Incorrect number of keys passed to find method");

            IQueryable<T> keyQuery = this.AsQueryable<T>();
            for (int i = 0; i < keyValues.Length; i++)
            {
                var x = i; // nested linq
                keyQuery = keyQuery.Where(entity => _keyProperties[x].GetValue(entity, null).Equals(keyValues[x]));
            }

            return keyQuery.SingleOrDefault();
        }

        public T Add(T item)
        {
            GenerateId(item);
            _data.Add(item);
            return item;
        }

        public T Remove(T item)
        {
            _data.Remove(item);
            return item;
        }

        public T Attach(T item)
        {
            _data.Add(item);
            return item;
        }

        public void Detach(T item)
        {
            _data.Remove(item);
        }

        Type IQueryable.ElementType
        {
            get { return _query.ElementType; }
        }

        Expression IQueryable.Expression
        {
            get { return _query.Expression; }
        }

        IQueryProvider IQueryable.Provider
        {
            get { return _query.Provider; }
        }

        System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return _data.GetEnumerator();
        }

        IEnumerator<T> IEnumerable<T>.GetEnumerator()
        {
            return _data.GetEnumerator();
        }

        public T Create()
        {
            return Activator.CreateInstance<T>();
        }

        public ObservableCollection<T> Local
        {
            get
            {
                return new ObservableCollection<T>(_data);
            }
        }

        public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, T
        {
            return Activator.CreateInstance<TDerivedEntity>();
        }
    }

Hope this code is useful to someone else :)

comments powered by Disqus