In [ ]:
# ==============================================================================
# 0. SETUP & LIBRARIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')
# Install specific libraries for Prithvi
!pip install -q uv
!uv pip install rasterio geopandas geojson torchinfo tslearn transformers segmentation-models-pytorch
import ee
import geemap
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel, AutoConfig
# Authenticate & Initialize
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]') # Using your specified project ID
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(f" Earth Engine Initialized with project: [REDACTED_FOR_SECURITY]")
# ==============================================================================
# 1. GEE DATA PIPELINE (MATURE WHEAT PHASE: JAN - APR)
# ==============================================================================
# ------------------- USER INPUTS -------------------
# You want 2024 dataset mature state -> Jan 2025 to Apr 2025
year = 2024
wheat_mask = ee.Image(f'projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved') # Verify path
roi = ee.FeatureCollection('projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab')
# ------------------- PARAMETERS -------------------
# UPDATED: Mature state only (Jan 1 to Apr 15)
START_DATE = f'{year+1}-01-01'
END_DATE = f'{year+1}-04-16'
print(f" Data Range: {START_DATE} to {END_DATE} (Mature Phase)")
S1_MIN_DB = -25
S1_MAX_DB = 0
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
# ------------------- SENTINEL-2 PROCESSING -------------------
MAX_CLOUD_PROB = 70
s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100)))
s2Clouds = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi))
s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
primary=s2Raw,
secondary=s2Clouds,
condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))
def maskS2Clouds(img):
cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
isProbCloud = cloudProb.gt(MAX_CLOUD_PROB)
optical = img.select(s2Bands).multiply(0.0001) # Scale 0-1
finalMask = isProbCloud.Not()
ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
return (optical.addBands(ndvi)
.updateMask(finalMask)
.copyProperties(img, ['system:time_start']))
s2Clean = s2Joined.map(maskS2Clouds)
# ------------------- SENTINEL-1 PROCESSING -------------------
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi)
.filter(ee.Filter.eq('instrumentMode', 'IW'))
.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
.filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING'))
.select(['VV', 'VH'])
.map(lambda img: img.rename(['VV_sar', 'VH_sar'])))
# ------------------- INTERVAL GENERATION -------------------
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)
while curDate < endDateObj:
yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day
if dayVal <= 15:
nextDate = datetime.datetime(yearVal, monthVal, 16)
suffix = '1'
else:
if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
suffix = '2'
if nextDate > endDateObj: nextDate = endDateObj
startStr = curDate.strftime('%Y-%m-%d')
endStr = (nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d')
label = f'{yearVal}_{monthVal:02d}_{suffix}'
if curDate < nextDate:
intervals.append([startStr, endStr, label])
curDate = nextDate
print(f" Generated {len(intervals)} fortnightly intervals")
eeIntervals = ee.List(intervals)
# ------------------- COMPOSITE FUNCTION -------------------
def makeComposite(item):
item = ee.List(item)
start = ee.Date(item.get(0))
end = ee.Date(item.get(1)).advance(1, 'day')
label = ee.String(item.get(2))
# S2 Composite (Median to avoid cloud artifacts better in winter)
s2 = s2Clean.filterDate(start, end).median()
# Fill gaps
s2 = ee.Image(ee.Algorithms.If(
s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)
))
# S1 Composite (Mean)
s1_comp = s1.filterDate(start, end).mean()
s1_comp = ee.Image(ee.Algorithms.If(
s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])
))
return s2.addBands(s1_comp).set('label', label).clip(roi)
fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))
# ------------------- STACKING & DOWNLOAD -------------------
def stack_collection(collection):
col_list = collection.toList(collection.size())
count = collection.size().getInfo()
if count == 0: raise ValueError("Empty collection")
stack = ee.Image(col_list.get(0))
for i in range(1, count):
stack = stack.addBands(ee.Image(col_list.get(i)))
return stack
print("Stacking collection...")
stacked_input = stack_collection(fortnightly_col)
input_band_names = stacked_input.bandNames().getInfo()
print(f"Total Input Channels: {len(input_band_names)}")
# Download Sample
# Using a buffer to get a reasonable training patch
sample_roi = roi.geometry().centroid(1).buffer(2500).bounds()
print(" Downloading data (may take a minute)...")
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
# Cleaning
ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
# Resizing to 224x224 patch friendly size if needed, but we tile 64x64
print(f"Downloaded Shape: {ds.shape}")
# Tiling function
def create_tiles(image, mask, patch_size=64):
h, w, c = image.shape
patches_x, patches_y = [], []
for y in range(0, h - patch_size + 1, patch_size):
for x in range(0, w - patch_size + 1, patch_size):
img_p = image[y:y+patch_size, x:x+patch_size, :]
mask_p = mask[y:y+patch_size, x:x+patch_size]
if np.mean(img_p) > 0.01: # Filter empty
patches_x.append(img_p)
patches_y.append(mask_p)
return np.array(patches_x), np.array(patches_y)
print("Tiling...")
X_np, y_np = create_tiles(ds, lb, patch_size=64)
# Transpose to (N, C, H, W) for PyTorch
X_data = torch.tensor(np.transpose(X_np, (0, 3, 1, 2)), dtype=torch.float32)
y_data = torch.tensor(np.expand_dims(y_np, 1), dtype=torch.float32)
print(f" Final Dataset: {X_data.shape}")
Mounted at /content/drive ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.2/22.2 MB 35.3 MB/s eta 0:00:00 Using Python 3.12.12 environment at: /usr Resolved 70 packages in 1.57s Prepared 4 packages in 54ms Installed 4 packages in 4ms + geojson==3.2.0 + segmentation-models-pytorch==0.5.0 + torchinfo==1.8.0 + tslearn==0.7.0 ✅ Earth Engine Initialized with project: phd0-473604 📅 Data Range: 2025-01-01 to 2025-04-16 (Mature Phase) ✅ Generated 7 fortnightly intervals Stacking collection... Total Input Channels: 91 ⏳ Downloading data (may take a minute)...
--------------------------------------------------------------------------- HttpError Traceback (most recent call last) /usr/local/lib/python3.12/dist-packages/ee/data.py in _execute_cloud_call(call, num_retries) 407 try: --> 408 return call.execute(num_retries=num_retries) 409 except googleapiclient.errors.HttpError as e: /usr/local/lib/python3.12/dist-packages/googleapiclient/_helpers.py in positional_wrapper(*args, **kwargs) 129 logger.warning(message) --> 130 return wrapped(*args, **kwargs) 131 /usr/local/lib/python3.12/dist-packages/googleapiclient/http.py in execute(self, http, num_retries) 937 if resp.status >= 300: --> 938 raise HttpError(resp, content, uri=self.uri) 939 return self.postproc(resp, content) HttpError: <HttpError 400 when requesting https://earthengine.googleapis.com/v1/projects/phd0-473604/image:computePixels? returned "Total request size (230704642 bytes) must be less than or equal to 50331648 bytes.". Details: "Total request size (230704642 bytes) must be less than or equal to 50331648 bytes."> During handling of the above exception, another exception occurred: EEException Traceback (most recent call last) /usr/local/lib/python3.12/dist-packages/geemap/common.py in ee_to_numpy(ee_object, region, scale, bands, **kwargs) 3133 try: -> 3134 struct_array = ee.data.computePixels(kwargs) 3135 array = np.dstack(([struct_array[band] for band in struct_array.dtype.names])) /usr/local/lib/python3.12/dist-packages/ee/data.py in computePixels(params) 953 _maybe_populate_workload_tag(params) --> 954 data = _execute_cloud_call( 955 _get_cloud_projects_raw() /usr/local/lib/python3.12/dist-packages/ee/data.py in _execute_cloud_call(call, num_retries) 409 except googleapiclient.errors.HttpError as e: --> 410 raise _translate_cloud_exception(e) # pylint: disable=raise-missing-from 411 EEException: Total request size (230704642 bytes) must be less than or equal to 50331648 bytes. During handling of the above exception, another exception occurred: Exception Traceback (most recent call last) /tmp/ipython-input-3007931609.py in <cell line: 0>() 164 165 print("⏳ Downloading data (may take a minute)...") --> 166 ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10) 167 lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10) 168 /usr/local/lib/python3.12/dist-packages/geemap/common.py in ee_to_numpy(ee_object, region, scale, bands, **kwargs) 3136 return array 3137 except Exception as e: -> 3138 raise Exception(e) 3139 3140 Exception: Total request size (230704642 bytes) must be less than or equal to 50331648 bytes.
In [ ]:
# ==============================================================================
# CELL 1: SETUP & AUTHENTICATION
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')
# Install dependencies for Prithvi
!pip install -q uv
!uv pip install rasterio geopandas geojson torchinfo tslearn transformers segmentation-models-pytorch
import ee
import geemap
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
# Authenticate & Initialize with YOUR Project ID
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Earth Engine Initialized with project: [REDACTED_FOR_SECURITY]")
In [ ]:
# ==============================================================================
# CELL 2: GEE DATA DOWNLOAD PIPELINE
# ==============================================================================
# ------------------- USER INPUTS -------------------
year = 2024
# We use assets from the other project, but COMPUTE on [REDACTED_FOR_SECURITY]
wheat_mask_path = 'projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved'
roi_path = 'projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab'
# Load Assets
wheat_mask = ee.Image(wheat_mask_path)
roi = ee.FeatureCollection(roi_path)
# ------------------- PARAMETERS (Mature Phase) -------------------
START_DATE = f'{year+1}-01-01' # Jan 2025
END_DATE = f'{year+1}-04-16' # Apr 2025
print(f" Data Range: {START_DATE} to {END_DATE} (Mature Wheat Phase)")
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
MAX_CLOUD_PROB = 70
# ------------------- SENTINEL-2 PROCESSING -------------------
s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100)))
s2Clouds = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi))
s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
primary=s2Raw,
secondary=s2Clouds,
condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))
def maskS2Clouds(img):
cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
isProbCloud = cloudProb.gt(MAX_CLOUD_PROB)
optical = img.select(s2Bands).multiply(0.0001)
finalMask = isProbCloud.Not()
ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
return (optical.addBands(ndvi)
.updateMask(finalMask)
.copyProperties(img, ['system:time_start']))
s2Clean = s2Joined.map(maskS2Clouds)
# ------------------- SENTINEL-1 PROCESSING -------------------
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi)
.filter(ee.Filter.eq('instrumentMode', 'IW'))
.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
.filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING'))
.select(['VV', 'VH'])
.map(lambda img: img.rename(['VV_sar', 'VH_sar'])))
# ------------------- INTERVALS -------------------
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)
while curDate < endDateObj:
yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day
if dayVal <= 15:
nextDate = datetime.datetime(yearVal, monthVal, 16)
suffix = '1'
else:
if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
suffix = '2'
if nextDate > endDateObj: nextDate = endDateObj
intervals.append([curDate.strftime('%Y-%m-%d'),
(nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d'),
f'{yearVal}_{monthVal:02d}_{suffix}'])
curDate = nextDate
eeIntervals = ee.List(intervals)
# ------------------- COMPOSITES -------------------
def makeComposite(item):
item = ee.List(item)
start = ee.Date(item.get(0))
end = ee.Date(item.get(1)).advance(1, 'day')
label = ee.String(item.get(2))
s2 = s2Clean.filterDate(start, end).median()
s2 = ee.Image(ee.Algorithms.If(s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)))
s1_comp = s1.filterDate(start, end).mean()
s1_comp = ee.Image(ee.Algorithms.If(s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])))
return s2.addBands(s1_comp).set('label', label).clip(roi)
fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))
# Stack
def stack_collection(collection):
col_list = collection.toList(collection.size())
count = collection.size().getInfo()
stack = ee.Image(col_list.get(0))
for i in range(1, count): stack = stack.addBands(ee.Image(col_list.get(i)))
return stack
print("Stacking...")
stacked_input = stack_collection(fortnightly_col)
input_band_names = stacked_input.bandNames().getInfo()
print(f" Total Channels: {len(input_band_names)}")
# Download Sample
sample_roi = roi.geometry().centroid(1).buffer(2500).bounds()
print(" Downloading patch...")
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
print(f" Downloaded Shape: {ds.shape}")
# ... (Previous Cell 2 code remains the same up to 'Download Sample') ...
# --- MODIFIED DOWNLOAD SECTION ---
# Reduce buffer from 2500 -> 1000 to fit in 50MB RAM limit
print(" ROI too large for direct download. Using smaller sample for training prototype.")
sample_roi = roi.geometry().centroid(1).buffer(1000).bounds()
print(" Downloading patch (smaller area)...")
try:
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
except Exception as e:
print(f" Still too large? Error: {e}")
# Fallback to even smaller if 1000 fails
print("Trying ultra-small patch (500m buffer)...")
sample_roi = roi.geometry().centroid(1).buffer(500).bounds()
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
print(f" Downloaded Shape: {ds.shape}")
In [ ]:
# ==============================================================================
# CELL 2: GEE DATA DOWNLOAD PIPELINE
# ==============================================================================
# ------------------- USER INPUTS -------------------
year = 2024
# We use assets from the other project, but COMPUTE on [REDACTED_FOR_SECURITY]
wheat_mask_path = 'projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved'
roi_path = 'projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab'
# Load Assets
wheat_mask = ee.Image(wheat_mask_path)
roi = ee.FeatureCollection(roi_path)
# ------------------- PARAMETERS (Mature Phase) -------------------
START_DATE = f'{year+1}-01-01' # Jan 2025
END_DATE = f'{year+1}-04-16' # Apr 2025
print(f" Data Range: {START_DATE} to {END_DATE} (Mature Wheat Phase)")
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
MAX_CLOUD_PROB = 70
# ------------------- SENTINEL-2 PROCESSING -------------------
s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100)))
s2Clouds = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi))
s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
primary=s2Raw,
secondary=s2Clouds,
condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))
def maskS2Clouds(img):
cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
isProbCloud = cloudProb.gt(MAX_CLOUD_PROB)
optical = img.select(s2Bands).multiply(0.0001)
finalMask = isProbCloud.Not()
ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
return (optical.addBands(ndvi)
.updateMask(finalMask)
.copyProperties(img, ['system:time_start']))
s2Clean = s2Joined.map(maskS2Clouds)
# ------------------- SENTINEL-1 PROCESSING -------------------
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
.filterDate(START_DATE, END_DATE)
.filterBounds(roi)
.filter(ee.Filter.eq('instrumentMode', 'IW'))
.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
.filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING'))
.select(['VV', 'VH'])
.map(lambda img: img.rename(['VV_sar', 'VH_sar'])))
# ------------------- INTERVALS -------------------
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)
while curDate < endDateObj:
yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day
if dayVal <= 15:
nextDate = datetime.datetime(yearVal, monthVal, 16)
suffix = '1'
else:
if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
suffix = '2'
if nextDate > endDateObj: nextDate = endDateObj
intervals.append([curDate.strftime('%Y-%m-%d'),
(nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d'),
f'{yearVal}_{monthVal:02d}_{suffix}'])
curDate = nextDate
eeIntervals = ee.List(intervals)
# ------------------- COMPOSITES -------------------
def makeComposite(item):
item = ee.List(item)
start = ee.Date(item.get(0))
end = ee.Date(item.get(1)).advance(1, 'day')
label = ee.String(item.get(2))
s2 = s2Clean.filterDate(start, end).median()
s2 = ee.Image(ee.Algorithms.If(s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)))
s1_comp = s1.filterDate(start, end).mean()
s1_comp = ee.Image(ee.Algorithms.If(s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])))
return s2.addBands(s1_comp).set('label', label).clip(roi)
fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))
# Stack
def stack_collection(collection):
col_list = collection.toList(collection.size())
count = collection.size().getInfo()
stack = ee.Image(col_list.get(0))
for i in range(1, count): stack = stack.addBands(ee.Image(col_list.get(i)))
return stack
print("Stacking...")
stacked_input = stack_collection(fortnightly_col)
input_band_names = stacked_input.bandNames().getInfo()
print(f" Total Channels: {len(input_band_names)}")
# --- MODIFIED DOWNLOAD SECTION ---
# Reduced buffer from 2500 -> 1000 to fit in 50MB RAM limit
print(" Using reduced 1000m buffer to fit memory limit...")
sample_roi = roi.geometry().centroid(1).buffer(1000).bounds()
print(" Downloading patch...")
try:
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
except Exception as e:
print(f" 1000m failed ({e}), trying 500m...")
sample_roi = roi.geometry().centroid(1).buffer(500).bounds()
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
print(f" Downloaded Shape: {ds.shape}")
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')
# Install libraries for Prithvi & Geodata
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch
import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt
print(" Environment Ready")
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')
# Install libraries for Prithvi & Geodata
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch
import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt
# Authenticate GEE (Required if using any GEE features, though data is in Drive)
import ee
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Environment Ready")
In [ ]:
import os
# Your folder path
folder_path = '/content/drive/MyDrive/Punjab Wheat Mask_Binary/'
print(f" Checking files in: {folder_path}\n")
try:
files = os.listdir(folder_path)
for f in files:
print(f" {f}")
except FileNotFoundError:
print(" Folder not found. Please check the path.")
In [ ]:
# ==============================================================================
# CELL 2: GENERATE DATA FROM SATELLITE (MATURE PHASE)
# ==============================================================================
import datetime # <--- ADDED THIS IMPORT
import ee
import geemap
import numpy as np
import torch # Ensure torch is imported if used later in this cell
# 1. Define ROI & Mask
wheat_mask = ee.Image('projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved')
roi = ee.FeatureCollection('projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab')
# 2. Define Timeframe: MATURE STATE (Jan - Apr)
year = 2024
START_DATE = f'{year+1}-01-01' # Jan 1, 2025
END_DATE = f'{year+1}-04-16' # Apr 16, 2025
print(f" Generating data for Mature Phase: {START_DATE} to {END_DATE}")
# 3. Sentinel-2 (Optical) Processing
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
s2Raw = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filterDate(START_DATE, END_DATE).filterBounds(roi)
s2Clouds = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY').filterDate(START_DATE, END_DATE).filterBounds(roi)
s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
primary=s2Raw, secondary=s2Clouds,
condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))
def maskS2Clouds(img):
cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
isProbCloud = cloudProb.gt(70)
return img.select(s2Bands).multiply(0.0001).updateMask(isProbCloud.Not())
s2Clean = s2Joined.map(maskS2Clouds)
# 4. Sentinel-1 (Radar) Processing
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
.filterDate(START_DATE, END_DATE).filterBounds(roi)
.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
.select(['VV', 'VH']).map(lambda img: img.rename(['VV_sar', 'VH_sar'])))
# 5. Create Fortnightly Composites
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)
while curDate < endDateObj:
# Logic to split month into 1-15 and 16-end
yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day
if dayVal <= 15: nextDate = datetime.datetime(yearVal, monthVal, 16); suffix='1'
else:
if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
suffix='2'
if nextDate > endDateObj: nextDate = endDateObj
intervals.append([curDate.strftime('%Y-%m-%d'), (nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d')])
curDate = nextDate
print(f" Created {len(intervals)} fortnightly intervals")
# Stack Function
eeIntervals = ee.List(intervals)
def makeComposite(item):
start, end = ee.Date(ee.List(item).get(0)), ee.Date(ee.List(item).get(1)).advance(1, 'day')
s2 = s2Clean.filterDate(start, end).median()
s2 = ee.Image(ee.Algorithms.If(s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)))
s1_comp = s1.filterDate(start, end).mean()
s1_comp = ee.Image(ee.Algorithms.If(s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])))
return s2.addBands(s1_comp)
fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))
# Flatten to single image stack
stack = ee.Image(fortnightly_col.toBands())
print(f" Total Input Channels: {stack.bandNames().size().getInfo()}")
# 6. DOWNLOAD PATCH (Using safe buffer size)
print(" Downloading Patch (1000m buffer)...")
sample_roi = roi.geometry().centroid(1).buffer(1000).bounds()
ds = geemap.ee_to_numpy(stack, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
# Clean
ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
print(f" Downloaded Data Shape: {ds.shape}")
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
# Install required libraries
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch opencv-python-headless
import ee
import geemap
import rasterio
from rasterio.windows import from_bounds
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt
# Authenticate with YOUR project
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Environment Ready")
Mounted at /content/drive ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.2/22.2 MB 118.2 MB/s eta 0:00:00 Using Python 3.12.12 environment at: /usr Resolved 63 packages in 314ms Prepared 2 packages in 27ms Installed 2 packages in 3ms + segmentation-models-pytorch==0.5.0 + torchinfo==1.8.0 ✅ Environment Ready
In [ ]:
# ==============================================================================
# CELL 2: HYBRID DATA LOADING (Mature Phase)
# ==============================================================================
# 1. Setup Path
mask_path = '/content/drive/MyDrive/Punjab Wheat Mask_Binary/Punjab Mask 2024.tif'
# 2. Define Sample Area from Mask (Drive)
with rasterio.open(mask_path) as src:
full_bounds = src.bounds
crs = src.crs
# Pick center
center_x = (full_bounds.left + full_bounds.right) / 2
center_y = (full_bounds.bottom + full_bounds.top) / 2
# Define small window (approx 5km x 5km)
half_size = 0.04
sample_bounds = (
center_x - half_size, center_y - half_size,
center_x + half_size, center_y + half_size
)
print(f" Selecting sample area from mask...")
window = from_bounds(*sample_bounds, transform=src.transform)
mask_data = src.read(1, window=window)
# Ensure binary
mask_data = np.where(mask_data > 0, 1, 0).astype(np.float32)
print(f" Mask Loaded! Shape: {mask_data.shape}")
# 3. Download Matching Satellite Data (GEE)
roi = ee.Geometry.Rectangle(
[sample_bounds[0], sample_bounds[1], sample_bounds[2], sample_bounds[3]],
proj=str(crs), geodesic=False
)
# MATURE PHASE (Jan - Apr) - adjusted for 2024/2025 crop cycle
# Using 2024 crop mask implies crop was sown late 2023, harvested April 2024.
# So Mature Phase is Jan 2024 - Apr 2024.
START_DATE = '2024-01-01'
END_DATE = '2024-04-15'
print(f" Downloading Sentinel-2 (Mature Phase: {START_DATE} to {END_DATE})...")
# Create a Composite (Median) to minimize clouds
# Selecting bands compatible with Prithvi (Blue, Green, Red, NIR, SWIR, SWIR)
# S2 Bands: B2, B3, B4, B8A, B11, B12
s2_img = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
.filterBounds(roi)
.filterDate(START_DATE, END_DATE)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20))
.median()
.select(['B2', 'B3', 'B4', 'B8A', 'B11', 'B12'])
.clip(roi))
# Download
image_data = geemap.ee_to_numpy(s2_img, region=roi, scale=10)
# 4. Handle NaNs & Normalize
image_data = np.nan_to_num(image_data, nan=0.0)
# Sentinel-2 is 0-10000 approx, clip to 0-3000 (reflectance 0.3) for visualization/training
image_data = np.clip(image_data / 3000.0, 0, 1)
print(f" Satellite Data Downloaded! Shape: {image_data.shape}")
# 5. Fix Shape Mismatch
# Resize Satellite Image to match Mask exactly
target_h, target_w = mask_data.shape
if image_data.shape[:2] != (target_h, target_w):
print(f" Resizing Image: {image_data.shape[:2]} -> {(target_h, target_w)}")
image_data = cv2.resize(image_data, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
print(f" Final Aligned Image Shape: {image_data.shape}")
📍 Selecting sample area from mask... ✅ Mask Loaded! Shape: (891, 891) ⏳ Downloading Sentinel-2 (Mature Phase: 2024-01-01 to 2024-04-15)... ✅ Satellite Data Downloaded! Shape: (891, 892, 6) ⚠️ Resizing Image: (891, 892) -> (891, 891) ✅ Final Aligned Image Shape: (891, 891, 6)
In [ ]:
# ==============================================================================
# CELL 3: TILING
# ==============================================================================
def create_tiles(image, mask, patch_size=64):
h, w, c = image.shape
patches_x, patches_y = [], []
for y in range(0, h - patch_size + 1, patch_size):
for x in range(0, w - patch_size + 1, patch_size):
img_p = image[y:y+patch_size, x:x+patch_size, :]
mask_p = mask[y:y+patch_size, x:x+patch_size]
# Filter empty patches
if np.mean(img_p) > 0.01:
patches_x.append(img_p)
patches_y.append(mask_p)
return np.array(patches_x), np.array(patches_y)
print(" Creating tiles...")
X_np, y_np = create_tiles(image_data, mask_data, patch_size=64)
# Convert to Tensor (N, C, H, W)
X_data = torch.tensor(np.transpose(X_np, (0, 3, 1, 2)), dtype=torch.float32)
y_data = torch.tensor(np.expand_dims(y_np, 1), dtype=torch.float32)
print(f" Dataset Ready: {X_data.shape}")
✂️ Creating tiles... ✅ Dataset Ready: torch.Size([169, 6, 64, 64])
In [ ]:
# ==============================================================================
# CELL 4: PRITHVI MODEL (MANUAL DEFINITION)
# ==============================================================================
import torch
import torch.nn as nn
from functools import partial
# --- 1. Define Standard ViT Components (Prithvi Architecture) ---
# We define the Transformer blocks manually to avoid loading errors
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class PrithviBackbone(nn.Module):
""" Simplified Prithvi-100M Backbone for 64x64 Input """
def __init__(self, img_size=64, patch_size=4, in_chans=6, embed_dim=768, depth=12, num_heads=12):
super().__init__()
# Patch Embed
self.num_patches = (img_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# Positional Embed
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_drop = nn.Dropout(p=0.0)
# Transformer Blocks
self.blocks = nn.ModuleList([
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4., qkv_bias=True)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
# x: [B, 6, 64, 64]
B = x.shape[0]
x = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, 256, 768]
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:, :x.shape[1], :]
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x # [B, 257, 768]
# --- 2. Main Segmentation Model ---
class PrithviSegmentation(nn.Module):
def __init__(self, in_channels, out_classes=1):
super().__init__()
print(" Initializing Prithvi Backbone from scratch...")
# Adapter: Your N bands -> 6 bands
self.input_adapter = nn.Conv2d(in_channels, 6, kernel_size=1)
# Manual Prithvi Backbone (No Hugging Face loading needed)
# Matches Prithvi-100M specs: dim=768, depth=12, heads=12
self.encoder = PrithviBackbone(
img_size=64,
patch_size=4, # 16x16 tokens
in_chans=6,
embed_dim=768,
depth=12,
num_heads=12
)
# Decoder
embed_dim = 768
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embed_dim, 256, kernel_size=2, stride=2), # 16->32
nn.BatchNorm2d(256), nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 32->64
nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, out_classes, kernel_size=1)
)
def forward(self, x):
# Adapt Input
x_adapted = self.input_adapter(x) # [B, 6, 64, 64]
# Encode
features = self.encoder(x_adapted) # [B, 257, 768]
# Reshape (Remove CLS)
features = features[:, 1:, :] # [B, 256, 768]
B, N, C = features.shape
H = int(N**0.5) # 16
features = features.permute(0, 2, 1).reshape(B, C, H, H) # [B, 768, 16, 16]
# Decode
return self.decoder(features)
print(" Prithvi Model Defined Manually (No HF Token Needed)")
✅ Prithvi Model Defined Manually (No HF Token Needed)
In [ ]:
# ==============================================================================
# CELL 5: TRAINING (LOG EVERY EPOCH)
# ==============================================================================
import time
# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 200 # Targeted for ~20 mins on GPU
PATIENCE = 25
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Training on: {device}")
if device.type == 'cpu':
print(" WARNING: You are running on CPU. This will be very slow.")
# Splits
dataset = TensorDataset(X_data, y_data)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
# Model
model = PrithviSegmentation(in_channels=X_data.shape[1], out_classes=1).to(device)
# Loss & Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
# Loop Variables
train_losses, val_losses = [], []
best_loss = float('inf')
patience_counter = 0
print(f"\nSTARTING TRAINING ({EPOCHS} Epochs)...")
start_time = time.time()
for epoch in range(EPOCHS):
# --- TRAIN ---
model.train()
running_loss = 0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
preds = model(X)
loss = criterion(preds, y)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
# --- VALIDATE ---
model.eval()
running_val_loss = 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
preds = model(X)
running_val_loss += criterion(preds, y).item()
avg_val_loss = running_val_loss / len(val_loader)
val_losses.append(avg_val_loss)
# --- SCHEDULER & LOGGING ---
scheduler.step(avg_val_loss)
elapsed = (time.time() - start_time) / 60
current_lr = optimizer.param_groups[0]['lr']
# Print EVERY Epoch
print(f"Epoch {epoch+1}/{EPOCHS} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.1e} | Time: {elapsed:.1f} min")
# --- EARLY STOPPING ---
if avg_val_loss < best_loss:
best_loss = avg_val_loss
patience_counter = 0
torch.save(model.state_dict(), 'best_prithvi.pth')
else:
patience_counter += 1
if patience_counter >= PATIENCE:
print(f" Early stopping triggered at Epoch {epoch+1}")
break
total_time = (time.time() - start_time) / 60
print(f"\n Training Complete in {total_time:.1f} minutes.")
🚀 Training on: cuda ⏳ Initializing Prithvi Backbone from scratch... STARTING TRAINING (200 Epochs)... Epoch 1/200 | Train: 0.7448 | Val: 0.7152 | LR: 1.0e-04 | Time: 0.1 min Epoch 2/200 | Train: 0.6588 | Val: 0.6610 | LR: 1.0e-04 | Time: 0.3 min Epoch 3/200 | Train: 0.5964 | Val: 0.6010 | LR: 1.0e-04 | Time: 0.4 min Epoch 4/200 | Train: 0.5584 | Val: 0.5773 | LR: 1.0e-04 | Time: 0.5 min Epoch 5/200 | Train: 0.5314 | Val: 0.5717 | LR: 1.0e-04 | Time: 0.6 min Epoch 6/200 | Train: 0.5267 | Val: 0.5048 | LR: 1.0e-04 | Time: 0.7 min Epoch 7/200 | Train: 0.5313 | Val: 0.4976 | LR: 1.0e-04 | Time: 0.9 min Epoch 8/200 | Train: 0.5173 | Val: 0.5684 | LR: 1.0e-04 | Time: 1.0 min Epoch 9/200 | Train: 0.5065 | Val: 0.5242 | LR: 1.0e-04 | Time: 1.1 min Epoch 10/200 | Train: 0.5134 | Val: 0.5212 | LR: 1.0e-04 | Time: 1.2 min Epoch 11/200 | Train: 0.5084 | Val: 0.4881 | LR: 1.0e-04 | Time: 1.3 min Epoch 12/200 | Train: 0.4999 | Val: 0.4829 | LR: 1.0e-04 | Time: 1.4 min Epoch 13/200 | Train: 0.4968 | Val: 0.6236 | LR: 1.0e-04 | Time: 1.5 min Epoch 14/200 | Train: 0.4996 | Val: 0.4786 | LR: 1.0e-04 | Time: 1.6 min Epoch 15/200 | Train: 0.4978 | Val: 0.5607 | LR: 1.0e-04 | Time: 1.7 min Epoch 16/200 | Train: 0.4977 | Val: 0.4824 | LR: 1.0e-04 | Time: 1.8 min Epoch 17/200 | Train: 0.4975 | Val: 0.5084 | LR: 1.0e-04 | Time: 1.9 min Epoch 18/200 | Train: 0.4974 | Val: 0.5067 | LR: 5.0e-05 | Time: 2.0 min Epoch 19/200 | Train: 0.4933 | Val: 0.4945 | LR: 5.0e-05 | Time: 2.1 min Epoch 20/200 | Train: 0.4859 | Val: 0.4771 | LR: 5.0e-05 | Time: 2.2 min Epoch 21/200 | Train: 0.4934 | Val: 0.4892 | LR: 5.0e-05 | Time: 2.3 min Epoch 22/200 | Train: 0.4865 | Val: 0.4907 | LR: 5.0e-05 | Time: 2.4 min Epoch 23/200 | Train: 0.4876 | Val: 0.4869 | LR: 5.0e-05 | Time: 2.5 min Epoch 24/200 | Train: 0.4860 | Val: 0.4779 | LR: 2.5e-05 | Time: 2.6 min Epoch 25/200 | Train: 0.4917 | Val: 0.4879 | LR: 2.5e-05 | Time: 2.7 min Epoch 26/200 | Train: 0.4845 | Val: 0.4937 | LR: 2.5e-05 | Time: 2.7 min Epoch 27/200 | Train: 0.4820 | Val: 0.4807 | LR: 2.5e-05 | Time: 2.8 min Epoch 28/200 | Train: 0.4916 | Val: 0.4817 | LR: 1.3e-05 | Time: 2.9 min Epoch 29/200 | Train: 0.4783 | Val: 0.4791 | LR: 1.3e-05 | Time: 3.0 min Epoch 30/200 | Train: 0.4813 | Val: 0.4824 | LR: 1.3e-05 | Time: 3.1 min Epoch 31/200 | Train: 0.4807 | Val: 0.4873 | LR: 1.3e-05 | Time: 3.2 min Epoch 32/200 | Train: 0.4844 | Val: 0.4828 | LR: 6.3e-06 | Time: 3.3 min Epoch 33/200 | Train: 0.4899 | Val: 0.4814 | LR: 6.3e-06 | Time: 3.4 min Epoch 34/200 | Train: 0.4819 | Val: 0.4830 | LR: 6.3e-06 | Time: 3.5 min Epoch 35/200 | Train: 0.4892 | Val: 0.4821 | LR: 6.3e-06 | Time: 3.6 min Epoch 36/200 | Train: 0.4855 | Val: 0.4829 | LR: 3.1e-06 | Time: 3.7 min Epoch 37/200 | Train: 0.4824 | Val: 0.4818 | LR: 3.1e-06 | Time: 3.8 min Epoch 38/200 | Train: 0.4853 | Val: 0.4799 | LR: 3.1e-06 | Time: 3.9 min Epoch 39/200 | Train: 0.4857 | Val: 0.4818 | LR: 3.1e-06 | Time: 4.0 min Epoch 40/200 | Train: 0.4790 | Val: 0.4810 | LR: 1.6e-06 | Time: 4.1 min Epoch 41/200 | Train: 0.4859 | Val: 0.4821 | LR: 1.6e-06 | Time: 4.2 min Epoch 42/200 | Train: 0.4816 | Val: 0.4827 | LR: 1.6e-06 | Time: 4.2 min Epoch 43/200 | Train: 0.4796 | Val: 0.4831 | LR: 1.6e-06 | Time: 4.3 min Epoch 44/200 | Train: 0.4813 | Val: 0.4850 | LR: 7.8e-07 | Time: 4.4 min Epoch 45/200 | Train: 0.4780 | Val: 0.4855 | LR: 7.8e-07 | Time: 4.5 min 🛑 Early stopping triggered at Epoch 45 ✅ Training Complete in 4.5 minutes.
In [ ]:
# ==============================================================================
# CELL 6: VISUALIZATION & METRICS
# ==============================================================================
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
# 1. Load Best Model
print(" Loading Best Model...")
model.load_state_dict(torch.load('best_prithvi.pth'))
model.eval()
# 2. Plot Loss Curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# 3. Calculate Metrics on Test Set
print(" Calculating Metrics on Test Set...")
all_preds = []
all_targets = []
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(device), y.to(device)
preds = model(X)
preds = torch.sigmoid(preds) # Convert to probability
preds = (preds > 0.5).float() # Threshold to 0/1
all_preds.append(preds.cpu().numpy().flatten())
all_targets.append(y.cpu().numpy().flatten())
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
acc = accuracy_score(all_targets, all_preds)
prec = precision_score(all_targets, all_preds, zero_division=0)
rec = recall_score(all_targets, all_preds, zero_division=0)
f1 = f1_score(all_targets, all_preds, zero_division=0)
iou = jaccard_score(all_targets, all_preds, zero_division=0)
print(f"\n TEST RESULTS:")
print(f" Accuracy: {acc:.4f}")
print(f" Precision: {prec:.4f}")
print(f" Recall: {rec:.4f}")
print(f" F1-Score: {f1:.4f}")
print(f" IoU: {iou:.4f}")
# 4. Visual Predictions
print("\n Visualizing Predictions...")
def visualize_prediction(loader, num_samples=3):
X, y = next(iter(loader))
X, y = X.to(device), y.to(device)
with torch.no_grad():
preds = torch.sigmoid(model(X))
preds = (preds > 0.5).float()
for idx in range(min(num_samples, len(X))):
# Prepare Input Image (False Color Composite: NIR, Red, Green -> Bands 3,2,1)
# Note: Adjust indices based on your specific stack.
# Usually: B2, B3, B4, B8... -> Indices 0,1,2,3...
# So NIR(3), Red(2), Green(1) makes a nice vegetation map.
rgb = X[idx].cpu().numpy()[[3,2,1],:,:].transpose(1,2,0)
rgb = np.clip(rgb, 0, 1) # Normalize for display
mask = y[idx, 0].cpu().numpy()
pred = preds[idx, 0].cpu().numpy()
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(rgb)
plt.title(f"Input (False Color) - Sample {idx+1}")
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(mask, cmap='gray')
plt.title("Ground Truth Mask")
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pred, cmap='gray')
plt.title("Prithvi Prediction")
plt.axis('off')
plt.show()
visualize_prediction(test_loader, num_samples=3)
⏳ Loading Best Model...
📊 Calculating Metrics on Test Set... ✅ TEST RESULTS: Accuracy: 0.7358 Precision: 0.7455 Recall: 0.8816 F1-Score: 0.8078 IoU: 0.6776 🖼️ Visualizing Predictions...
In [ ]:
# ==============================================================================
# CELL 6: VISUALIZATION (TRUE COLOR + GREEN OVERLAY)
# ==============================================================================
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
# 1. Load Best Model
print(" Loading Best Model...")
model.load_state_dict(torch.load('best_prithvi.pth'))
model.eval()
# 2. Plot Loss Curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# 3. Visual Predictions (Green Overlay)
print("\n Visualizing Predictions (True Color + Green Overlay)...")
def visualize_green_overlay(loader, num_samples=3):
X, y = next(iter(loader))
X, y = X.to(device), y.to(device)
with torch.no_grad():
preds = torch.sigmoid(model(X))
preds = (preds > 0.5).float()
for idx in range(min(num_samples, len(X))):
# Prepare True Color Image (Red=2, Green=1, Blue=0)
# Scale brightness by 3x to make it visible (satellite data is dark)
img_rgb = X[idx].cpu().numpy()[[2,1,0],:,:].transpose(1,2,0)
img_rgb = np.clip(img_rgb * 3.0, 0, 1)
# Ground Truth Mask
mask_gt = y[idx, 0].cpu().numpy()
# Prediction Mask
mask_pred = preds[idx, 0].cpu().numpy()
# Create Green Overlay for Prediction
# We create a copy of the RGB image and tint pixels GREEN where mask_pred == 1
overlay = img_rgb.copy()
overlay[mask_pred == 1] = [0, 1, 0] # Pure Green
# Blend: 70% Original + 30% Green Overlay
blended = cv2.addWeighted(img_rgb, 0.7, overlay, 0.3, 0)
plt.figure(figsize=(12, 4))
# 1. True Color Input
plt.subplot(1, 3, 1)
plt.imshow(img_rgb)
plt.title(f"True Color Input {idx+1}")
plt.axis('off')
# 2. Ground Truth (Black/White)
plt.subplot(1, 3, 2)
plt.imshow(mask_gt, cmap='gray')
plt.title("Ground Truth (Target)")
plt.axis('off')
# 3. Prediction (Green Overlay)
plt.subplot(1, 3, 3)
plt.imshow(blended)
plt.title("Prediction (Green Overlay)")
plt.axis('off')
plt.show()
visualize_green_overlay(test_loader, num_samples=3)
⏳ Loading Best Model...
🖼️ Visualizing Predictions (True Color + Green Overlay)...
In [ ]:
# ==============================================================================
# CELL 5: TRAINING (FORCE 200 EPOCHS + COSINE ANNEALING)
# ==============================================================================
import time
# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 200 # Will complete ALL 200 epochs
# PATIENCE removed because we want to force full training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Training on: {device}")
# Splits
dataset = TensorDataset(X_data, y_data)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
# Model
model = PrithviSegmentation(in_channels=X_data.shape[1], out_classes=1).to(device)
# Loss & Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
# --- COSINE ANNEALING SCHEDULER (New) ---
# This scheduler oscillates the LR, preventing the model from getting stuck.
# T_0=50: Resets LR every 50 epochs.
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=50, T_mult=1, eta_min=1e-6
)
# Loop Variables
train_losses, val_losses = [], []
best_loss = float('inf')
print(f"\nSTARTING FULL TRAINING ({EPOCHS} Epochs)...")
print("Early Stopping is DISABLED to force full training.")
start_time = time.time()
for epoch in range(EPOCHS):
# --- TRAIN ---
model.train()
running_loss = 0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
preds = model(X)
loss = criterion(preds, y)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
# --- VALIDATE ---
model.eval()
running_val_loss = 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
preds = model(X)
running_val_loss += criterion(preds, y).item()
avg_val_loss = running_val_loss / len(val_loader)
val_losses.append(avg_val_loss)
# --- SCHEDULER STEP ---
# Cosine Scheduler steps every epoch (not based on val_loss)
scheduler.step()
# --- LOGGING ---
elapsed = (time.time() - start_time) / 60
current_lr = optimizer.param_groups[0]['lr']
# Print every 10 epochs to reduce clutter
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{EPOCHS} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.1e} | Time: {elapsed:.1f} min")
# Save Best Model (but DON'T stop)
if avg_val_loss < best_loss:
best_loss = avg_val_loss
torch.save(model.state_dict(), 'best_prithvi.pth')
total_time = (time.time() - start_time) / 60
print(f"\n Full 200 Epochs Complete in {total_time:.1f} minutes.")
🚀 Training on: cuda ⏳ Initializing Prithvi Backbone from scratch... STARTING FULL TRAINING (200 Epochs)... Early Stopping is DISABLED to force full training. Epoch 10/200 | Train: 0.4919 | Val: 0.5170 | LR: 9.1e-05 | Time: 1.1 min Epoch 20/200 | Train: 0.4944 | Val: 0.5119 | LR: 6.6e-05 | Time: 2.0 min Epoch 30/200 | Train: 0.4816 | Val: 0.5318 | LR: 3.5e-05 | Time: 3.0 min Epoch 40/200 | Train: 0.4834 | Val: 0.5110 | LR: 1.0e-05 | Time: 4.0 min Epoch 50/200 | Train: 0.4792 | Val: 0.4999 | LR: 1.0e-04 | Time: 4.9 min Epoch 60/200 | Train: 0.4940 | Val: 0.5165 | LR: 9.1e-05 | Time: 5.8 min Epoch 70/200 | Train: 0.4710 | Val: 0.5205 | LR: 6.6e-05 | Time: 6.8 min Epoch 80/200 | Train: 0.4558 | Val: 0.5072 | LR: 3.5e-05 | Time: 7.7 min Epoch 90/200 | Train: 0.4500 | Val: 0.5153 | LR: 1.0e-05 | Time: 8.6 min Epoch 100/200 | Train: 0.4530 | Val: 0.5127 | LR: 1.0e-04 | Time: 9.6 min Epoch 110/200 | Train: 0.4507 | Val: 0.5662 | LR: 9.1e-05 | Time: 10.5 min Epoch 120/200 | Train: 0.4507 | Val: 0.5344 | LR: 6.6e-05 | Time: 11.5 min Epoch 130/200 | Train: 0.4384 | Val: 0.5301 | LR: 3.5e-05 | Time: 12.4 min Epoch 140/200 | Train: 0.4364 | Val: 0.5264 | LR: 1.0e-05 | Time: 13.4 min Epoch 150/200 | Train: 0.4252 | Val: 0.5291 | LR: 1.0e-04 | Time: 14.3 min Epoch 160/200 | Train: 0.4326 | Val: 0.5518 | LR: 9.1e-05 | Time: 15.2 min Epoch 170/200 | Train: 0.4115 | Val: 0.5975 | LR: 6.6e-05 | Time: 16.2 min Epoch 180/200 | Train: 0.4200 | Val: 0.5475 | LR: 3.5e-05 | Time: 17.1 min Epoch 190/200 | Train: 0.3916 | Val: 0.5476 | LR: 1.0e-05 | Time: 18.1 min Epoch 200/200 | Train: 0.3911 | Val: 0.5508 | LR: 1.0e-04 | Time: 19.0 min ✅ Full 200 Epochs Complete in 19.0 minutes.
In [ ]:
# ==============================================================================
# CELL 6: VISUALIZATION & METRICS
# ==============================================================================
import cv2
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
# 1. Load Best Model (Saved during the 200 epochs)
print(" Loading Best Model Checkpoint...")
model.load_state_dict(torch.load('best_prithvi.pth'))
model.eval()
# 2. Plot Loss Curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss (200 Epochs)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# 3. Calculate Metrics on Test Set
print(" Calculating Metrics on Test Set...")
all_preds = []
all_targets = []
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(device), y.to(device)
preds = torch.sigmoid(model(X))
preds = (preds > 0.5).float()
all_preds.append(preds.cpu().numpy().flatten())
all_targets.append(y.cpu().numpy().flatten())
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
# Compute Scores
acc = accuracy_score(all_targets, all_preds)
f1 = f1_score(all_targets, all_preds, zero_division=0)
iou = jaccard_score(all_targets, all_preds, zero_division=0)
print(f"\n TEST RESULTS:")
print(f" Accuracy: {acc:.4f}")
print(f" F1-Score: {f1:.4f}")
print(f" IoU: {iou:.4f}")
# 4. Visual Predictions (Green Overlay)
print("\n Visualizing Predictions (Green Overlay)...")
def visualize_green_overlay(loader, num_samples=3):
X, y = next(iter(loader))
X, y = X.to(device), y.to(device)
with torch.no_grad():
preds = torch.sigmoid(model(X))
preds = (preds > 0.5).float()
for idx in range(min(num_samples, len(X))):
# Prepare True Color Image (Red=2, Green=1, Blue=0)
# We assume bands are downloaded in that order or standard S2 order
# Adjust scale (*3.0) for brightness
img_rgb = X[idx].cpu().numpy()[[2,1,0],:,:].transpose(1,2,0)
img_rgb = np.clip(img_rgb * 3.5, 0, 1)
mask_gt = y[idx, 0].cpu().numpy()
mask_pred = preds[idx, 0].cpu().numpy()
# Create Green Overlay
overlay = img_rgb.copy()
overlay[mask_pred == 1] = [0, 1, 0] # Pure Green
# Blend
blended = cv2.addWeighted(img_rgb, 0.7, overlay, 0.3, 0)
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(img_rgb)
plt.title(f"Input {idx+1}")
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(mask_gt, cmap='gray')
plt.title("Ground Truth")
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(blended)
plt.title("Prediction (Green)")
plt.axis('off')
plt.show()
visualize_green_overlay(test_loader, num_samples=5)
⏳ Loading Best Model Checkpoint...
📊 Calculating Metrics on Test Set... ✅ TEST RESULTS: Accuracy: 0.7873 F1-Score: 0.8501 IoU: 0.7392 🖼️ Visualizing Predictions (Green Overlay)...
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
# Install required libraries
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch opencv-python-headless
import ee
import geemap
import rasterio
from rasterio.windows import from_bounds
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt
# Authenticate with YOUR project
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Environment Ready")
Mounted at /content/drive ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.2/22.2 MB 52.4 MB/s eta 0:00:00 Using Python 3.12.12 environment at: /usr Resolved 63 packages in 791ms Prepared 2 packages in 41ms Installed 2 packages in 5ms + segmentation-models-pytorch==0.5.0 + torchinfo==1.8.0 ✅ Environment Ready
In [ ]:
# ==============================================================================
# CELL 2: HYBRID DATA LOADING (Mature Phase)
# ==============================================================================
# 1. Setup Path
mask_path = '/content/drive/MyDrive/Punjab Wheat Mask_Binary/Punjab Mask 2024.tif'
# 2. Define Sample Area from Mask (Drive)
with rasterio.open(mask_path) as src:
full_bounds = src.bounds
crs = src.crs
# Pick center
center_x = (full_bounds.left + full_bounds.right) / 2
center_y = (full_bounds.bottom + full_bounds.top) / 2
# Define small window (approx 5km x 5km)
half_size = 0.04
sample_bounds = (
center_x - half_size, center_y - half_size,
center_x + half_size, center_y + half_size
)
print(f" Selecting sample area from mask...")
window = from_bounds(*sample_bounds, transform=src.transform)
mask_data = src.read(1, window=window)
# Ensure binary
mask_data = np.where(mask_data > 0, 1, 0).astype(np.float32)
print(f" Mask Loaded! Shape: {mask_data.shape}")
# 3. Download Matching Satellite Data (GEE)
roi = ee.Geometry.Rectangle(
[sample_bounds[0], sample_bounds[1], sample_bounds[2], sample_bounds[3]],
proj=str(crs), geodesic=False
)
# MATURE PHASE (Jan - Apr) - adjusted for 2024/2025 crop cycle
# Using 2024 crop mask implies crop was sown late 2023, harvested April 2024.
# So Mature Phase is Jan 2024 - Apr 2024.
START_DATE = '2024-01-01'
END_DATE = '2024-04-15'
print(f" Downloading Sentinel-2 (Mature Phase: {START_DATE} to {END_DATE})...")
# Create a Composite (Median) to minimize clouds
# Selecting bands compatible with Prithvi (Blue, Green, Red, NIR, SWIR, SWIR)
# S2 Bands: B2, B3, B4, B8A, B11, B12
s2_img = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
.filterBounds(roi)
.filterDate(START_DATE, END_DATE)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20))
.median()
.select(['B2', 'B3', 'B4', 'B8A', 'B11', 'B12'])
.clip(roi))
# Download
image_data = geemap.ee_to_numpy(s2_img, region=roi, scale=10)
# 4. Handle NaNs & Normalize
image_data = np.nan_to_num(image_data, nan=0.0)
# Sentinel-2 is 0-10000 approx, clip to 0-3000 (reflectance 0.3) for visualization/training
image_data = np.clip(image_data / 3000.0, 0, 1)
print(f" Satellite Data Downloaded! Shape: {image_data.shape}")
# 5. Fix Shape Mismatch
# Resize Satellite Image to match Mask exactly
target_h, target_w = mask_data.shape
if image_data.shape[:2] != (target_h, target_w):
print(f" Resizing Image: {image_data.shape[:2]} -> {(target_h, target_w)}")
image_data = cv2.resize(image_data, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
print(f" Final Aligned Image Shape: {image_data.shape}")
📍 Selecting sample area from mask... ✅ Mask Loaded! Shape: (891, 891) ⏳ Downloading Sentinel-2 (Mature Phase: 2024-01-01 to 2024-04-15)... ✅ Satellite Data Downloaded! Shape: (891, 892, 6) ⚠️ Resizing Image: (891, 892) -> (891, 891) ✅ Final Aligned Image Shape: (891, 891, 6)
In [ ]:
# ==============================================================================
# CELL 3: TILING
# ==============================================================================
def create_tiles(image, mask, patch_size=64):
h, w, c = image.shape
patches_x, patches_y = [], []
for y in range(0, h - patch_size + 1, patch_size):
for x in range(0, w - patch_size + 1, patch_size):
img_p = image[y:y+patch_size, x:x+patch_size, :]
mask_p = mask[y:y+patch_size, x:x+patch_size]
# Filter empty patches
if np.mean(img_p) > 0.01:
patches_x.append(img_p)
patches_y.append(mask_p)
return np.array(patches_x), np.array(patches_y)
print(" Creating tiles...")
X_np, y_np = create_tiles(image_data, mask_data, patch_size=64)
# Convert to Tensor (N, C, H, W)
X_data = torch.tensor(np.transpose(X_np, (0, 3, 1, 2)), dtype=torch.float32)
y_data = torch.tensor(np.expand_dims(y_np, 1), dtype=torch.float32)
print(f" Dataset Ready: {X_data.shape}")
✂️ Creating tiles... ✅ Dataset Ready: torch.Size([169, 6, 64, 64])
In [ ]:
# ==============================================================================
# CELL 4: PRITHVI MODEL (MANUAL DEFINITION)
# ==============================================================================
import torch
import torch.nn as nn
from functools import partial
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class PrithviBackbone(nn.Module):
""" Simplified Prithvi-100M Backbone for 64x64 Input """
def __init__(self, img_size=64, patch_size=4, in_chans=6, embed_dim=768, depth=12, num_heads=12):
super().__init__()
# Patch Embed
self.num_patches = (img_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# Positional Embed
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_drop = nn.Dropout(p=0.0)
# Transformer Blocks
self.blocks = nn.ModuleList([
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4., qkv_bias=True)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
# x: [B, 6, 64, 64]
B = x.shape[0]
x = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, 256, 768]
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:, :x.shape[1], :]
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x # [B, 257, 768]
# --- 2. Main Segmentation Model ---
class PrithviSegmentation(nn.Module):
def __init__(self, in_channels, out_classes=1):
super().__init__()
print(" Initializing Prithvi Backbone from scratch...")
# Adapter: Your N bands -> 6 bands
self.input_adapter = nn.Conv2d(in_channels, 6, kernel_size=1)
# Manual Prithvi Backbone (No Hugging Face loading needed)
# Matches Prithvi-100M specs: dim=768, depth=12, heads=12
self.encoder = PrithviBackbone(
img_size=64,
patch_size=4, # 16x16 tokens
in_chans=6,
embed_dim=768,
depth=12,
num_heads=12
)
# Decoder
embed_dim = 768
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embed_dim, 256, kernel_size=2, stride=2), # 16->32
nn.BatchNorm2d(256), nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 32->64
nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, out_classes, kernel_size=1)
)
def forward(self, x):
# Adapt Input
x_adapted = self.input_adapter(x) # [B, 6, 64, 64]
# Encode
features = self.encoder(x_adapted) # [B, 257, 768]
# Reshape (Remove CLS)
features = features[:, 1:, :] # [B, 256, 768]
B, N, C = features.shape
H = int(N**0.5) # 16
features = features.permute(0, 2, 1).reshape(B, C, H, H) # [B, 768, 16, 16]
# Decode
return self.decoder(features)
print(" Prithvi Model Defined Manually (No HF Token Needed)")
✅ Prithvi Model Defined Manually (No HF Token Needed)
In [ ]:
# ==============================================================================
# CELL 5: TRAINING (FORCE 500 EPOCHS + SMOOTH SCHEDULER)
# ==============================================================================
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 500 # Increased to 500
WEIGHT_DECAY = 0.05 # Stronger regularization for long training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Training on: {device}")
# Splits
dataset = TensorDataset(X_data, y_data)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
# Model
model = PrithviSegmentation(in_channels=X_data.shape[1], out_classes=1).to(device)
# Loss & Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
# --- CHANGED SCHEDULER (To Avoid Spikes) ---
# CosineAnnealingLR: Smoothly decreases LR from 1e-4 to 1e-6 over 500 epochs.
# No sudden "Restarts", so validation loss should be much smoother.
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
# Loop Variables
train_losses, val_losses = [], []
best_loss = float('inf')
best_epoch = 0
print(f"\nSTARTING FULL TRAINING ({EPOCHS} Epochs)...")
print("Early Stopping DISABLED. Using Smooth Cosine Decay.")
start_time = time.time()
for epoch in range(EPOCHS):
# --- TRAIN ---
model.train()
running_loss = 0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
preds = model(X)
loss = criterion(preds, y)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
# --- VALIDATE ---
model.eval()
running_val_loss = 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
preds = model(X)
running_val_loss += criterion(preds, y).item()
avg_val_loss = running_val_loss / len(val_loader)
val_losses.append(avg_val_loss)
# --- SCHEDULER STEP ---
# Step every epoch. This will gently lower LR.
scheduler.step()
# --- LOGGING ---
elapsed = (time.time() - start_time) / 60
current_lr = optimizer.param_groups[0]['lr']
# Print every 10 epochs
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{EPOCHS} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.1e} | Time: {elapsed:.1f} min")
# Save Best Model (Track "Peak" Performance)
if avg_val_loss < best_loss:
best_loss = avg_val_loss
best_epoch = epoch + 1
torch.save(model.state_dict(), 'best_prithvi_500.pth')
total_time = (time.time() - start_time) / 60
print(f"\n Full 500 Epochs Complete in {total_time:.1f} minutes.")
print(f" Best Validation Loss was {best_loss:.4f} at Epoch {best_epoch}")
🚀 Training on: cuda ⏳ Initializing Prithvi Backbone from scratch... STARTING FULL TRAINING (500 Epochs)... Early Stopping DISABLED. Using Smooth Cosine Decay. Epoch 10/500 | Train: 0.4926 | Val: 0.4957 | LR: 1.0e-04 | Time: 1.1 min Epoch 20/500 | Train: 0.4850 | Val: 0.4974 | LR: 1.0e-04 | Time: 2.0 min Epoch 30/500 | Train: 0.4847 | Val: 0.5035 | LR: 9.9e-05 | Time: 2.9 min Epoch 40/500 | Train: 0.4774 | Val: 0.5534 | LR: 9.8e-05 | Time: 3.9 min Epoch 50/500 | Train: 0.4710 | Val: 0.5228 | LR: 9.8e-05 | Time: 4.9 min Epoch 60/500 | Train: 0.4553 | Val: 0.5221 | LR: 9.7e-05 | Time: 5.8 min Epoch 70/500 | Train: 0.4293 | Val: 0.5289 | LR: 9.5e-05 | Time: 6.8 min Epoch 80/500 | Train: 0.4343 | Val: 0.5558 | LR: 9.4e-05 | Time: 7.7 min Epoch 90/500 | Train: 0.4314 | Val: 0.5229 | LR: 9.2e-05 | Time: 8.7 min Epoch 100/500 | Train: 0.4274 | Val: 0.5967 | LR: 9.1e-05 | Time: 9.6 min Epoch 110/500 | Train: 0.4131 | Val: 0.5723 | LR: 8.9e-05 | Time: 10.6 min Epoch 120/500 | Train: 0.4018 | Val: 0.5985 | LR: 8.7e-05 | Time: 11.5 min Epoch 130/500 | Train: 0.4072 | Val: 0.5780 | LR: 8.4e-05 | Time: 12.5 min Epoch 140/500 | Train: 0.4023 | Val: 0.5846 | LR: 8.2e-05 | Time: 13.4 min Epoch 150/500 | Train: 0.3725 | Val: 0.6684 | LR: 8.0e-05 | Time: 14.4 min Epoch 160/500 | Train: 0.3378 | Val: 0.7010 | LR: 7.7e-05 | Time: 15.3 min Epoch 170/500 | Train: 0.3029 | Val: 0.6721 | LR: 7.4e-05 | Time: 16.3 min Epoch 180/500 | Train: 0.2590 | Val: 0.7995 | LR: 7.2e-05 | Time: 17.2 min Epoch 190/500 | Train: 0.2286 | Val: 0.8050 | LR: 6.9e-05 | Time: 18.2 min Epoch 200/500 | Train: 0.2003 | Val: 0.7671 | LR: 6.6e-05 | Time: 19.2 min Epoch 210/500 | Train: 0.1876 | Val: 0.8271 | LR: 6.3e-05 | Time: 20.1 min Epoch 220/500 | Train: 0.1773 | Val: 0.8541 | LR: 6.0e-05 | Time: 21.1 min Epoch 230/500 | Train: 0.1657 | Val: 0.9243 | LR: 5.7e-05 | Time: 22.0 min Epoch 240/500 | Train: 0.1514 | Val: 0.9354 | LR: 5.4e-05 | Time: 23.0 min Epoch 250/500 | Train: 0.1515 | Val: 0.9574 | LR: 5.1e-05 | Time: 23.9 min Epoch 260/500 | Train: 0.1364 | Val: 0.9202 | LR: 4.7e-05 | Time: 24.9 min Epoch 270/500 | Train: 0.1302 | Val: 0.9741 | LR: 4.4e-05 | Time: 25.9 min Epoch 280/500 | Train: 0.1248 | Val: 1.0028 | LR: 4.1e-05 | Time: 26.8 min Epoch 290/500 | Train: 0.1163 | Val: 0.9780 | LR: 3.8e-05 | Time: 27.8 min Epoch 300/500 | Train: 0.1082 | Val: 1.0251 | LR: 3.5e-05 | Time: 28.7 min Epoch 310/500 | Train: 0.1107 | Val: 1.0805 | LR: 3.2e-05 | Time: 29.7 min Epoch 320/500 | Train: 0.0964 | Val: 1.0552 | LR: 2.9e-05 | Time: 30.7 min
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) /tmp/ipython-input-3282085532.py in <cell line: 0>() 57 X, y = X.to(device), y.to(device) 58 optimizer.zero_grad() ---> 59 preds = model(X) 60 loss = criterion(preds, y) 61 loss.backward() /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /tmp/ipython-input-3994323363.py in forward(self, x) 126 127 # Encode --> 128 features = self.encoder(x_adapted) # [B, 257, 768] 129 130 # Reshape (Remove CLS) /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /tmp/ipython-input-3994323363.py in forward(self, x) 84 85 for blk in self.blocks: ---> 86 x = blk(x) 87 88 x = self.norm(x) /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /tmp/ipython-input-3994323363.py in forward(self, x) 48 def forward(self, x): 49 x = x + self.attn(self.norm1(x)) ---> 50 x = x + self.mlp(self.norm2(x)) 51 return x 52 /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /usr/local/lib/python3.12/dist-packages/torch/nn/modules/container.py in forward(self, input) 248 """ 249 for module in self: --> 250 input = module(input) 251 return input 252 /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1774 else: -> 1775 return self._call_impl(*args, **kwargs) 1776 1777 # torchrec tests the code consistency with the following code /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1784 or _global_backward_pre_hooks or _global_backward_hooks 1785 or _global_forward_hooks or _global_forward_pre_hooks): -> 1786 return forward_call(*args, **kwargs) 1787 1788 result = None /usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py in forward(self, input) 132 Runs the forward pass. 133 """ --> 134 return F.linear(input, self.weight, self.bias) 135 136 def extra_repr(self) -> str: KeyboardInterrupt:
In [ ]:
# ==========================================
# EVALUATION (Run this after stopping training)
# ==========================================
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
import numpy as np
import torch
def evaluate_model(loader, model, device):
model.eval()
all_preds = []
all_targets = []
print(" Calculating Metrics on Test Set (using current model state)...")
with torch.no_grad():
for X, y in loader:
X, y = X.to(device), y.to(device)
outputs = model(X)
# Convert logits to binary (0 or 1) using threshold 0.5
preds = (torch.sigmoid(outputs) > 0.5).float()
# Flatten to 1D array for metric calculation
all_preds.append(preds.cpu().numpy().flatten())
all_targets.append(y.cpu().numpy().flatten())
# Concatenate all batches
y_pred = np.concatenate(all_preds)
y_true = np.concatenate(all_targets)
# Calculate Metrics
acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
rec = recall_score(y_true, y_pred, zero_division=0)
f1 = f1_score(y_true, y_pred, zero_division=0)
iou = jaccard_score(y_true, y_pred, zero_division=0)
print("\n FINAL TEST RESULTS (Epoch ~310):")
print(f" Accuracy: {acc:.4f}")
print(f" Precision: {prec:.4f}")
print(f" Recall: {rec:.4f}")
print(f" F1-Score: {f1:.4f}")
print(f" IoU: {iou:.4f}")
return acc, f1, iou
# Run Evaluation on Test Loader
# (Ensure 'test_loader' was defined in your split earlier. If not, use 'val_loader')
try:
evaluate_model(test_loader, model, device)
except NameError:
print(" 'test_loader' not found. Evaluating on 'val_loader' instead.")
evaluate_model(val_loader, model, device)
# Save this model manually since you interrupted
torch.save(model.state_dict(), 'model_stopped_epoch_310.pth')
print(" Model saved as 'model_stopped_epoch_310.pth'")
Calculating Metrics on Test Set (using current model state)... FINAL TEST RESULTS (Epoch ~310): Accuracy: 0.7024 Precision: 0.7750 Recall: 0.7993 F1-Score: 0.7870 IoU: 0.6487 Model saved as 'model_stopped_epoch_310.pth'
In [ ]:
import matplotlib.pyplot as plt
def plot_history(train_losses, val_losses):
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Validation Loss', color='orange')
plt.title('Training vs Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# Run Plot
# (This uses the lists 'train_losses' and 'val_losses' from your training loop)
try:
plot_history(train_losses, val_losses)
except NameError:
print(" Could not find loss history lists. Did you run the training loop?")
In [ ]:
import numpy as np
import torch
import matplotlib.pyplot as plt
def visualize_predictions(loader, model, device, num_samples=3):
model.eval()
samples_shown = 0
# Randomly pick a batch
data_iter = iter(loader)
images, masks = next(data_iter)
images = images.to(device)
with torch.no_grad():
outputs = model(images)
preds = (torch.sigmoid(outputs) > 0.5).float()
# Move to CPU for plotting
images = images.cpu().numpy()
masks = masks.cpu().numpy()
preds = preds.cpu().numpy()
plt.figure(figsize=(12, 4 * num_samples))
for i in range(min(num_samples, len(images))):
# 1. Prepare RGB Image (Bands 3,2,1 usually -> Indices 2,1,0)
# Assuming shape [Channels, H, W] -> [H, W, Channels]
# Sentinel-2: R=Band3(idx2), G=Band2(idx1), B=Band1(idx0) for 12-band stack
# (Adjust indices if your stack order is different)
img_rgb = images[i][[3, 2, 1], :, :].transpose(1, 2, 0) # Use bands 4,3,2 (NIR,Red,Green) or 3,2,1
# Normalize for display (0-1 range)
img_rgb = (img_rgb - img_rgb.min()) / (img_rgb.max() - img_rgb.min())
# 2. Prepare Masks
true_mask = masks[i].squeeze()
pred_mask = preds[i].squeeze()
# Plot
plt.subplot(num_samples, 3, i*3 + 1)
plt.imshow(img_rgb)
plt.title("Satellite Input (RGB)")
plt.axis('off')
plt.subplot(num_samples, 3, i*3 + 2)
plt.imshow(true_mask, cmap='gray')
plt.title("Ground Truth Mask")
plt.axis('off')
plt.subplot(num_samples, 3, i*3 + 3)
plt.imshow(pred_mask, cmap='gray')
plt.title("Model Prediction")
plt.axis('off')
plt.tight_layout()
plt.show()
# Run Visualization
try:
visualize_predictions(test_loader, model, device)
except NameError:
print(" 'test_loader' not found. Using 'val_loader' instead.")
visualize_predictions(val_loader, model, device)
In [1]:
!pip install terratorch
Requirement already satisfied: terratorch in /usr/local/lib/python3.12/dist-packages (1.2) Requirement already satisfied: torch>2.0 in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.9.0+cu126) Requirement already satisfied: numpy>=2.2 in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.2.6) Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.24.0+cu126) Requirement already satisfied: rioxarray in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.20.0) Requirement already satisfied: albumentations in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.0.8) Requirement already satisfied: albucore in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.0.24) Requirement already satisfied: rasterio in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.4.4) Requirement already satisfied: torchmetrics in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.8.2) Requirement already satisfied: geopandas in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.1.1) Requirement already satisfied: lightly==1.5.22 in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.5.22) Requirement already satisfied: h5py in /usr/local/lib/python3.12/dist-packages (from terratorch) (3.15.1) Requirement already satisfied: lightning>=2.6.0 in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.6.0) Requirement already satisfied: segmentation-models-pytorch in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.5.0) Requirement already satisfied: jsonargparse>=4.40.0 in /usr/local/lib/python3.12/dist-packages (from terratorch) (4.45.0) Requirement already satisfied: torchgeo in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.8.0) Requirement already satisfied: einops in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.8.1) Requirement already satisfied: timm>=1.0.15 in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.0.22) Requirement already satisfied: pycocotools in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.0.10) Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.36.0) Requirement already satisfied: tifffile in /usr/local/lib/python3.12/dist-packages (from terratorch) (2025.12.12) Requirement already satisfied: python-box in /usr/local/lib/python3.12/dist-packages (from terratorch) (7.3.2) Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from terratorch) (4.67.1) Requirement already satisfied: wandb in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.23.1) Requirement already satisfied: tensorboard in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.19.0) Requirement already satisfied: diffusers in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.36.0) Requirement already satisfied: scikit-learn>=1.3.2 in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.6.1) Requirement already satisfied: scikit-image in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.25.2) Requirement already satisfied: certifi>=14.05.14 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2025.11.12) Requirement already satisfied: hydra-core>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (1.3.2) Requirement already satisfied: lightly_utils~=0.0.0 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (0.0.2) Requirement already satisfied: python_dateutil>=2.5.3 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.9.0.post0) Requirement already satisfied: requests>=2.27.0 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.32.4) Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (1.17.0) Requirement already satisfied: pydantic>=1.10.5 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.12.3) Requirement already satisfied: pytorch_lightning>=1.0.4 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.6.0) Requirement already satisfied: urllib3>=1.25.3 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.5.0) Requirement already satisfied: aenum>=3.1.11 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (3.1.16) Requirement already satisfied: PyYAML>=3.13 in /usr/local/lib/python3.12/dist-packages (from jsonargparse>=4.40.0->terratorch) (6.0.3) Requirement already satisfied: fsspec<2027.0,>=2022.5.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (2025.3.0) Requirement already satisfied: lightning-utilities<2.0,>=0.10.0 in /usr/local/lib/python3.12/dist-packages (from lightning>=2.6.0->terratorch) (0.15.2) Requirement already satisfied: packaging<27.0,>=20.0 in /usr/local/lib/python3.12/dist-packages (from lightning>=2.6.0->terratorch) (25.0) Requirement already satisfied: typing-extensions<6.0,>4.5.0 in /usr/local/lib/python3.12/dist-packages (from lightning>=2.6.0->terratorch) (4.15.0) Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.3.2->terratorch) (1.16.3) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.3.2->terratorch) (1.5.3) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.3.2->terratorch) (3.6.0) Requirement already satisfied: safetensors in /usr/local/lib/python3.12/dist-packages (from timm>=1.0.15->terratorch) (0.7.0) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.20.0) Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (75.2.0) Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (1.14.0) Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.6.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.1.6) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.77) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.77) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.80) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (11.3.0.4) Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (10.3.7.77) Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (11.7.1.2) Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.5.4.2) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (2.27.5) Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.3.20) Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.77) Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.85) Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (1.11.1.6) Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.5.0) Requirement already satisfied: stringzilla>=3.10.4 in /usr/local/lib/python3.12/dist-packages (from albucore->terratorch) (4.5.1) Requirement already satisfied: simsimd>=5.9.2 in /usr/local/lib/python3.12/dist-packages (from albucore->terratorch) (6.5.3) Requirement already satisfied: opencv-python-headless>=4.9.0.80 in /usr/local/lib/python3.12/dist-packages (from albucore->terratorch) (4.12.0.88) Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (8.7.0) Requirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (0.28.1) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (2025.11.3) Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (11.3.0) Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->terratorch) (1.2.0) Requirement already satisfied: pyogrio>=0.7.2 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (0.12.1) Requirement already satisfied: pandas>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (2.2.2) Requirement already satisfied: pyproj>=3.5.0 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (3.7.2) Requirement already satisfied: shapely>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (2.1.2) Requirement already satisfied: affine in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (2.4.0) Requirement already satisfied: attrs in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (25.4.0) Requirement already satisfied: click!=8.2.*,>=4.0 in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (8.3.1) Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (0.7.2) Requirement already satisfied: click-plugins in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (1.1.1.2) Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (3.2.5) Requirement already satisfied: xarray>=2024.7.0 in /usr/local/lib/python3.12/dist-packages (from rioxarray->terratorch) (2025.12.0) Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image->terratorch) (2.37.2) Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image->terratorch) (0.4) Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (1.4.0) Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (1.76.0) Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (3.10) Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (5.29.5) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (3.1.4) Requirement already satisfied: kornia>=0.8.2 in /usr/local/lib/python3.12/dist-packages (from torchgeo->terratorch) (0.8.2) Requirement already satisfied: matplotlib>=3.6 in /usr/local/lib/python3.12/dist-packages (from torchgeo->terratorch) (3.10.0) Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb->terratorch) (3.1.45) Requirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb->terratorch) (4.5.1) Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb->terratorch) (2.47.0) Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (3.13.2) Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.12/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb->terratorch) (4.0.12) Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers->terratorch) (4.12.0) Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers->terratorch) (1.0.9) Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers->terratorch) (3.11) Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->diffusers->terratorch) (0.16.0) Requirement already satisfied: omegaconf<2.4,>=2.2 in /usr/local/lib/python3.12/dist-packages (from hydra-core>=1.0.0->lightly==1.5.22->terratorch) (2.3.0) Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.12/dist-packages (from hydra-core>=1.0.0->lightly==1.5.22->terratorch) (4.9.3) Requirement already satisfied: docstring-parser>=0.17 in /usr/local/lib/python3.12/dist-packages (from jsonargparse[signatures]>=4.25->torchgeo->terratorch) (0.17.0) Requirement already satisfied: typeshed-client>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from jsonargparse[signatures]>=4.25->torchgeo->terratorch) (2.8.2) Requirement already satisfied: kornia_rs>=0.1.9 in /usr/local/lib/python3.12/dist-packages (from kornia>=0.8.2->torchgeo->terratorch) (0.1.10) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (1.3.3) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (4.61.1) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (1.4.9) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.0.0->geopandas->terratorch) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.0.0->geopandas->terratorch) (2025.3) Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.10.5->lightly==1.5.22->terratorch) (0.7.0) Requirement already satisfied: pydantic-core==2.41.4 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.10.5->lightly==1.5.22->terratorch) (2.41.4) Requirement already satisfied: typing-inspection>=0.4.2 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.10.5->lightly==1.5.22->terratorch) (0.4.2) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.27.0->lightly==1.5.22->terratorch) (3.4.4) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>2.0->terratorch) (1.3.0) Requirement already satisfied: markupsafe>=2.1.1 in /usr/local/lib/python3.12/dist-packages (from werkzeug>=1.0.1->tensorboard->terratorch) (3.0.3) Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib_metadata->diffusers->terratorch) (3.23.0) Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (2.6.1) Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (1.4.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (1.8.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (6.7.0) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (0.4.1) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (1.22.0) Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb->terratorch) (5.0.2) Requirement already satisfied: importlib_resources>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from typeshed-client>=2.8.2->jsonargparse[signatures]>=4.25->torchgeo->terratorch) (6.5.2)
In [2]:
from terratorch.registry import BACKBONE_REGISTRY
# Load the Prithvi-EO-1.0-100M model backbone
# Use the appropriate key for the 100M model (e.g., 'prithvi_eo_1_0_100m' or similar key found in the docs)
# The key below is for v2, please verify the exact 100M v1 key in the Prithvi docs
model = BACKBONE_REGISTRY.build("prithvi_eo_v2_600_tl", pretrained=True)
model.eval() # Set model to evaluation mode
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) /tmp/ipython-input-2529966910.py in <cell line: 0>() ----> 1 from terratorch.registry import BACKBONE_REGISTRY 2 3 # Load the Prithvi-EO-1.0-100M model backbone 4 # Use the appropriate key for the 100M model (e.g., 'prithvi_eo_1_0_100m' or similar key found in the docs) 5 # The key below is for v2, please verify the exact 100M v1 key in the Prithvi docs /usr/local/lib/python3.12/dist-packages/terratorch/__init__.py in <module> 5 os.environ["NO_ALBUMENTATIONS_UPDATE"] = "True" 6 ----> 7 import terratorch.models # noqa: F401 8 from terratorch.models.backbones import * # register models in registries # noqa: F403 9 from terratorch.registry import BACKBONE_REGISTRY, DECODER_REGISTRY, MODEL_FACTORY_REGISTRY, FULL_MODEL_REGISTRY # noqa: F401 /usr/local/lib/python3.12/dist-packages/terratorch/models/__init__.py in <module> 4 import logging 5 ----> 6 import terratorch.models.necks # register necks # noqa: F401 7 from terratorch.models.clay1_5_model_factory import Clay1_5ModelFactory 8 from terratorch.models.clay_model_factory import ClayModelFactory /usr/local/lib/python3.12/dist-packages/terratorch/models/necks.py in <module> 10 from einops import rearrange 11 from torch import nn ---> 12 from torchvision.ops import FeaturePyramidNetwork 13 14 from terratorch.registry import NECK_REGISTRY, TERRATORCH_NECK_REGISTRY /usr/local/lib/python3.12/dist-packages/torchvision/__init__.py in <module> 8 # .extensions) before entering _meta_registrations. 9 from .extension import _HAS_OPS # usort:skip ---> 10 from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip 11 12 try: /usr/local/lib/python3.12/dist-packages/torchvision/models/__init__.py in <module> 1 from .alexnet import * ----> 2 from .convnext import * 3 from .densenet import * 4 from .efficientnet import * 5 from .googlenet import * /usr/local/lib/python3.12/dist-packages/torchvision/models/convnext.py in <module> 7 from torch.nn import functional as F 8 ----> 9 from ..ops.misc import Conv2dNormActivation, Permute 10 from ..ops.stochastic_depth import StochasticDepth 11 from ..transforms._presets import ImageClassification /usr/local/lib/python3.12/dist-packages/torchvision/ops/__init__.py in <module> 21 from .giou_loss import generalized_box_iou_loss 22 from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation ---> 23 from .poolers import MultiScaleRoIAlign 24 from .ps_roi_align import ps_roi_align, PSRoIAlign 25 from .ps_roi_pool import ps_roi_pool, PSRoIPool /usr/local/lib/python3.12/dist-packages/torchvision/ops/poolers.py in <module> 8 9 from ..utils import _log_api_usage_once ---> 10 from .roi_align import roi_align 11 12 /usr/local/lib/python3.12/dist-packages/torchvision/ops/roi_align.py in <module> 5 import torch.fx 6 from torch import nn, Tensor ----> 7 from torch._dynamo.utils import is_compile_supported 8 from torch.jit.annotations import BroadcastingList2 9 from torch.nn.modules.utils import _pair KeyboardInterrupt: