Gemma 2B Fine Tuned Lightweight model

Gemma 2B Fine Tuned Lightweight model

Blog post source:

https://kumaran198726.blogspot.com/2024/12/gemma-2b-fine-tuned-lightweight-model.html

kaggle notebook for reusability: https://www.kaggle.com/code/kumarandatascientist/gemma-2b-fine-tuned-lightweight-model

Step 1: Configure GPU for Memory Growth

gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) print("GPU memory growth enabled.") except RuntimeError as e: print(e) else: print("No GPU found. Using CPU.")

  • Purpose: Ensures the GPU is set to dynamically allocate memory instead of pre-allocating all available GPU memory. This approach prevents memory wastage and allows multiple processes to use the GPU without running into memory allocation errors.
  • Technical Details:tf.config.list_physical_devices('GPU'): Lists available GPUs.tf.config.experimental.set_memory_growth(gpu, True): Allows TensorFlow to allocate GPU memory on demand.Fallback: If no GPU is found, the code defaults to CPU computation.


Step 2: Enable Mixed Precision for Memory Optimization

policy = tf.keras.mixed_precision.Policy("mixed_float16") set_global_policy(policy) print(f"Mixed precision enabled with policy: {policy}")

  • Purpose: Reduces memory usage and increases computational speed by using lower-precision data types (e.g., float16) where appropriate, while keeping critical calculations in higher precision (e.g., float32).
  • Technical Details:tf.keras.mixed_precision.Policy("mixed_float16"): Specifies the use of float16 for operations and float32 for accumulations.set_global_policy(policy): Globally applies the mixed precision policy.Mixed precision is especially effective on GPUs with Tensor Cores (e.g., NVIDIA Volta, Ampere).


Step 3: Load a Smaller Model Variant

try: gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_1b_en") print("Loaded Gemma LM (1B model).") except: gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en") print("Loaded Gemma LM (2B model).")

  • Purpose: Dynamically load a smaller model variant if possible, reducing memory and computational requirements. Falls back to a larger model if the smaller one is unavailable.
  • Technical Details:keras_nlp.models.GemmaCausalLM.from_preset: Loads a preconfigured language model (Gemma LM) with pretrained weights."gemma2_instruct_1b_en": A 1-billion parameter variant."gemma2_instruct_2b_en": A 2-billion parameter variant used as a fallback.


Step 4: Apply Low-Rank Adaptation (LoRA) for Reduced Parameters

gemma_lm.backbone.enable_lora(rank=2) print("LoRA enabled with rank=2.")

  • Purpose: Reduces the memory footprint of the model by adapting its parameter matrices using a low-rank decomposition.
  • Technical Details:LoRA modifies the transformer layers to optimize memory use while retaining performance.rank=2: Specifies the rank of the adaptation, balancing efficiency and accuracy.


Step 5: Reduce Sequence Length for Lower Memory Usage

gemma_lm.preprocessor.sequence_length = 128 print("Sequence length set to 128.")

  • Purpose: Lowers the memory usage during training or inference by reducing the number of tokens processed per sequence.
  • Technical Details:Shorter sequences mean fewer computations, leading to faster and more memory-efficient runs.The reduction from a typical length (e.g., 512 or 1024) to 128 is significant in terms of resource savings.


Step 6: Compile the Model with Optimized Settings

initializer = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05) gemma_lm.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5), weighted_metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], ) print("Model compiled successfully.")

  • Purpose: Prepares the model for training or inference by configuring loss functions, optimizers, and metrics.
  • Technical Details:Initializer:tf.keras.initializers.TruncatedNormal: Ensures model weights start close to zero, improving convergence.Loss:SparseCategoricalCrossentropy(from_logits=True): Suitable for multi-class classification tasks where outputs are logits.Optimizer:Adam(learning_rate=3e-5): A widely used optimizer balancing efficiency and convergence.Metrics:SparseCategoricalAccuracy: Tracks classification accuracy for sparse label formats.


Step 7: Save Optimized Versions of the Model

Saving Model Weights

weights_path = "gemma_lm_lightweight.weights.h5" gemma_lm.backbone.save_weights(weights_path) print(f"Model weights saved to: {weights_path}")

  • Purpose: Saves only the weights of the backbone model to a lightweight file for reuse or transfer.
  • Technical Details:.h5 format: Common for Keras models and weights storage.

Saving Quantized TensorFlow Lite Model

converter = tf.lite.TFLiteConverter.from_keras_model(gemma_lm) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_model = converter.convert() tflite_path = "gemma_lm_lightweight_v2.tflite" with open(tflite_path, "wb") as f: f.write(quantized_model) print(f"Quantized model saved to: {tflite_path}")

  • Purpose: Converts the model to TensorFlow Lite format with quantization, making it suitable for deployment on resource-constrained devices.
  • Technical Details:tf.lite.TFLiteConverter.from_keras_model: Converts a Keras model to TensorFlow Lite.converter.optimizations = [tf.lite.Optimize.DEFAULT]: Applies optimizations such as quantization to reduce size and improve performance..tflite: A lightweight format for deployment.

Saving Backbone Only

backbone_path = "gemma_lm_lightweight_backbone.h5" gemma_lm.backbone.save(backbone_path) print(f"Backbone model saved to: {backbone_path}")

  • Purpose: Saves only the backbone of the model, excluding preprocessing or output layers.
  • Technical Details:Backbone-saving allows reuse of core layers for transfer learning or fine-tuning in other tasks.


This structured code is optimized for memory, computational efficiency, and deployment versatility, addressing various stages of model training and optimization.


要查看或添加评论,请登录

Kumaran Kanniappan ( I / we / Human )的更多文章

社区洞察

其他会员也浏览了