Gemma 2B Fine Tuned Lightweight model
Blog post source:
kaggle notebook for reusability: https://www.kaggle.com/code/kumarandatascientist/gemma-2b-fine-tuned-lightweight-model
Step 1: Configure GPU for Memory Growth
gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) print("GPU memory growth enabled.") except RuntimeError as e: print(e) else: print("No GPU found. Using CPU.")
Step 2: Enable Mixed Precision for Memory Optimization
policy = tf.keras.mixed_precision.Policy("mixed_float16") set_global_policy(policy) print(f"Mixed precision enabled with policy: {policy}")
Step 3: Load a Smaller Model Variant
try: gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_1b_en") print("Loaded Gemma LM (1B model).") except: gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en") print("Loaded Gemma LM (2B model).")
Step 4: Apply Low-Rank Adaptation (LoRA) for Reduced Parameters
gemma_lm.backbone.enable_lora(rank=2) print("LoRA enabled with rank=2.")
领英推荐
Step 5: Reduce Sequence Length for Lower Memory Usage
gemma_lm.preprocessor.sequence_length = 128 print("Sequence length set to 128.")
Step 6: Compile the Model with Optimized Settings
initializer = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05) gemma_lm.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5), weighted_metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], ) print("Model compiled successfully.")
Step 7: Save Optimized Versions of the Model
Saving Model Weights
weights_path = "gemma_lm_lightweight.weights.h5" gemma_lm.backbone.save_weights(weights_path) print(f"Model weights saved to: {weights_path}")
Saving Quantized TensorFlow Lite Model
converter = tf.lite.TFLiteConverter.from_keras_model(gemma_lm) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_model = converter.convert() tflite_path = "gemma_lm_lightweight_v2.tflite" with open(tflite_path, "wb") as f: f.write(quantized_model) print(f"Quantized model saved to: {tflite_path}")
Saving Backbone Only
backbone_path = "gemma_lm_lightweight_backbone.h5" gemma_lm.backbone.save(backbone_path) print(f"Backbone model saved to: {backbone_path}")
This structured code is optimized for memory, computational efficiency, and deployment versatility, addressing various stages of model training and optimization.