大型语言模型 (LLM) 正在彻底改变人工智能领域,其应用范围从智能聊天机器人到高级内容生成,无所不包。虽然预训练的 LLM 已经具备强大的开箱即用功能,但通过在自定义数据集上进行微调,可以显著提高其在特定任务和领域的表现。本文将深入探讨如何使用 Keras 微调 Gemma 2B LLM,利用 JAX 的高性能计算能力,并在 TPU 设备网格上进行分布式训练,从而实现无与伦比的速度和效率。无论您是希望将 LLM 适配到企业的知识库,创建特定领域的聊天机器人,还是优化现有模型,本指南都将引导您完成整个过程,从环境设置到分布式训练和模型保存。

为什么选择 Keras, JAX 和 TPUs?

选择 Keras,是因为它提供了一个高级的、用户友好的 API,可以简化神经网络的构建和训练过程。它的抽象特性允许在后端(TensorFlow、JAX、PyTorch)之间无缝切换,使其具有极高的灵活性。例如,您可以先使用 Keras 在 CPU 上快速原型化一个模型,然后在准备大规模训练时,轻松切换到 JAX 后端并利用 TPU 的强大算力。

JAX 是 Google 的高性能数值计算库。它提供自动微分、XLA 编译(为 TPU 等特定硬件优化代码)和函数式编程范式,使其成为大规模机器学习研究和生产的理想选择。JAX 的 XLA 编译器可以将 Keras 定义的模型转化为高度优化的代码,充分利用 TPU 的架构特点,从而显著提升训练速度。

TPU (Tensor Processing Units) 由 Google 开发,是专门为加速机器学习工作负载而构建的。它们擅长矩阵乘法,这是神经网络计算的核心,尤其是在像 LLM 这样的基于 Transformer 的模型中。将 JAX 与 TPU 结合使用,可以实现高效且可扩展的训练。例如,使用 TPU v4 Pod,可以在数小时内完成原本需要数天甚至数周才能完成的 LLM 训练任务。

环境配置

首先,我们需要准备 Colab 环境。我们将安装必要的库并配置后端。

!pip install -q -U keras keras-hub

接下来,我们将设置环境变量。如果在 Colab notebook 中运行,可以安全地使用 userdata.get 来获取 Kaggle 凭据。否则,请确保手动设置 KAGGLE_USERNAMEKAGGLE_KEY。至关重要的是,我们将 KERAS_BACKEND 设置为 “jax”,以确保 Keras 使用 JAX 进行操作。我们还配置了 XLA_PYTHON_CLIENT_MEM_FRACTION 以避免内存碎片,这是在使用 JAX 在加速器上工作时的常见问题。

import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.9"

import keras
import keras_hub
import jax

jax.devices()

jax.devices() 调用将显示可用的 JAX 设备。在 Colab TPU 运行时,应该会看到列出的多个 TPU 核心(例如,8 个设备)。

使用设备网格和模型并行进行分布式微调

使用 JAX 与 TPU 进行 LLM 训练最强大的方面之一是利用分布式训练的能力。对于大型模型,通常无法将它们安装在单个设备上。模型并行允许我们将模型的参数拆分到多个设备上,从而有效地分配计算负载和内存需求。

在这里,我们将设置一个 DeviceMeshLayoutMap 来定义我们的模型如何在可用的 TPU 核心上分布。

device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)

DeviceMesh((1, 8), ["batch", "model"], ...) 中,我们创建了一个 1×8 网格,这意味着我们的操作将在 8 个设备上分布。我们将这些维度命名为 “batch” 和 “model”。 “model” 维度对于模型并行尤其重要。LayoutMap 然后指定 Gemma 模型权重的不同部分(kernel)将如何在此 device_mesh 上分片。

layout_map[".*lora_kernel_a"] = (None, None)
layout_map[".*lora_kernel_b"] = (None, None)
layout_map["decoder_block.*attention.*(query|key|value).*kernel$"] = (None, model_dim, None)
layout_map["decoder_block.*attention_output.*kernel$"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel$"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel$"] = (model_dim, None)

