The other day Aaron was having some trouble with a pretty simple piece of code and asked for my help. He had a Celery task that runs every 15 minutes which was about refreshing access token from a 3rd party service. Something like this:

@task
def refresh_tokens():
    for obj in MyModel.objects.all():
        new_token = obj.get_new_token()
        obj.token = new_token
        obj.save()

Now the funny thing is that this task was working perfectly at 11:15, 11:30 and 11:45 but it didn't work at 12:00. Yes, it was a little bit time picky.

We knew it didn't work because another part of the code was making API requests to that service and we were getting token expired errors.

So we added some code to send an email to us with the value of the obj.token field before and after we got the new token. We used the method obj.refresh_from_db() after the save call to make sure our new value was correctly persisted.

And again, at 13:15, 13:30 and 13:45 the email showed the right value but at 14:00 the before and after values were the same.

In other words, the call obj.save() was not working.

At that point I was very puzzled and also very interested in solving this mistery. So I went home and keep thinking on it. It was only on my walk from the train station to my home when I had an idea of what could be happening.

I checked our Celery schedule for periodic tasks and confirmed my suspicion. We had another task that run every hour and that performed some writes to the same model! The main difference was that this older task was running a query that returned a lot more objects than our newer task.

Essentially we had 2 tasks that looked like this:


@task
def refresh_tokens_for_service_1():
    for obj in MyModel.objects.filter(service_1_token__isnull=False):
        new_token = obj.get_new_token_for_service_1()
        obj.service_1_token = new_token
        obj.save()

@task
def refresh_tokens_for_service_2():
    for obj in MyModel.objects.filter(service_2_token__isnull=False):
        new_token = obj.get_new_token_for_service_2()
        obj.service_2_token = new_token
        obj.save()

The way the ORM works is that when you run the queries it fetches the data from the database and create the model objects with that data. That means we are doing 1 SELECT query in each of these 2 tasks and those queries run at about the same time (by differnet Celery workers). Once the data is in the worker memory it never gets re-read again. Because the second task finishes faster since there are fewer objects with service_2 enabled, the values for those objects get overwritten as the first task continues to execute.

An easy "fix" is to schedule both tasks at a different time but that "solution" will bite us in the future for sure.

A slightly better approach is to split each task into N + 1 tasks having a master task that just runs the initial query and call a separate task for each object in the loop. Then in each child task it will read the object again from the DB, compute the new token and write it to the DB. While this doesn't really solve the problem, it greatly reduces the probability of it happening because now each object data stay in memory way less time. Note that splitting the task into smaller tasks (aka not having fat tasks) is also a good practice because it helps when restarting the worker process. With big fat tasks the worker needs to wait until the task is finished before it can be restarted.

Another way to solve this problem is to specify the fields to update when writing the object to the DB by using the update_fields argument in the save method. The problem with this method is that it requires an audit of all the different tasks that writes to this model in order to put the right fields in the update_fields argument. And worse, any future person touching the code needs to know about this.

Finally, the right approach is to somehow synchronize both tasks and make them be nice to each other and wait until the other is done dealing with an object of that model. Of course, I'm talking about using select_for_update.