Google JAX Cookbook: Perform machine learning and numerical computing with combined capabilities of TensorFlow and NumPy
暫譯: Google JAX 食譜:結合 TensorFlow 和 NumPy 的機器學習與數值計算

Quent, Zephyr

  • 出版商: Gitforgits
  • 出版日期: 2024-10-30
  • 售價: $2,320
  • 貴賓價: 9.5$2,204
  • 語言: 英文
  • 頁數: 252
  • 裝訂: Quality Paper - also called trade paper
  • ISBN: 8197950415
  • ISBN-13: 9788197950414
  • 相關分類: PythonDeepLearningTensorFlowMachine Learning
  • 海外代購書籍(需單獨結帳)

相關主題

商品描述

This is the practical, solution-oriented book for every data scientists, machine learning engineers, and AI engineers to utilize the most of Google JAX for efficient and advanced machine learning. It covers essential tasks, troubleshooting scenarios, and optimization techniques to address common challenges encountered while working with JAX across machine learning and numerical computing projects.

The book starts with the move from NumPy to JAX. It introduces the best ways to speed up computations, handle data types, generate random numbers, and perform in-place operations. It then shows you how to use profiling techniques to monitor computation time and device memory, helping you to optimize training and performance. The debugging section provides clear and effective strategies for resolving common runtime issues, including shape mismatches, NaNs, and control flow errors. The book goes on to show you how to master Pytrees for data manipulation, integrate external functions through the Foreign Function Interface (FFI), and utilize advanced serialization and type promotion techniques for stable computations.

If you want to optimize training processes, this book has you covered. It includes recipes for efficient data loading, building custom neural networks, implementing mixed precision, and tracking experiments with Penzai. You'll learn how to visualize model performance and monitor metrics to assess training progress effectively. The recipes in this book tackle real-world scenarios and give users the power to fix issues and fine-tune models quickly.

Key Learnings

Get your calculations done faster by moving from NumPy to JAX's optimized framework.

Make your training pipelines more efficient by profiling how long things take and how much memory they use.

Use debugging techniques to fix runtime issues like shape mismatches and numerical instability.

Get to grips with Pytrees for managing complex, nested data structures across various machine learning tasks.

Use JAX's Foreign Function Interface (FFI) to bring in external functions and give your computational capabilities a boost.

Take advantage of mixed-precision training to speed up neural network computations without sacrificing model accuracy.

Keep your experiments on track with Penzai. This lets you reproduce results and monitor key metrics.

Create your own neural networks and optimizers directly in JAX so you have full control of the architecture.

Use serialization techniques to save, load, and transfer models and training checkpoints efficiently.

Table of Content

Transition NumPy to JAX

Profiling Computation and Device Memory

Debugging Runtime Values and Errors

Mastering Pytrees for Data Structures

Exporting and Serialization

Type Promotion Semantics and Mixed Precision

Integrating Foreign Functions (FFI)

Training Neural Networks with JAX

商品描述(中文翻譯)

這是一本實用且以解決方案為導向的書籍,適合每位數據科學家、機器學習工程師和人工智慧工程師充分利用 Google JAX 進行高效且先進的機器學習。書中涵蓋了基本任務、故障排除情境和優化技術,以解決在機器學習和數值計算專案中使用 JAX 時常遇到的挑戰。

本書首先介紹了從 NumPy 轉移到 JAX。它介紹了加速計算、處理數據類型、生成隨機數和執行原地操作的最佳方法。接著,書中展示了如何使用性能分析技術來監控計算時間和設備記憶體,幫助您優化訓練和性能。調試部分提供了清晰且有效的策略來解決常見的運行時問題,包括形狀不匹配、NaN 和控制流錯誤。本書接著展示了如何掌握 Pytrees 進行數據操作,通過外部函數介面 (Foreign Function Interface, FFI) 整合外部函數,以及利用先進的序列化和類型提升技術以實現穩定的計算

如果您想優化訓練過程,本書將為您提供幫助。它包括高效數據加載、自定義神經網絡構建、實現混合精度以及使用 Penzai 追蹤實驗的食譜。您將學會如何可視化模型性能並監控指標,以有效評估訓練進度。本書中的食譜針對現實情境,賦予用戶快速修復問題和微調模型的能力。

關鍵學習
- 通過從 NumPy 轉移到 JAX 的優化框架,讓您的計算更快完成。
- 通過分析任務所需的時間和使用的記憶體,使您的訓練管道更高效。
- 使用調試技術修復運行時問題,如形狀不匹配和數值不穩定。
- 熟悉 Pytrees,以管理各種機器學習任務中的複雜嵌套數據結構。
- 使用 JAX 的外部函數介面 (FFI) 引入外部函數,提升計算能力。
- 利用混合精度訓練,加速神經網絡計算而不犧牲模型準確性。
- 使用 Penzai 讓您的實驗保持在正軌,這樣您可以重現結果並監控關鍵指標。
- 直接在 JAX 中創建自己的神經網絡和優化器,以便完全控制架構。
- 使用序列化技術高效地保存、加載和轉移模型及訓練檢查點。

目錄
- 從 NumPy 轉移到 JAX
- 計算和設備記憶體的性能分析
- 調試運行時值和錯誤
- 掌握 Pytrees 進行數據結構管理
- 匯出和序列化
- 類型提升語義和混合精度
- 整合外部函數 (FFI)
- 使用 JAX 訓練神經網絡