model_parallel = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)

让我们分解这些 layout_map 条目:

  • LoRA Kernels: layout_map[".*lora_kernel_a"] = (None, None)layout_map[".*lora_kernel_b"] = (None, None) 表明 LoRA (Low-Rank Adaptation) kernel,我们稍后将启用它,不会被分片。这是 LoRA 的常见做法,因为它的矩阵通常足够小,可以安装在单个设备上,而分片它们可能会引入不必要的开销。
  • Original Kernels: 后续行使用正则表达式来定位 Gemma 模型转换器块 (decoder_block) 中的特定 kernel 权重。例如,layout_map["decoder_block.*attention.*(query|key|value).*kernel$"] = (None, model_dim, None) 告诉 Keras 沿 model_dim 对注意力机制的 query, key, 和 value kernel 进行分片。这会将这些大型矩阵的参数分布在我们的 8 个 TPU 核心上。类似地,像注意力输出、FFW gating 和 FFW linear kernel 这样的其他关键层也被分片。

最后,keras.distribution.set_distribution(model_parallel) 将此分发策略全局应用于 Keras 操作。通过这种方式,即使是 2B 参数的 Gemma 模型,也能有效地在多个 TPU 核心上进行训练,避免了单卡内存溢出的问题。

加载 Gemma LLM

配置好分布式设置后,我们现在可以从 keras_hub 加载 Gemma 2B Causal Language Model。

gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")

此行下载并加载预训练的 Gemma 2B 模型。 GemmaCausalLM 是一个 Keras 模型,它包装了 Gemma backbone,使其易于用于文本生成和微调。

准备微调数据

在本教程中,我们使用一个小型自定义数据集,该数据集是关于印度尼西亚大学 Institut STTS(也称为 Institut Sains dan Teknologi Terpadu Surabaya)。本教程向您展示如何调整 LLM 以理解和回答来自知识的特定问题。 train_contexts 提供事实信息,train_questions_answers 包含链接到这些上下文的问答对。

