python

Django 新增一条纪录背后发生的事 - 知乎

文章暂存

systemime
2021-03-25
11 min

摘要.

Django中如需要新建一条纪录,肯定会调用model.save()方法。那么django之后到底做了哪些工作呢? 官方文档是这样说的:

  1. 发送一个django.db.models.signals.pre_save 信号,以允许监听该信号的函数完成一些自定义的动作。
  2. 预处理数据。 如果需要,对对象的每个字段进行自动转换。大部分字段不需要预处理,字段的数据将保持原样。预处理只用于具有特殊行为的字段。例如,如果你的模型具有一个auto_now=TrueDateField,那么预处理阶段将修改对象中的数据以确保该日期字段包含当前的时间戳。
  3. 准备数据库数据。 要求每个字段提供的当前值是能够写入到对应数据库中的类型。大部分字段不需要数据准备。简单的数据类型,例如整数和字符串,是可以直接写入的Python 对象。但是,复杂的数据类型通常需要一些改动。例如,DateField 字段使用Pythondatetime 对象来保存数据。数据库保存的不是datetime 对象,所以该字段的值必须转换成 ISO 兼容的日期字符串才能插入到数据库中。
  4. 插入数据到数据库中。 将预处理过、准备好的数据组织成一个 SQL 语句插入到数据库中。
  5. 发送一个django.db.models.signals.post_save 信号,以允许监听听信号的函数完成一些自定义的动作。

看了说明,有个疑问还是不太清楚,save() 是如何判断对象的纪录是新增呢还是更新呢?还是看源码能够比较清晰的回答这个疑问。下面是django 2.0的源码,根据我自己的理解写上了注释。
就从save()开始

# django/db/models/base.py

def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
    """
    force_insert, force_update 强制 save() 执行INSERT或UPDATE
    using 指定保存的数据库,setting.py中配置的DATABASES的key,默认是'default'
    这个函数主要是检查和处理字段的一些额外情况
    """

     # 检查外键关系的正确性并清除缓存
    for field in self._meta.concrete_fields:
        if field.is_relation and field.is_cached(self):
            obj = getattr(self, field.name, None)
            if obj and obj.pk is None:
                if not field.remote_field.multiple:
                    field.remote_field.delete_cached_value(obj)
                raise ValueError(
                    "save() prohibited to prevent data loss due to "
                    "unsaved related object '%s'." % field.name
                )
    
    using = using or router.db_for_write(self.__class__, instance=self)

    if force_insert and (force_update or update_fields):
        raise ValueError("Cannot force both insert and updating in model saving.")
 
    deferred_fields = self.get_deferred_fields()

    # 处理有指定更新的字段
    if update_fields is not None:
        if len(update_fields) == 0:
            return

        update_fields = frozenset(update_fields) # fronzenset 创建的是一个不可变集合
        field_names = set()

        for field in self._meta.fields:
            if not field.primary_key:
                field_names.add(field.name)

                if field.name != field.attname:
                    field_names.add(field.attname)

        non_model_fields = update_fields.difference(field_names)

        if non_model_fields:
            raise ValueError("The following fields do not exist in this "
                             "model or are m2m fields: %s"
                             % ', '.join(non_model_fields))

    # 这里不大明白什么情况会触发
    elif not force_insert and deferred_fields and using == self._state.db:
        field_names = set()
        for field in self._meta.concrete_fields:
            if not field.primary_key and not hasattr(field, 'through'):
                field_names.add(field.attname)
        loaded_fields = field_names.difference(deferred_fields)
        if loaded_fields:
            update_fields = frozenset(loaded_fields)

    # 保存动作进入下一个函数
    self.save_base(using=using, force_insert=force_insert,
                   force_update=force_update, update_fields=update_fields)

进入save_base(),主要逻辑是在这里处理

