手写 SQLite 07:实现 SELECT count(*) FROM table

这是"从零手写 SQLite 读取器"系列的第 07 篇。前六篇中,我们依次实现了:文件头解析(第 01 篇)、页读取器(第 02 篇)、B-Tree 遍历(第 03 篇)、varint 与 Record 解码(第 04 篇)、sqlite_schema 读取与 .tables 命令(第 05 篇)、SELECT * FROM table 带列名输出(第 06 篇)。本篇在此基础上增加 SELECT count(*) FROM <table> 支持——只输出一个整数,代表表中的行数。

一、count(*) 为何特殊

SELECT count(*)SELECT * 的本质差别在于:前者不需要解码任何行数据

在我们的实现中,B-Tree 遍历的核心是 collect_leaf_cells(),它返回所有叶节点 Cell 的地址列表(Vec<CellAddr>)。每个 CellAddr 对应表中的一行。因此:

  • SELECT *:遍历所有 Cell 地址,对每个地址读取页数据、解码 varint、解析 Record、格式化每个字段值
  • SELECT count(*):遍历所有 Cell 地址,直接取 cells.len(),结束

不需要读取任何 Cell 的实际内容,不需要解析 varint,不需要解码 Record。唯一的开销是遍历 B-Tree 的内部节点以收集叶节点指针——这在任何情况下都无法省略,因为我们需要知道有多少叶节点 Cell。

对于多页 B-Tree(如 big.db 的 500 行数据),count(*) 只需访问内部节点页和叶节点的 Cell 指针区(每个 Cell 指针仅 2 字节),而 SELECT * 还要额外读取每个 Cell 的完整 payload。

二、实现 count_all()

src/sql.rs 中添加 count_all() 函数。它复用已有的 collect_leaf_cells()read_schema(),只需在拿到 Cell 地址列表后返回其长度:

// src/sql.rs  — 在 select_all() 之后追加

/// 执行 SELECT count(*) FROM <table_name>
/// 返回表中的行数,不解码任何行数据。
pub fn count_all(pager: &mut Pager, table_name: &str) -> Result<usize, String> {
    // 1. 从 sqlite_schema 找到表的根页号
    let entries = read_schema(pager);
    let entry = entries.iter()
        .find(|e| e.object_type == "table" && e.name == table_name)
        .ok_or_else(|| format!("table '{}' not found", table_name))?;

    let root_page = entry.root_page;

    // 2. 遍历 B-Tree,收集所有叶节点 Cell 地址
    let cells = collect_leaf_cells(pager, root_page);

    // 3. Cell 数量即行数,直接返回,无需解码 payload
    Ok(cells.len())
}

完整的 sql.rs 顶部 use 声明保持不变,count_allselect_all 共享同一组依赖:

// src/sql.rs 顶部(已有,无需修改)

use crate::btree::collect_leaf_cells;
use crate::pager::Pager;
use crate::record::{parse_leaf_cell, Value};
use crate::schema::{read_schema, parse_column_defs, ColumnDef};

三、扩展 main.rs

main.rs 中增加两处改动:

  1. 新增 parse_count_star() 辅助函数,从命令字符串中识别 SELECT count(*) FROM <table> 模式并提取表名
  2. main() 的命令分发逻辑中,优先尝试 count(*) 模式,再尝试 SELECT * 模式

parse_count_star() 辅助函数

/// 从 "SELECT count(*) FROM users" 解析出表名 "users"
/// 大小写不敏感,支持 count(*) / COUNT(*) / Count(*)
fn parse_count_star(sql: &str) -> Option<&str> {
    let upper = sql.to_uppercase();
    // 去掉前缀 "SELECT"
    let rest = upper.strip_prefix("SELECT")?.trim_start();
    // 去掉 "COUNT(*)"(忽略大小写已经由 upper 处理)
    let rest = rest.strip_prefix("COUNT(*)")?.trim_start();
    // 去掉 "FROM"
    let rest = rest.strip_prefix("FROM")?.trim_start();
    // 在原始 sql 中定位对应位置(保持原始大小写)
    let offset = sql.len() - rest.len();
    // 表名到下一个空白或字符串末尾
    let name = sql[offset..].split_whitespace().next()?;
    Some(name)
}

更新 main() 命令分发

// src/main.rs

mod header; mod page; mod pager; mod btree;
mod varint; mod record; mod schema; mod sql;

use header::DbHeader;
use pager::Pager;
use schema::read_schema;
use sql::{select_all, count_all};

fn main() {
    let args: Vec<String> = std::env::args().collect();
    if args.len() < 3 {
        eprintln!("Usage: sqlite-rs <db> <.tables | SELECT * FROM <table> | SELECT count(*) FROM <table>>");
        std::process::exit(1);
    }

    let db_path = &args[1];
    let command = args[2..].join(" ");

    let db_header = DbHeader::read_from_file(db_path).unwrap();
    let mut pager = Pager::open(db_path, db_header.page_size).unwrap();

    if command == ".tables" {
        let entries = read_schema(&mut pager);
        let names: Vec<&str> = entries.iter()
            .filter(|e| e.object_type == "table" && !e.name.starts_with("sqlite_"))
            .map(|e| e.name.as_str())
            .collect();
        println!("{}", names.join(" "));
        return;
    }

    // 优先尝试 SELECT count(*) FROM <table>
    if let Some(table_name) = parse_count_star(&command) {
        match count_all(&mut pager, table_name) {
            Ok(n) => println!("{}", n),
            Err(e) => {
                eprintln!("Error: {}", e);
                std::process::exit(1);
            }
        }
        return;
    }

    // 再尝试 SELECT * FROM <table>
    if let Some(table_name) = parse_select_star(&command) {
        match select_all(&mut pager, table_name) {
            Ok(result) => {
                println!("{}", result.columns.join("|"));
                for row in &result.rows {
                    println!("{}", row.join("|"));
                }
            }
            Err(e) => {
                eprintln!("Error: {}", e);
                std::process::exit(1);
            }
        }
        return;
    }

    eprintln!("Unsupported command: {}", command);
    eprintln!("Supported: .tables | SELECT * FROM <table> | SELECT count(*) FROM <table>");
    std::process::exit(1);
}