train_contexts = [
    "Website iSTTS: https://www.istts.ac.id/",
    "Website pendaftaran iSTTS : https://pmb.istts.ac.id/manajemen/?menu=pendaftaran_online&page=pendaftaran_s1",
    "iSTTS merupakan sekolah teknik terbaik di Jawa Timur, dan sudah berdiri lebih dari 40 tahun untuk terus mengabdi pada bangsa dan negara Indonesia",
    "iSTTS memiliki dua fakultas yaitu fakultas teknik dan desain yang sangat selaras dengan teknologi yang sedang berkembang di dunia",
    "Untuk fakultas teknik terdapat D3 Sistem Informasi, S1 Teknik Elektro, S1 Informatika, S1 Teknik Industri, S1 Sistem Informasi Bisnis, S1 Informatika profesional (kelas malam)",
    "Untuk fakultas desain terdapat S1 Desain Produk, S1 Desain Komunikasi Visual",
    "iSTTS menyediakan Magister Teknologi Informasi",
    "iSTTS menyediakan Bachelor International Program yang merupakan program dual degree, pada program ini iSTTS bekerja sama dengan Swinburne university of technology di australia, Murdoch university di singapore, Dongseo university di korea",
    "D3 Sistem Informasi memfokuskan pada Web Desain, Jaringan Komputer, serta Basis Data dan Sistem Informasi untuk menghasilkan tenaga kerja profesional",
    "S1 Teknik Elektro memberikan konsep dan pengembangan perangkat keras maupun perangkat lunak yang terfokus pada Industrial Automation dan Industrial Internet of Things (IoT)",
    "S1 Informatika belajar memanfaatkan komputasi untuk problem solving menggunakan teknologi informasi yang memiliki 3 penjurusan yaitu, Software Technology, Computational Intelligence, dan Internet Technology",
    "S1 Teknik Industri memberikan pengajaran mengenai melakukan pengelolaan sistem yang memiliki penjurusan terkait Manufacturing System dan Quality Management",
    "S1 Sistem Informasi Bisnis yang menggabungkan Information Technology (IT) dan dunia Bisnis, serta memiliki 3 fokus pendidikan yaitu, Information Science, Multimedia and Game Technology, dan Social Media Business",
    "S1 Informatika profesional (kelas malam) menyediakan para perkerja yang melanjutkan studinya ataupun bagi para lulusan SMA yang ingin kuliah sambil bekerja dengan fokus pendidikan, antara lain Kewirausahaan serta Java dan .NET",
    "S1 Desain Produk memadukan desain dan teknologi otomasi untuk menghasilkan produk yang tidak hanya menarik secara tampilan, tetapi memiliki fungsi smart product. S1 Desain Produk memiliki 4 fokus pendidikan, yaitu Desain Produk Cerdas, Desain Kriya Modern, Desain Transportasi, dan Desain Interior",
    "S1 Desain Komunikasi Visual berfokus pada 3 pilar utama yaitu Seni, Teknologi, dan Bisnis. S1 Desain Komunikasi Visual memiliki 4 fokus pendidikan, yaitu Animasi, Perfilman, Ilustrasi, dan Fotografi",
    "Magister Teknologi Informasi memberikan pengetahuan dan wawasan tentang teknologi Informasi terbaru yang disertai dengan penggunaan dan pemanfaatan dalam dunia kerja. IT Utilization dan Artificial Intelligence merupakan 2 fokus pendidikan di Magister Teknologi Informasi",
    "iSTTS Memiliki 4 Laboratorium Komputer yang berada di L204, L304, L404, serta E401",
    "iSTTS Memiliki Research Laboratory untuk melakukan riset-riset mengenai teknologi-teknologi baru yang muncul",
    "iSTTS Juga memiliki fasilitas lain seperti Lapangan, Ruang Band, Theater, serta 2 Perpustakaan",
    "iSTTS Menyediakan sertifikasi yang dapat diambil oleh mahasiswa, seperti AWS, RedHat, MikroTik, Cisco, dan lain-lain",
    "Mahasiswa dapat masuk iSTTS tanpa melakukan tes masuk"]