def save_base(self, raw=False, force_insert=False,
              force_update=False, using=None, update_fields=None):
    """
    这个函数的主要功能有
    1 跳过代理
    2 发出信号
    3 开启事务执行
    """

    # 再次检查参数正确性
    using = using or router.db_for_write(self.__class__, instance=self)
    assert not (force_insert and (force_update or update_fields))
    assert update_fields is None or len(update_fields) > 0

    cls = origin = self.__class__
    # 如果是代理model,则跳过代理
    if cls._meta.proxy:
        cls = cls._meta.concrete_model
    meta = cls._meta

    # 步骤1 发出保存前信号
    if not meta.auto_created:
        pre_save.send(
            sender=origin, instance=self, raw=raw, using=using,
            update_fields=update_fields,
        )

    # 在事务下执行
    with transaction.atomic(using=using, savepoint=False):
        if not raw:
            self._save_parents(cls, using, update_fields) 
        updated = self._save_table(raw, cls, force_insert, force_update, using, update_fields) # 执行函数,进入下一步

    # 保存状态
    self._state.db = using
    self._state.adding = False

    # 步骤5 发出保存完成的信号
    if not meta.auto_created:
        post_save.send(
            sender=origin, instance=self, created=(not updated),
            update_fields=update_fields, raw=raw, using=using,
        )

进入_save_table()

def _save_table(self, raw=False, cls=None, force_insert=False,
                force_update=False, using=None, update_fields=None):
    """
    这个函数主要的功能就是分配UPDATE或INSERT操作
    """
    meta = cls._meta
    
    non_pks = [f for f in meta.local_concrete_fields if not f.primary_key]
    if update_fields:
        non_pks = [f for f in non_pks
                   if f.name in update_fields or f.attname in update_fields]

    # 获取pk的值,如果是更新则pk值就可以获取到
    pk_val = self._get_pk_val(meta)
    if pk_val is None:
        pk_val = meta.pk.get_pk_value_on_save(self)
        setattr(self, meta.pk.attname, pk_val)
    pk_set = pk_val is not None
    if not pk_set and (force_update or update_fields):
        raise ValueError("Cannot force an update in save() with no primary key.")
    updated = False

    # 如果pk值有了尝试使用UPDATE,所以save()处理是更新还是插入纪录就靠这个pk值来判断
    if pk_set and not force_insert:
        base_qs = cls._base_manager.using(using)
        values = [(f, None, (getattr(self, f.attname) if raw else f.pre_save(self, False)))
                  for f in non_pks]
        forced_update = update_fields or force_update

        # _do_update() 更新操作
        updated = self._do_update(base_qs, using, pk_val, values, update_fields,
                                  forced_update)

        if force_update and not updated:
            raise DatabaseError("Forced update did not affect any rows.")
        if update_fields and not updated:
            raise DatabaseError("Save with update_fields did not affect any rows.")

    # 执行插入
    if not updated:
        if meta.order_with_respect_to:
            field = meta.order_with_respect_to
            filter_args = field.get_filter_kwargs_for_object(self)
            order_value = cls._base_manager.using(using).filter(**filter_args).count()
            self._order = order_value
        fields = meta.local_concrete_fields
        if not pk_set:
            fields = [f for f in fields if f is not meta.auto_field]
        update_pk = meta.auto_field and not pk_set

        # _do_insert() 插入操作, 返回一个新建的pk值
        result = self._do_insert(cls._base_manager, using, fields, update_pk, raw)

        # 设置这个新新生成的pk值,用于判断
        if update_pk:
            setattr(self, meta.pk.attname, result)

    return updated

先看 INSERT 一条纪录,进入_do_insert()

def _do_insert(self, manager, using, fields, update_pk, raw):
    """
    调用model的管理器执行INSERT
    """
    return manager._insert([self], fields=fields, return_id=update_pk,
                       using=using, raw=raw)

Manager的工作比较少,进入 _insert()

# django/db/models/query.py

def _insert(self, objs, fields, return_id=False, raw=False, using=None):
    """
    管理器单独新建一个InserQuery实例,通过查询集去转化
    """
    self._for_write = True
    if using is None:
        using = self.db
    query = sql.InsertQuery(self.model)
    query.insert_values(fields, objs, raw=raw)
    return query.get_compiler(using=using).execute_sql(return_id)

结果实际的 sql 转化交给了 Query 去实现,下面看这个 InsertQuery 是如何转化的。
首先,获得编译器get_compiler()