/// 从 "SELECT * FROM users" 解析出表名 "users"
fn parse_select_star(sql: &str) -> Option<&str> {
    let upper = sql.to_uppercase();
    let rest = upper.strip_prefix("SELECT")?.trim_start()
        .strip_prefix('*')?.trim_start()
        .strip_prefix("FROM")?
        .trim_start();
    let offset = sql.len() - rest.len();
    let name = sql[offset..].split_whitespace().next()?;
    Some(name)
}

/// 从 "SELECT count(*) FROM users" 解析出表名 "users"
fn parse_count_star(sql: &str) -> Option<&str> {
    let upper = sql.to_uppercase();
    let rest = upper.strip_prefix("SELECT")?.trim_start();
    let rest = rest.strip_prefix("COUNT(*)")?.trim_start();
    let rest = rest.strip_prefix("FROM")?.trim_start();
    let offset = sql.len() - rest.len();
    let name = sql[offset..].split_whitespace().next()?;
    Some(name)
}

注意 parse_count_starmain() 中被调用时,函数定义可以放在文件任意位置(Rust 不要求先声明后使用)。上面把两个解析函数都列在 main() 之后,保持代码结构一致。

四、测试

test.db(3 行)

cargo run -- test.db "SELECT count(*) FROM users"
3

SELECT * 的输出对照:

cargo run -- test.db "SELECT * FROM users"
id|name|age
1|Alice|30
2|Bob|25
3|Charlie|35

行数确认一致。

big.db(500 行)

cargo run -- big.db "SELECT count(*) FROM users"
500

big.db 的 B-Tree 跨越多个页,collect_leaf_cells() 递归遍历所有叶节点后返回 500 个地址,cells.len() 即 500。

multi.db(多表)

cargo run -- multi.db "SELECT count(*) FROM employees"
2
cargo run -- multi.db "SELECT count(*) FROM departments"
2

大小写不敏感

cargo run -- test.db "select COUNT(*) from users"
3

parse_count_star() 对整个命令字符串调用 to_uppercase() 后匹配,因此 selectSELECTcount(*)COUNT(*) 均可识别。

五、count(*) 与 SELECT * 的性能对比

两条路径在 B-Tree 遍历阶段完全相同——都调用 collect_leaf_cells(),都需要读取每一个内部节点页来找到叶节点,都需要读取每一个叶节点页来枚举 Cell 指针。差别发生在拿到 Vec<CellAddr> 之后:

阶段 SELECT count(*) SELECT *
遍历内部节点
读取叶节点页(Cell 指针区)
读取 Cell payload(行数据) 是(每行)
解码 varint(payload 长度、rowid) 是(每行)
解析 Record header(serial types) 是(每行)
解码每列值(整数、文本、浮点) 是(每行每列)
格式化输出 一次 println! 每行一次 println!

实际上,SQLite 官方对 count(*) 有更激进的优化:对于不含 WITHOUT ROWID 的普通表,可以通过读取 B-Tree 统计信息(sqlite_stat1)直接返回近似行数,或者只扫描 Cell 指针而不读取页内容。我们的实现没有做到这一步——仍然需要将所有叶节点页读入内存——但已经省去了所有的数据解码开销。

对于行数极大、每行数据量又很大的表(例如每行存储几 KB 的 BLOB),count(*) 的优势会非常显著:B-Tree 的叶节点数量与行数成正比,而每行 payload 的大小对 count(*) 完全不影响(我们根本不读 payload)。

六、当前代码结构

sqlite-rs/
└── src/
    ├── main.rs     ← 命令路由:.tables / SELECT * / SELECT count(*)(本篇更新)
    ├── header.rs   ← 文件头解析(第 01 篇)
    ├── page.rs     ← 页头 + Cell 指针(第 02 篇)
    ├── pager.rs    ← 页读取器(第 02 篇)
    ├── btree.rs    ← B-Tree 遍历(第 03 篇)
    ├── varint.rs   ← varint 解码(第 04 篇)
    ├── record.rs   ← Record 解析(第 04 篇)
    ├── schema.rs   ← sqlite_schema 解析 + 列定义解析(第 05、06 篇)
    └── sql.rs      ← SQL 执行器:select_all + count_all(本篇更新)

本篇只改动了两个文件:sql.rs(新增 count_all 函数)和 main.rs(新增 parse_count_star 函数和分发逻辑)。其余 7 个文件原封不动。

七、关键点总结

  • SELECT count(*) 只需统计 B-Tree 叶节点的 Cell 数量,无需解码任何行数据
  • collect_leaf_cells() 返回的 Vec<CellAddr> 长度直接等于表的行数
  • count_all()select_all() 共享同一个 B-Tree 遍历和 schema 查找逻辑,代码复用度高
  • parse_count_star() 通过 to_uppercase() 实现大小写不敏感匹配,与 parse_select_star() 风格一致
  • 命令分发中 count(*) 模式先于 SELECT * 模式匹配,避免后者误识别(两者都以 SELECT 开头)
  • count(*) 相比 SELECT * 省去了读取 Cell payload、解码 varint、解析 Record、格式化输出等所有行级操作

下一篇:实现 WHERE 过滤——在 SELECT * FROM <table> WHERE <col> = <value> 中,解析简单等值条件,在遍历每一行时比较指定列的值,只输出满足条件的行。