train_questions_answers = [
    {
        "context_index": 0,
        "question": "Apakah website iSTTS?",
        "answer": "https://www.istts.ac.id"
    },
    {
        "context_index": 0,
        "question": "Apakah iSTTS mempunyai website?",
        "answer": "Ya, iSTTS memiliki website yang terdapat di https://www.istts.ac.id"
    },
    {
        "context_index": 1,
        "question": "Apakah website pendaftaran iSTTS?",
        "answer": "https://pmb.istts.ac.id/manajemen/?menu=pendaftaran_online&page=pendaftaran_s1"
    },
    {
        "context_index": 1,
        "question": "Di manakah untuk melakukan pendaftaran iSTTS?",
        "answer": "Pendaftaran iSTTS bisa dilakukan di halaman website https://pmb.istts.ac.id/manajemen/?menu=pendaftaran_online&page=pendaftaran_s1"
    },
    {
        "context_index": 1,
        "question": "Apakah iSTTS memiliki website pendaftaran?",
        "answer": "Ya, iSTTS memiliki website pendaftaran yang terdapat di https://pmb.istts.ac.id/manajemen/?menu=pendaftaran_online&page=pendaftaran_s1"
    },
    {
        "context_index": 2,
        "question": "Apa itu iSTTS?",
        "answer": "iSTTS merupakan sekolah teknik terbaik di Jawa Timur dan sudah berdiri lebih dari 40 tahun untuk terus mengabdi pada bangsa dan negara Indonesia"
    },
    {
        "context_index": 2,
        "question": "Sudah berapa lama iSTTS berdiri?",
        "answer": "iSTTS sudah berdiri lebih dari 40 tahun"
    },
    {
        "context_index": 3,
        "question": "Apa saja fakultas yang ada di iSTTS?",
        "answer": "iSTTS memiliki dua fakultas yaitu fakultas teknik dan desain"
    },
    {
        "context_index": 4,
        "question": "Apa saja program studi yang ada di fakultas teknik iSTTS?",
        "answer": "Terdapat D3 Sistem Informasi, S1 Teknik Elektro, S1 Informatika, S1 Teknik Industri, S1 Sistem Informasi Bisnis, dan S1 Informatika Profesional (kelas malam)"
    },
    {
        "context_index": 5,
        "question": "Apa saja program studi yang ada di fakultas desain iSTTS?",
        "answer": "Terdapat S1 Desain Produk dan S1 Desain Komunikasi Visual"
    },
    {
        "context_index": 6,
        "question": "Apakah terdapat program studi untuk lanjutan di iSTTS?",
        "answer": "iSTTS juga menyediakan program Magister untuk Teknologi Informasi"
    },
    {
        "context_index": 7,
        "question": "Apakah iSTTS menyediakan program studi untuk ke luar negeri?",
        "answer": "Ya, iSTTS menyediakan program Bachelor International Program yang merupakan program dual degree"
    },
    {
        "context_index": 7,
        "question": "Apakah iSTTS memiliki hubungan kerja sama dengan universitas lain?",
        "answer": "Ya, iSTTS memiliki beberapa hubungan kerja sama dengan universitas di luar negeri untuk dual degree. Terdapat Swinburne university of technology di Australia, Murdoch university di Singapore, dan Dongseo university di Korea"
    },
    {
        "context_index": 7,
        "question": "Apa ada kerja sama dengan iSTTS di Australia?",
        "answer": "Ya ada, iSTTS memiliki kerja sama dengan Swinburne university of technology di Australia untuk program dual degree"
    },
    {
        "context_index": 7,
        "question": "Apa ada kerja sama dengan iSTTS di Singapore?",
        "answer": "Ya ada, iSTTS memiliki kerja sama dengan Murdoch univerrsity di Singapore untuk program dual degree"
    },
    {
        "context_index": 7,
        "question": "Apa ada kerja sama dengan iSTTS di Korea?",
        "answer": "Ya ada, iSTTS memiliki kerja sama dengan Dongseo university di Korea untuk program dual degree"
    },
    {
        "context_index": 8,
        "question": "Apabila ingin untuk fokus pada Web Desain program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi D3 Sistem Informasi berfokus pada Web Desain, Jaringan Komputer, Basis Data, dan Sistem Informasi"
    },
    {
        "context_index": 8,
        "question": "Apabila ingin untuk fokus pada Jaringan Komputer program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi D3 Sistem Informasi berfokus pada Web Desain, Jaringan Komputer, Basis Data, dan Sistem Informasi"
    },
    {
        "context_index": 8,
        "question": "Apabila ingin untuk fokus pada Basis Data program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi D3 Sistem Informasi berfokus pada Web Desain, Jaringan Komputer, Basis Data, dan Sistem Informasi"
    },
    {
        "context_index": 8,
        "question": "Apabila ingin untuk fokus pada Sistem Informasi program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi D3 Sistem Informasi berfokus pada Web Desain, Jaringan Komputer, Basis Data, dan Sistem Informasi"
    },
    {
        "context_index": 9,
        "question": "Apabila ingin untuk belajar mengenai perangkat keras program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Elektro memberikan konsep dan pengembangan perangkat keras maupun perangkat lunak"
    },
    {
        "context_index": 9,
        "question": "Apabila ingin untuk fokus pada Industrial Automation program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Elektro memiliki fokus penjurusan pada Industrial Automation"
    },
    {
        "context_index": 9,
        "question": "Apabila ingin untuk fokus pada IoT program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Elektro memiliki fokus penjurusan pada Internet of Things (IoT)"
    },
    {
        "context_index": 9,
        "question": "Apa saja penjurusan program studi Teknik Elektro di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Elektro memiliki 2 penjurusan, yaitu Industrial Automation dan Industrial Internet of Things (IoT)"
    },
    {
        "context_index": 10,
        "question": "Apabila ingin untuk belajar untuk coding program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Informatika belajar memanfaatkan komputasi untuk problem solving menggunakan teknologi informasi"
    },
    {
        "context_index": 10,
        "question": "Apabila ingin untuk fokus pada metode perancangan pembuatan aplikasi program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Informatika memiliki fokus pembelajaran pada Software Technology"
    },
    {
        "context_index": 10,
        "question": "Apabila ingin untuk fokus pada AI program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Informatika memiliki fokus pembelajaran pada Computational Intelligence"
    },
    {
        "context_index": 10,
        "question": "Apabila ingin untuk fokus pada jaringan internet program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Informatika memiliki fokus pembelajaran pada Internet Technology"
    },
    {
        "context_index": 10,
        "question": "Apa saja penjurusan program studi Informatika di iSTTS?",
        "answer": "Di iSTTS program studi S1 Informatika memiliki 3 penjurusan, yaitu Software Technology, Computational Intelligence, dan Internet Technology"
    },
    {
        "context_index": 11,
        "question": "Apa saja penjurusan program studi Teknik Industri di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Industri memiliki 2 penjurusan, yaitu Manufacturing System dan Quality Management"
    },
    {
        "context_index": 11,
        "question": "Apabila ingin untuk fokus pada pengelolaan sistem program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Industri memberikan pengajaran mengenai melakukan pengelolaan sistem"
    },
    {
        "context_index": 11,
        "question": "Apabila ingin untuk fokus pada mengelola Manufacturing System program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Industri memiliki penjurusan Manufacturing System"
    },
    {
        "context_index": 11,
        "question": "Apabila ingin untuk fokus pada Quality Management saat mengelola suatu sistem program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Teknik Industri memiliki penjurusan terkait Quality Management"
    },
    {
        "context_index": 12,
        "question": "Apa saja penjurusan program studi Sistem Informasi Bisnis di iSTTS?",
        "answer": "Di iSTTS program studi S1 Sistem Informasi Bisnis memiliki 3 penjurusan, yaitu Information Science, Multimedia and Game Technology, dan Social Media Business"
    },
    {
        "context_index": 12,
        "question": "Apabila ingin untuk belajar IT dan bisnis program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Sistem Informasi Bisnis yang menggabungkan Information Technology (IT) dan dunia Bisnis"
    },
    {
        "context_index": 12,
        "question": "Apabila ingin untuk belajar Game development program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Sistem Informasi Bisnis memiliki fokus pendidikan Multimedia and Game Technology yang bisa merambah ke Game Development"
    },
    {
        "context_index": 12,
        "question": "Apabila ingin untuk belajar mengenai AI tapi untuk bisnis program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Sistem Informasi Bisnis memiliki fokus pendidikan Information Science yang bertujuan untuk menghasilkan AI untuk dunia bisnis"
    },
    {
        "context_index": 12,
        "question": "Apabila ingin untuk belajar mengenai digital business program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Sistem Informasi Bisnis memiliki fokus pendidikan Social Media Business yang mengajarkan untuk membuat digital business"
    },
    {
        "context_index": 13,
        "question": "Apabila ingin untuk belajar namun sambil bekerja program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Informatika profesional (kelas malam) tersedia untuk orang yang ingin kuliah sambil bekerja"
    },
    {
        "context_index": 13,
        "question": "Apa saja fokus pendidikan S1 Informatika Profesional?",
        "answer": "Di iSTTS program studi S1 Informatika profesional (kelas malam) memiliki fokus pendidikan antara lain Kewirausahaan, Java, dan .NET"
    },
    {
        "context_index": 14,
        "question": "Apabila ingin belajar untuk cara yang baik menghasilkan produk program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Produk memadukan desain dan teknologi otomasi untuk menghasilkan produk yang tidak hanya menarik secara tampilan namun memiliki fungsi smart product"
    },
    {
        "context_index": 14,
        "question": "Apabila ingin belajar untuk merancang interior program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Produk memiliki penjurusan Desain Interior"
    },
    {
        "context_index": 14,
        "question": "Apabila ingin belajar untuk merancang produk cerdas program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Produk memiliki penjurusan Desain Produk Cerdas"
    },
    {
        "context_index": 14,
        "question": "Apa saja penjurusan pada program studi Teknik Industri di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Produk memiliki 4 penjurusan, yaitu Desain Produk Cerdas, Desain Kriya Modern, Desain Transportasi, dan Desain Interior"
    },
    {
        "context_index": 15,
        "question": "Apabila ingin belajar fotografi program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Komunikasi Visual memiliki penjurusan Fotografi"
    },
    {
        "context_index": 15,
        "question": "Apabila ingin belajar untuk membuat film program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Komunikasi Visual memiliki penjurusan Perfilman"
    },
    {
        "context_index": 15,
        "question": "Apabila ingin belajar untuk membuat animasi program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Komunikasi Visual memiliki penjurusan Animasi"
    },
    {
        "context_index": 15,
        "question": "Apabila ingin belajar melukis program studi apa yang harus saya ambil di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Komunikasi Visual memiliki penjurusan Ilustrasi"
    },
    {
        "context_index": 15,
        "question": "Apa saja penjurusan pada program studi Desain Komunikasi Visual di iSTTS?",
        "answer": "Di iSTTS program studi S1 Desain Komunikasi Visual memiliki penjurusan 4, yaitu Animasi, Perfilman, Ilustrasi, dan Fotografi"
    },
    {
        "context_index": 16,
        "question": "Apa saja fokus pembelajaran program studi Magister Teknologi Informasi di iSTTS?",
        "answer": "Di iSTTS program studi Magister Teknologi Informasi memiliki 2 fokus pembelajaran, yaitu IT Utilization dan Artificial Intelligence"
    },
    {
        "context_index": 17,
        "question": "Apa saja fasilitas di iSTTS?",
        "answer": "Di iSTTS terdapat 4 Laboratorium Komputer"
    },
    {
        "context_index": 18,
        "question": "Apakah saya bisa melakukan riset di iSTTS?",
        "answer": "Di iSTTS juga terdapat Research Laboratory untuk melakukan riset mengenai teknologi terbaru"
    },
    {
        "context_index": 19,
        "question": "Apa saja fasilitas di iSTTS?",
        "answer": "Di iSTTS terdapat Lapangan, Ruang Band, Theater, dan 2 Perpustakaan"
    },
    {
        "context_index": 20,
        "question": "Apakah iSTTS menyediakan program untuk sertifikasi?",
        "answer": "Di iSTTS menyediakan banyak program sertifikasi seperti AWS, RedHat, MikroTik, Cisco, dan lain-lain"
    },
    {
        "context_index": 21,
        "question": "Apakah saya perlu melakukan tes untuk masuk ke iSTTS?",
        "answer": "Mahasiswa dapat masuk iSTTS tanpa melakukan tes masuk"
    },]