class InsertQuery(Query):
    compiler = 'SQLInsertCompiler'

    def get_compiler(self, using=None, connection=None):
        """
        这个函数的主要功能就是找出对应数据库(比如mysql)中SQLInsertCompiler的实现
        """
        if using is None and connection is None:
            raise ValueError("Need either using or connection")
        if using:
            connection = connections[using]
        return connection.ops.compiler(self.compiler)(self, connection, using)

第二,找到了对应的编译器那就执行 SQL,继续看_insert()中最后的execute_sql()

# django/db/models/sql/compiler.py 
class SQLInsertCompiler(SQLCompiler):

    def execute_sql(self, return_id=False):
    """
    这个函数主要就是执行sql,返回最后生成的pk值
    """
        assert not (
            return_id and len(self.query.objs) != 1 and
            not self.connection.features.can_return_ids_from_bulk_insert
        )
        self.return_id = return_id

        # 字段是怎么变成sql语句的,最后就是在这个as_sql()中
        with self.connection.cursor() as cursor:
            for sql, params in self.as_sql():
                # 步骤4 执行SQL
                cursor.execute(sql, params)

            if not (return_id and cursor):
                return
            if self.connection.features.can_return_ids_from_bulk_insert and len(self.query.objs) > 1:
                return self.connection.ops.fetch_returned_insert_ids(cursor)
            if self.connection.features.can_return_id_from_insert:
                assert len(self.query.objs) == 1
                return self.connection.ops.fetch_returned_insert_id(cursor)
            return self.connection.ops.last_insert_id(
                cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column
            )

终于到了最神秘也是最关键的地方,ORM 中对象转化成 sql 的部分as_sql()

def as_sql(self):
    """
    这个函数主要就是执行步骤3准备数据库数据,生成SQL和值
    """
    qn = self.connection.ops.quote_name 
    # 根据不同的数据库做不同的处理比如mysql:
    # >>> name = "green"
    # >>> qn(name)
    #>>> "`green`"

    opts = self.query.get_meta()
    result = ['INSERT INTO %s' % qn(opts.db_table)]

    has_fields = bool(self.query.fields)
    fields = self.query.fields if has_fields else [opts.pk]
    # 加入字段名称
    result.append('(%s)' % ', '.join(qn(f.column) for f in fields))

    if has_fields:
        # self.pre_save_val(field, obj) 实际调用的是field.pre_save(obj, add=True)
        # 将转化交给field自己去处理
        # prepare_value() 对需要特殊转化的字段进行处理如DatetimeField
        value_rows = [
            [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
            for obj in self.query.objs
        ]
    else:
        # An empty object.
        value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
        fields = [None]

    # 数据库能否支持批量插入
    can_bulk = (not self.return_id and self.connection.features.has_bulk_insert)

    # 生成字段和值对应的sql
    placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)

    # 下面就是根据不同的数据库能接受不同形式的SQL的处理,最终返回拼装好的SQL和值  
    if self.return_id and self.connection.features.can_return_id_from_insert:
        if self.connection.features.can_return_ids_from_bulk_insert:
            result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
            params = param_rows
        else:
            result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
            params = [param_rows[0]]
        col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
        r_fmt, r_params = self.connection.ops.return_insert_id()
        if r_fmt:
            result.append(r_fmt % col)
            params += [r_params]
        return [(" ".join(result), tuple(chain.from_iterable(params)))]

    if can_bulk:
        result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
        return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
    else:
        return [
            (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
            for p, vals in zip(placeholder_rows, param_rows)
        ]

最终生成了 SQL 执行完成,整个流程完成。

看到这里是不是觉得 django 调用链很长,我第一次看也没坚持看完。下面来整理这个流程以便于理解,如果没看完源码也可以看下面的整理。

  1. Model
    model 主要是处理外键关系 ,发送信号,判断是更新还是插入操作,然后交给 manager
  2. Manager
    manager 找到处理的 Query
  3. Query
    query 根据不同数据库的拼接 sql , 最后执行 https://zhuanlan.zhihu.com/p/33764375 https://zhuanlan.zhihu.com/p/33764375
上次编辑于: 5/20/2021, 7:26:49 AM