data = []
for features in train_questions_answers:
  features['context'] = train_contexts[features['context_index']]
  template = "<start_of_turn>user\n{question}<end_of_turn>\n<start_of_turn>model\nContext adalah {context}\nJadi jawabannya adalah {answer}<end_of_turn>\n"
  data.append(template.format(**features))

接下来的步骤是将数据格式化为基于对话的轮次结构,遵循 Gemma 模型首选的输入格式:<start_of_turn>user\n{question}<end_of_turn>\n<start_of_turn>model\nContext adalah {context}\nJadi jawabannya adalah {answer}<end_of_turn>\n。这种显式格式有助于模型理解微调期间的上下文和所需的输出。

LoRA: 高效微调

微调整个 LLM 可能会在计算上非常昂贵并且占用大量内存。 Low-Rank Adaptation (LoRA) 是一种参数高效的微调技术,可显著减少可训练参数的数量,从而使微调更容易实现。使用 LoRA,我们将小的、低秩矩阵注入到现有的预训练权重中。仅训练这些新矩阵,而原始预训练权重保持冻结。这大大减少了内存占用和计算成本,同时仍能实现出色的性能。例如,在 Gemma 2B 模型上使用 LoRA,可以将可训练参数的数量从 20 亿减少到几百万,从而可以在资源受限的环境中进行微调。

gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

在这里,gemma_lm.backbone.enable_lora(rank=4) 在 Gemma 模型的主干上启用 LoRA,并将低秩矩阵的秩设置为 4。值为 4 的秩是一个好的起点,可以在表达能力和参数效率之间取得平衡。 summary() 方法现在将显示比总参数少得多的可训练参数,表明 LoRA 已激活。

模型配置和训练

在训练之前,我们配置模型的预处理器、优化器和损失函数。

gemma_lm.preprocessor.sequence_length = 512

optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    sampler=sampler
)

gemma_lm.fit(data, epochs=10, batch_size=1)
  • sequence_length: 我们将输入序列长度限制为 512,以管理内存使用,这在 TPU 上尤其重要。
  • optimizer: 我们使用 AdamW,它是 Transformer 模型的常用选择,学习率为 5e-5,权重衰减为 0.01。最好将偏差和比例项从权重衰减中排除。
  • sampler: TopKSampler 用于评估期间的文本生成,以确保生成输出的多样性。
  • loss: 对于因果语言模型,使用 SparseCategoricalCrossentropy 作为损失函数,因为我们预测序列中的下一个 token。 from_logits=True 表示模型的输出是原始 logits。
  • fit(): gemma_lm.fit(data, epochs=10, batch_size=1) 函数调用启动微调过程。即使 batch_size 值为 1,TPU 上的分布式设置也能有效地处理此问题。 epochs 参数确定模型将迭代整个数据集的次数。

与微调后的模型互动

训练完成后,您可以使用 ChatState 类和多个辅助函数,如下所示。这段代码定义了几个用于与微调后的模型交互的实用工具函数。display_chat 函数用于以带颜色的 Markdown 格式显示用户提示和模型响应,使对话更易于阅读。to_markdown 函数将纯文本转换为 Markdown 格式。ChatState 类管理对话历史记录,遵循 Gemma 模型的 turn-based 对话格式化指南。它提供添加用户消息、添加模型响应以及检索完整对话历史记录的功能。send_message 函数将用户查询发送到微调后的模型,并返回模型的响应。

from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """
  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    """
    Initializes the chat state.

    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.

    Args:
        message: The user's message.

    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response
    self.add_to_history_as_model(result)
    return result

在本教程中,我们将专注于核心互动,以与微调后的模型互动。

# Assuming ChatState class is defined as in the original code
chat = ChatState(gemma_lm)
message = "Apakah website iSTTS?"
display_chat(message, chat.send_message(message))

chat.send_message(message) 函数将用户查询发送到微调后的 Gemma 模型,然后该模型将根据从微调数据中新获得的知识生成响应。您应该观察到该模型现在可以准确地回答有关 iSTTS 的问题,并从提供的 train_contexts 中提取信息。 例如,在微调之前,Gemma 模型可能无法准确回答关于 iSTTS 的问题,但在微调之后,它可以根据提供的上下文准确地回答,展示了微调的有效性。

保存微调后的权重

最后,至关重要的是要保存微调后的模型的权重,以便以后可以重复使用它们,而无需重新训练。由于我们使用了 LoRA,因此我们只需要保存 LoRA 权重,这比完整模型小得多。

gemma_lm.backbone.save_weights("gemma_lm_lora.weights.h5")

此命令保存特定于 LoRA 的权重。要再次使用此微调模型,您需要加载基础 Gemma 模型,然后将这些 LoRA 权重加载到其主干上。通过保存 LoRA 权重,您可以显著减小模型的大小,并