In [ ]:
# ==============================================================================
# 0. SETUP & LIBRARIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')

# Install specific libraries for Prithvi
!pip install -q uv
!uv pip install rasterio geopandas geojson torchinfo tslearn transformers segmentation-models-pytorch

import ee
import geemap
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel, AutoConfig

# Authenticate & Initialize
try:
    ee.Initialize(project='[REDACTED_FOR_SECURITY]') # Using your specified project ID
except:
    ee.Authenticate()
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')

print(f" Earth Engine Initialized with project: [REDACTED_FOR_SECURITY]")

# ==============================================================================
# 1. GEE DATA PIPELINE (MATURE WHEAT PHASE: JAN - APR)
# ==============================================================================

# ------------------- USER INPUTS -------------------
# You want 2024 dataset mature state -> Jan 2025 to Apr 2025
year = 2024
wheat_mask = ee.Image(f'projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved') # Verify path
roi = ee.FeatureCollection('projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab')

# ------------------- PARAMETERS -------------------
# UPDATED: Mature state only (Jan 1 to Apr 15)
START_DATE = f'{year+1}-01-01'
END_DATE   = f'{year+1}-04-16'
print(f" Data Range: {START_DATE} to {END_DATE} (Mature Phase)")

S1_MIN_DB = -25
S1_MAX_DB = 0
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']

# ------------------- SENTINEL-2 PROCESSING -------------------
MAX_CLOUD_PROB = 70

s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
         .filterDate(START_DATE, END_DATE)
         .filterBounds(roi)
         .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100)))

s2Clouds = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
            .filterDate(START_DATE, END_DATE)
            .filterBounds(roi))

s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
    primary=s2Raw,
    secondary=s2Clouds,
    condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))

def maskS2Clouds(img):
    cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
    isProbCloud = cloudProb.gt(MAX_CLOUD_PROB)
    optical = img.select(s2Bands).multiply(0.0001) # Scale 0-1
    finalMask = isProbCloud.Not()
    ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
    return (optical.addBands(ndvi)
            .updateMask(finalMask)
            .copyProperties(img, ['system:time_start']))

s2Clean = s2Joined.map(maskS2Clouds)

# ------------------- SENTINEL-1 PROCESSING -------------------
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
      .filterDate(START_DATE, END_DATE)
      .filterBounds(roi)
      .filter(ee.Filter.eq('instrumentMode', 'IW'))
      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
      .filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING'))
      .select(['VV', 'VH'])
      .map(lambda img: img.rename(['VV_sar', 'VH_sar'])))

# ------------------- INTERVAL GENERATION -------------------
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)

while curDate < endDateObj:
    yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day

    if dayVal <= 15:
        nextDate = datetime.datetime(yearVal, monthVal, 16)
        suffix = '1'
    else:
        if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
        else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
        suffix = '2'

    if nextDate > endDateObj: nextDate = endDateObj

    startStr = curDate.strftime('%Y-%m-%d')
    endStr = (nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d')
    label = f'{yearVal}_{monthVal:02d}_{suffix}'

    if curDate < nextDate:
        intervals.append([startStr, endStr, label])

    curDate = nextDate

print(f" Generated {len(intervals)} fortnightly intervals")
eeIntervals = ee.List(intervals)

# ------------------- COMPOSITE FUNCTION -------------------
def makeComposite(item):
    item = ee.List(item)
    start = ee.Date(item.get(0))
    end = ee.Date(item.get(1)).advance(1, 'day')
    label = ee.String(item.get(2))

    # S2 Composite (Median to avoid cloud artifacts better in winter)
    s2 = s2Clean.filterDate(start, end).median()
    # Fill gaps
    s2 = ee.Image(ee.Algorithms.If(
        s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)
    ))

    # S1 Composite (Mean)
    s1_comp = s1.filterDate(start, end).mean()
    s1_comp = ee.Image(ee.Algorithms.If(
        s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])
    ))

    return s2.addBands(s1_comp).set('label', label).clip(roi)

fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))

# ------------------- STACKING & DOWNLOAD -------------------
def stack_collection(collection):
    col_list = collection.toList(collection.size())
    count = collection.size().getInfo()
    if count == 0: raise ValueError("Empty collection")
    stack = ee.Image(col_list.get(0))
    for i in range(1, count):
        stack = stack.addBands(ee.Image(col_list.get(i)))
    return stack

print("Stacking collection...")
stacked_input = stack_collection(fortnightly_col)
input_band_names = stacked_input.bandNames().getInfo()
print(f"Total Input Channels: {len(input_band_names)}")

# Download Sample
# Using a buffer to get a reasonable training patch
sample_roi = roi.geometry().centroid(1).buffer(2500).bounds()

print(" Downloading data (may take a minute)...")
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)

# Cleaning
ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)

# Resizing to 224x224 patch friendly size if needed, but we tile 64x64
print(f"Downloaded Shape: {ds.shape}")

# Tiling function
def create_tiles(image, mask, patch_size=64):
    h, w, c = image.shape
    patches_x, patches_y = [], []
    for y in range(0, h - patch_size + 1, patch_size):
        for x in range(0, w - patch_size + 1, patch_size):
            img_p = image[y:y+patch_size, x:x+patch_size, :]
            mask_p = mask[y:y+patch_size, x:x+patch_size]
            if np.mean(img_p) > 0.01: # Filter empty
                patches_x.append(img_p)
                patches_y.append(mask_p)
    return np.array(patches_x), np.array(patches_y)

print("Tiling...")
X_np, y_np = create_tiles(ds, lb, patch_size=64)

# Transpose to (N, C, H, W) for PyTorch
X_data = torch.tensor(np.transpose(X_np, (0, 3, 1, 2)), dtype=torch.float32)
y_data = torch.tensor(np.expand_dims(y_np, 1), dtype=torch.float32)

print(f" Final Dataset: {X_data.shape}")
Mounted at /content/drive
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.2/22.2 MB 35.3 MB/s eta 0:00:00
Using Python 3.12.12 environment at: /usr
Resolved 70 packages in 1.57s
Prepared 4 packages in 54ms
Installed 4 packages in 4ms
 + geojson==3.2.0
 + segmentation-models-pytorch==0.5.0
 + torchinfo==1.8.0
 + tslearn==0.7.0
✅ Earth Engine Initialized with project: phd0-473604
📅 Data Range: 2025-01-01 to 2025-04-16 (Mature Phase)
✅ Generated 7 fortnightly intervals
Stacking collection...
Total Input Channels: 91
⏳ Downloading data (may take a minute)...
---------------------------------------------------------------------------
HttpError                                 Traceback (most recent call last)
/usr/local/lib/python3.12/dist-packages/ee/data.py in _execute_cloud_call(call, num_retries)
    407   try:
--> 408     return call.execute(num_retries=num_retries)
    409   except googleapiclient.errors.HttpError as e:

/usr/local/lib/python3.12/dist-packages/googleapiclient/_helpers.py in positional_wrapper(*args, **kwargs)
    129                     logger.warning(message)
--> 130             return wrapped(*args, **kwargs)
    131 

/usr/local/lib/python3.12/dist-packages/googleapiclient/http.py in execute(self, http, num_retries)
    937         if resp.status >= 300:
--> 938             raise HttpError(resp, content, uri=self.uri)
    939         return self.postproc(resp, content)

HttpError: <HttpError 400 when requesting https://earthengine.googleapis.com/v1/projects/phd0-473604/image:computePixels? returned "Total request size (230704642 bytes) must be less than or equal to 50331648 bytes.". Details: "Total request size (230704642 bytes) must be less than or equal to 50331648 bytes.">

During handling of the above exception, another exception occurred:

EEException                               Traceback (most recent call last)
/usr/local/lib/python3.12/dist-packages/geemap/common.py in ee_to_numpy(ee_object, region, scale, bands, **kwargs)
   3133     try:
-> 3134         struct_array = ee.data.computePixels(kwargs)
   3135         array = np.dstack(([struct_array[band] for band in struct_array.dtype.names]))

/usr/local/lib/python3.12/dist-packages/ee/data.py in computePixels(params)
    953   _maybe_populate_workload_tag(params)
--> 954   data = _execute_cloud_call(
    955       _get_cloud_projects_raw()

/usr/local/lib/python3.12/dist-packages/ee/data.py in _execute_cloud_call(call, num_retries)
    409   except googleapiclient.errors.HttpError as e:
--> 410     raise _translate_cloud_exception(e)  # pylint: disable=raise-missing-from
    411 

EEException: Total request size (230704642 bytes) must be less than or equal to 50331648 bytes.

During handling of the above exception, another exception occurred:

Exception                                 Traceback (most recent call last)
/tmp/ipython-input-3007931609.py in <cell line: 0>()
    164 
    165 print("⏳ Downloading data (may take a minute)...")
--> 166 ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
    167 lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
    168 

/usr/local/lib/python3.12/dist-packages/geemap/common.py in ee_to_numpy(ee_object, region, scale, bands, **kwargs)
   3136         return array
   3137     except Exception as e:
-> 3138         raise Exception(e)
   3139 
   3140 

Exception: Total request size (230704642 bytes) must be less than or equal to 50331648 bytes.
In [ ]:
# ==============================================================================
# CELL 1: SETUP & AUTHENTICATION
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')

# Install dependencies for Prithvi
!pip install -q uv
!uv pip install rasterio geopandas geojson torchinfo tslearn transformers segmentation-models-pytorch

import ee
import geemap
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel

# Authenticate & Initialize with YOUR Project ID
try:
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
    ee.Authenticate()
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')

print(" Earth Engine Initialized with project: [REDACTED_FOR_SECURITY]")
In [ ]:
# ==============================================================================
# CELL 2: GEE DATA DOWNLOAD PIPELINE
# ==============================================================================

# ------------------- USER INPUTS -------------------
year = 2024
# We use assets from the other project, but COMPUTE on [REDACTED_FOR_SECURITY]
wheat_mask_path = 'projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved'
roi_path = 'projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab'

# Load Assets
wheat_mask = ee.Image(wheat_mask_path)
roi = ee.FeatureCollection(roi_path)

# ------------------- PARAMETERS (Mature Phase) -------------------
START_DATE = f'{year+1}-01-01' # Jan 2025
END_DATE   = f'{year+1}-04-16' # Apr 2025
print(f" Data Range: {START_DATE} to {END_DATE} (Mature Wheat Phase)")

s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
MAX_CLOUD_PROB = 70

# ------------------- SENTINEL-2 PROCESSING -------------------
s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
         .filterDate(START_DATE, END_DATE)
         .filterBounds(roi)
         .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100)))

s2Clouds = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
            .filterDate(START_DATE, END_DATE)
            .filterBounds(roi))

s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
    primary=s2Raw,
    secondary=s2Clouds,
    condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))

def maskS2Clouds(img):
    cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
    isProbCloud = cloudProb.gt(MAX_CLOUD_PROB)
    optical = img.select(s2Bands).multiply(0.0001)
    finalMask = isProbCloud.Not()
    ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
    return (optical.addBands(ndvi)
            .updateMask(finalMask)
            .copyProperties(img, ['system:time_start']))

s2Clean = s2Joined.map(maskS2Clouds)

# ------------------- SENTINEL-1 PROCESSING -------------------
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
      .filterDate(START_DATE, END_DATE)
      .filterBounds(roi)
      .filter(ee.Filter.eq('instrumentMode', 'IW'))
      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
      .filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING'))
      .select(['VV', 'VH'])
      .map(lambda img: img.rename(['VV_sar', 'VH_sar'])))

# ------------------- INTERVALS -------------------
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)

while curDate < endDateObj:
    yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day
    if dayVal <= 15:
        nextDate = datetime.datetime(yearVal, monthVal, 16)
        suffix = '1'
    else:
        if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
        else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
        suffix = '2'

    if nextDate > endDateObj: nextDate = endDateObj
    intervals.append([curDate.strftime('%Y-%m-%d'),
                      (nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d'),
                      f'{yearVal}_{monthVal:02d}_{suffix}'])
    curDate = nextDate

eeIntervals = ee.List(intervals)

# ------------------- COMPOSITES -------------------
def makeComposite(item):
    item = ee.List(item)
    start = ee.Date(item.get(0))
    end = ee.Date(item.get(1)).advance(1, 'day')
    label = ee.String(item.get(2))

    s2 = s2Clean.filterDate(start, end).median()
    s2 = ee.Image(ee.Algorithms.If(s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)))

    s1_comp = s1.filterDate(start, end).mean()
    s1_comp = ee.Image(ee.Algorithms.If(s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])))

    return s2.addBands(s1_comp).set('label', label).clip(roi)

fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))

# Stack
def stack_collection(collection):
    col_list = collection.toList(collection.size())
    count = collection.size().getInfo()
    stack = ee.Image(col_list.get(0))
    for i in range(1, count): stack = stack.addBands(ee.Image(col_list.get(i)))
    return stack

print("Stacking...")
stacked_input = stack_collection(fortnightly_col)
input_band_names = stacked_input.bandNames().getInfo()
print(f" Total Channels: {len(input_band_names)}")

# Download Sample
sample_roi = roi.geometry().centroid(1).buffer(2500).bounds()
print(" Downloading patch...")
ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)

ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
print(f" Downloaded Shape: {ds.shape}")

# ... (Previous Cell 2 code remains the same up to 'Download Sample') ...

# --- MODIFIED DOWNLOAD SECTION ---

# Reduce buffer from 2500 -> 1000 to fit in 50MB RAM limit
print(" ROI too large for direct download. Using smaller sample for training prototype.")
sample_roi = roi.geometry().centroid(1).buffer(1000).bounds()

print(" Downloading patch (smaller area)...")
try:
    ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
    lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
except Exception as e:
    print(f" Still too large? Error: {e}")
    # Fallback to even smaller if 1000 fails
    print("Trying ultra-small patch (500m buffer)...")
    sample_roi = roi.geometry().centroid(1).buffer(500).bounds()
    ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
    lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)

ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
print(f" Downloaded Shape: {ds.shape}")
In [ ]:
# ==============================================================================
# CELL 2: GEE DATA DOWNLOAD PIPELINE
# ==============================================================================

# ------------------- USER INPUTS -------------------
year = 2024
# We use assets from the other project, but COMPUTE on [REDACTED_FOR_SECURITY]
wheat_mask_path = 'projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved'
roi_path = 'projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab'

# Load Assets
wheat_mask = ee.Image(wheat_mask_path)
roi = ee.FeatureCollection(roi_path)

# ------------------- PARAMETERS (Mature Phase) -------------------
START_DATE = f'{year+1}-01-01' # Jan 2025
END_DATE   = f'{year+1}-04-16' # Apr 2025
print(f" Data Range: {START_DATE} to {END_DATE} (Mature Wheat Phase)")

s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
MAX_CLOUD_PROB = 70

# ------------------- SENTINEL-2 PROCESSING -------------------
s2Raw = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
         .filterDate(START_DATE, END_DATE)
         .filterBounds(roi)
         .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 100)))

s2Clouds = (ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
            .filterDate(START_DATE, END_DATE)
            .filterBounds(roi))

s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
    primary=s2Raw,
    secondary=s2Clouds,
    condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))

def maskS2Clouds(img):
    cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
    isProbCloud = cloudProb.gt(MAX_CLOUD_PROB)
    optical = img.select(s2Bands).multiply(0.0001)
    finalMask = isProbCloud.Not()
    ndvi = optical.normalizedDifference(['B8', 'B4']).rename('NDVI')
    return (optical.addBands(ndvi)
            .updateMask(finalMask)
            .copyProperties(img, ['system:time_start']))

s2Clean = s2Joined.map(maskS2Clouds)

# ------------------- SENTINEL-1 PROCESSING -------------------
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
      .filterDate(START_DATE, END_DATE)
      .filterBounds(roi)
      .filter(ee.Filter.eq('instrumentMode', 'IW'))
      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
      .filter(ee.Filter.eq('orbitProperties_pass', 'DESCENDING'))
      .select(['VV', 'VH'])
      .map(lambda img: img.rename(['VV_sar', 'VH_sar'])))

# ------------------- INTERVALS -------------------
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)

while curDate < endDateObj:
    yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day
    if dayVal <= 15:
        nextDate = datetime.datetime(yearVal, monthVal, 16)
        suffix = '1'
    else:
        if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
        else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
        suffix = '2'

    if nextDate > endDateObj: nextDate = endDateObj
    intervals.append([curDate.strftime('%Y-%m-%d'),
                      (nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d'),
                      f'{yearVal}_{monthVal:02d}_{suffix}'])
    curDate = nextDate

eeIntervals = ee.List(intervals)

# ------------------- COMPOSITES -------------------
def makeComposite(item):
    item = ee.List(item)
    start = ee.Date(item.get(0))
    end = ee.Date(item.get(1)).advance(1, 'day')
    label = ee.String(item.get(2))

    s2 = s2Clean.filterDate(start, end).median()
    s2 = ee.Image(ee.Algorithms.If(s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)))

    s1_comp = s1.filterDate(start, end).mean()
    s1_comp = ee.Image(ee.Algorithms.If(s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])))

    return s2.addBands(s1_comp).set('label', label).clip(roi)

fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))

# Stack
def stack_collection(collection):
    col_list = collection.toList(collection.size())
    count = collection.size().getInfo()
    stack = ee.Image(col_list.get(0))
    for i in range(1, count): stack = stack.addBands(ee.Image(col_list.get(i)))
    return stack

print("Stacking...")
stacked_input = stack_collection(fortnightly_col)
input_band_names = stacked_input.bandNames().getInfo()
print(f" Total Channels: {len(input_band_names)}")

# --- MODIFIED DOWNLOAD SECTION ---
# Reduced buffer from 2500 -> 1000 to fit in 50MB RAM limit
print(" Using reduced 1000m buffer to fit memory limit...")
sample_roi = roi.geometry().centroid(1).buffer(1000).bounds()

print(" Downloading patch...")
try:
    ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
    lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)
except Exception as e:
    print(f" 1000m failed ({e}), trying 500m...")
    sample_roi = roi.geometry().centroid(1).buffer(500).bounds()
    ds = geemap.ee_to_numpy(stacked_input, region=sample_roi, scale=10)
    lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)

ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)
print(f" Downloaded Shape: {ds.shape}")
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')

# Install libraries for Prithvi & Geodata
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch

import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt

print(" Environment Ready")
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive')

# Install libraries for Prithvi & Geodata
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch

import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt

# Authenticate GEE (Required if using any GEE features, though data is in Drive)
import ee
try:
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
    ee.Authenticate()
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')

print(" Environment Ready")
In [ ]:
import os

# Your folder path
folder_path = '/content/drive/MyDrive/Punjab Wheat Mask_Binary/'

print(f" Checking files in: {folder_path}\n")

try:
    files = os.listdir(folder_path)
    for f in files:
        print(f" {f}")
except FileNotFoundError:
    print(" Folder not found. Please check the path.")
In [ ]:
# ==============================================================================
# CELL 2: GENERATE DATA FROM SATELLITE (MATURE PHASE)
# ==============================================================================
import datetime  # <--- ADDED THIS IMPORT
import ee
import geemap
import numpy as np
import torch # Ensure torch is imported if used later in this cell

# 1. Define ROI & Mask
wheat_mask = ee.Image('projects/[REDACTED_FOR_SECURITY]/assets/Binary_Punjab_Wheat_2024_Sieved')
roi = ee.FeatureCollection('projects/[REDACTED_FOR_SECURITY]/assets/ShapeFiles/Punjab')

# 2. Define Timeframe: MATURE STATE (Jan - Apr)
year = 2024
START_DATE = f'{year+1}-01-01' # Jan 1, 2025
END_DATE   = f'{year+1}-04-16' # Apr 16, 2025
print(f" Generating data for Mature Phase: {START_DATE} to {END_DATE}")

# 3. Sentinel-2 (Optical) Processing
s2Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
s2Raw = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filterDate(START_DATE, END_DATE).filterBounds(roi)
s2Clouds = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY').filterDate(START_DATE, END_DATE).filterBounds(roi)

s2Joined = ee.ImageCollection(ee.Join.saveFirst('cloud_mask_img').apply(
    primary=s2Raw, secondary=s2Clouds,
    condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
))

def maskS2Clouds(img):
    cloudProb = ee.Image(img.get('cloud_mask_img')).select('probability')
    isProbCloud = cloudProb.gt(70)
    return img.select(s2Bands).multiply(0.0001).updateMask(isProbCloud.Not())

s2Clean = s2Joined.map(maskS2Clouds)

# 4. Sentinel-1 (Radar) Processing
s1 = (ee.ImageCollection('COPERNICUS/S1_GRD')
      .filterDate(START_DATE, END_DATE).filterBounds(roi)
      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
      .select(['VV', 'VH']).map(lambda img: img.rename(['VV_sar', 'VH_sar'])))

# 5. Create Fortnightly Composites
intervals = []
curDate = datetime.datetime.fromisoformat(START_DATE)
endDateObj = datetime.datetime.fromisoformat(END_DATE)

while curDate < endDateObj:
    # Logic to split month into 1-15 and 16-end
    yearVal, monthVal, dayVal = curDate.year, curDate.month, curDate.day
    if dayVal <= 15: nextDate = datetime.datetime(yearVal, monthVal, 16); suffix='1'
    else:
        if monthVal == 12: nextDate = datetime.datetime(yearVal + 1, 1, 1)
        else: nextDate = datetime.datetime(yearVal, monthVal + 1, 1)
        suffix='2'
    if nextDate > endDateObj: nextDate = endDateObj

    intervals.append([curDate.strftime('%Y-%m-%d'), (nextDate - datetime.timedelta(days=1)).strftime('%Y-%m-%d')])
    curDate = nextDate

print(f" Created {len(intervals)} fortnightly intervals")

# Stack Function
eeIntervals = ee.List(intervals)
def makeComposite(item):
    start, end = ee.Date(ee.List(item).get(0)), ee.Date(ee.List(item).get(1)).advance(1, 'day')
    s2 = s2Clean.filterDate(start, end).median()
    s2 = ee.Image(ee.Algorithms.If(s2.bandNames().size().gt(0), s2, ee.Image.constant([0]*len(s2Bands)).rename(s2Bands)))
    s1_comp = s1.filterDate(start, end).mean()
    s1_comp = ee.Image(ee.Algorithms.If(s1_comp.bandNames().size().gt(0), s1_comp, ee.Image.constant([0,0]).rename(['VV_sar','VH_sar'])))
    return s2.addBands(s1_comp)

fortnightly_col = ee.ImageCollection.fromImages(eeIntervals.map(makeComposite))

# Flatten to single image stack
stack = ee.Image(fortnightly_col.toBands())
print(f" Total Input Channels: {stack.bandNames().size().getInfo()}")

# 6. DOWNLOAD PATCH (Using safe buffer size)
print(" Downloading Patch (1000m buffer)...")
sample_roi = roi.geometry().centroid(1).buffer(1000).bounds()

ds = geemap.ee_to_numpy(stack, region=sample_roi, scale=10)
lb = geemap.ee_to_numpy(wheat_mask, region=sample_roi, scale=10)

# Clean
ds = np.nan_to_num(ds, nan=0.0)
lb = np.nan_to_num(lb, nan=0.0)

print(f" Downloaded Data Shape: {ds.shape}")
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Install required libraries
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch opencv-python-headless

import ee
import geemap
import rasterio
from rasterio.windows import from_bounds
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt

# Authenticate with YOUR project
try:
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
    ee.Authenticate()
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')

print(" Environment Ready")
Mounted at /content/drive
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.2/22.2 MB 118.2 MB/s eta 0:00:00
Using Python 3.12.12 environment at: /usr
Resolved 63 packages in 314ms
Prepared 2 packages in 27ms
Installed 2 packages in 3ms
 + segmentation-models-pytorch==0.5.0
 + torchinfo==1.8.0
✅ Environment Ready
In [ ]:
# ==============================================================================
# CELL 2: HYBRID DATA LOADING (Mature Phase)
# ==============================================================================

# 1. Setup Path
mask_path = '/content/drive/MyDrive/Punjab Wheat Mask_Binary/Punjab Mask 2024.tif'

# 2. Define Sample Area from Mask (Drive)
with rasterio.open(mask_path) as src:
    full_bounds = src.bounds
    crs = src.crs

    # Pick center
    center_x = (full_bounds.left + full_bounds.right) / 2
    center_y = (full_bounds.bottom + full_bounds.top) / 2

    # Define small window (approx 5km x 5km)
    half_size = 0.04
    sample_bounds = (
        center_x - half_size, center_y - half_size,
        center_x + half_size, center_y + half_size
    )

    print(f" Selecting sample area from mask...")

    window = from_bounds(*sample_bounds, transform=src.transform)
    mask_data = src.read(1, window=window)
    # Ensure binary
    mask_data = np.where(mask_data > 0, 1, 0).astype(np.float32)

print(f" Mask Loaded! Shape: {mask_data.shape}")

# 3. Download Matching Satellite Data (GEE)
roi = ee.Geometry.Rectangle(
    [sample_bounds[0], sample_bounds[1], sample_bounds[2], sample_bounds[3]],
    proj=str(crs), geodesic=False
)

# MATURE PHASE (Jan - Apr) - adjusted for 2024/2025 crop cycle
# Using 2024 crop mask implies crop was sown late 2023, harvested April 2024.
# So Mature Phase is Jan 2024 - Apr 2024.
START_DATE = '2024-01-01'
END_DATE   = '2024-04-15'

print(f" Downloading Sentinel-2 (Mature Phase: {START_DATE} to {END_DATE})...")

# Create a Composite (Median) to minimize clouds
# Selecting bands compatible with Prithvi (Blue, Green, Red, NIR, SWIR, SWIR)
# S2 Bands: B2, B3, B4, B8A, B11, B12
s2_img = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
          .filterBounds(roi)
          .filterDate(START_DATE, END_DATE)
          .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20))
          .median()
          .select(['B2', 'B3', 'B4', 'B8A', 'B11', 'B12'])
          .clip(roi))

# Download
image_data = geemap.ee_to_numpy(s2_img, region=roi, scale=10)

# 4. Handle NaNs & Normalize
image_data = np.nan_to_num(image_data, nan=0.0)
# Sentinel-2 is 0-10000 approx, clip to 0-3000 (reflectance 0.3) for visualization/training
image_data = np.clip(image_data / 3000.0, 0, 1)

print(f" Satellite Data Downloaded! Shape: {image_data.shape}")

# 5. Fix Shape Mismatch
# Resize Satellite Image to match Mask exactly
target_h, target_w = mask_data.shape
if image_data.shape[:2] != (target_h, target_w):
    print(f" Resizing Image: {image_data.shape[:2]} -> {(target_h, target_w)}")
    image_data = cv2.resize(image_data, (target_w, target_h), interpolation=cv2.INTER_LINEAR)

print(f" Final Aligned Image Shape: {image_data.shape}")
📍 Selecting sample area from mask...
✅ Mask Loaded! Shape: (891, 891)
⏳ Downloading Sentinel-2 (Mature Phase: 2024-01-01 to 2024-04-15)...
✅ Satellite Data Downloaded! Shape: (891, 892, 6)
⚠️ Resizing Image: (891, 892) -> (891, 891)
✅ Final Aligned Image Shape: (891, 891, 6)
In [ ]:
# ==============================================================================
# CELL 3: TILING
# ==============================================================================

def create_tiles(image, mask, patch_size=64):
    h, w, c = image.shape
    patches_x, patches_y = [], []

    for y in range(0, h - patch_size + 1, patch_size):
        for x in range(0, w - patch_size + 1, patch_size):
            img_p = image[y:y+patch_size, x:x+patch_size, :]
            mask_p = mask[y:y+patch_size, x:x+patch_size]

            # Filter empty patches
            if np.mean(img_p) > 0.01:
                patches_x.append(img_p)
                patches_y.append(mask_p)

    return np.array(patches_x), np.array(patches_y)

print(" Creating tiles...")
X_np, y_np = create_tiles(image_data, mask_data, patch_size=64)

# Convert to Tensor (N, C, H, W)
X_data = torch.tensor(np.transpose(X_np, (0, 3, 1, 2)), dtype=torch.float32)
y_data = torch.tensor(np.expand_dims(y_np, 1), dtype=torch.float32)

print(f" Dataset Ready: {X_data.shape}")
✂️ Creating tiles...
✅ Dataset Ready: torch.Size([169, 6, 64, 64])
In [ ]:
# ==============================================================================
# CELL 4: PRITHVI MODEL (MANUAL DEFINITION)
# ==============================================================================
import torch
import torch.nn as nn
from functools import partial

# --- 1. Define Standard ViT Components (Prithvi Architecture) ---
# We define the Transformer blocks manually to avoid loading errors

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class PrithviBackbone(nn.Module):
    """ Simplified Prithvi-100M Backbone for 64x64 Input """
    def __init__(self, img_size=64, patch_size=4, in_chans=6, embed_dim=768, depth=12, num_heads=12):
        super().__init__()

        # Patch Embed
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        # Positional Embed
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=0.0)

        # Transformer Blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4., qkv_bias=True)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: [B, 6, 64, 64]
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, 256, 768]

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed[:, :x.shape[1], :]
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x # [B, 257, 768]

# --- 2. Main Segmentation Model ---

class PrithviSegmentation(nn.Module):
    def __init__(self, in_channels, out_classes=1):
        super().__init__()

        print(" Initializing Prithvi Backbone from scratch...")

        # Adapter: Your N bands -> 6 bands
        self.input_adapter = nn.Conv2d(in_channels, 6, kernel_size=1)

        # Manual Prithvi Backbone (No Hugging Face loading needed)
        # Matches Prithvi-100M specs: dim=768, depth=12, heads=12
        self.encoder = PrithviBackbone(
            img_size=64,
            patch_size=4, # 16x16 tokens
            in_chans=6,
            embed_dim=768,
            depth=12,
            num_heads=12
        )

        # Decoder
        embed_dim = 768
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 256, kernel_size=2, stride=2), # 16->32
            nn.BatchNorm2d(256), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 32->64
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, out_classes, kernel_size=1)
        )

    def forward(self, x):
        # Adapt Input
        x_adapted = self.input_adapter(x) # [B, 6, 64, 64]

        # Encode
        features = self.encoder(x_adapted) # [B, 257, 768]

        # Reshape (Remove CLS)
        features = features[:, 1:, :] # [B, 256, 768]
        B, N, C = features.shape
        H = int(N**0.5) # 16
        features = features.permute(0, 2, 1).reshape(B, C, H, H) # [B, 768, 16, 16]

        # Decode
        return self.decoder(features)

print(" Prithvi Model Defined Manually (No HF Token Needed)")
✅ Prithvi Model Defined Manually (No HF Token Needed)
In [ ]:
# ==============================================================================
# CELL 5: TRAINING (LOG EVERY EPOCH)
# ==============================================================================
import time

# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 200  # Targeted for ~20 mins on GPU
PATIENCE = 25

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Training on: {device}")

if device.type == 'cpu':
    print(" WARNING: You are running on CPU. This will be very slow.")

# Splits
dataset = TensorDataset(X_data, y_data)
train_size = int(0.7 * len(dataset))
val_size   = int(0.15 * len(dataset))
test_size  = len(dataset) - train_size - val_size

train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# Model
model = PrithviSegmentation(in_channels=X_data.shape[1], out_classes=1).to(device)

# Loss & Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

# Loop Variables
train_losses, val_losses = [], []
best_loss = float('inf')
patience_counter = 0

print(f"\nSTARTING TRAINING ({EPOCHS} Epochs)...")
start_time = time.time()

for epoch in range(EPOCHS):
    # --- TRAIN ---
    model.train()
    running_loss = 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        preds = model(X)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # --- VALIDATE ---
    model.eval()
    running_val_loss = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            preds = model(X)
            running_val_loss += criterion(preds, y).item()

    avg_val_loss = running_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    # --- SCHEDULER & LOGGING ---
    scheduler.step(avg_val_loss)
    elapsed = (time.time() - start_time) / 60
    current_lr = optimizer.param_groups[0]['lr']

    # Print EVERY Epoch
    print(f"Epoch {epoch+1}/{EPOCHS} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.1e} | Time: {elapsed:.1f} min")

    # --- EARLY STOPPING ---
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_prithvi.pth')
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f" Early stopping triggered at Epoch {epoch+1}")
            break

total_time = (time.time() - start_time) / 60
print(f"\n Training Complete in {total_time:.1f} minutes.")
🚀 Training on: cuda
⏳ Initializing Prithvi Backbone from scratch...

STARTING TRAINING (200 Epochs)...
Epoch 1/200 | Train: 0.7448 | Val: 0.7152 | LR: 1.0e-04 | Time: 0.1 min
Epoch 2/200 | Train: 0.6588 | Val: 0.6610 | LR: 1.0e-04 | Time: 0.3 min
Epoch 3/200 | Train: 0.5964 | Val: 0.6010 | LR: 1.0e-04 | Time: 0.4 min
Epoch 4/200 | Train: 0.5584 | Val: 0.5773 | LR: 1.0e-04 | Time: 0.5 min
Epoch 5/200 | Train: 0.5314 | Val: 0.5717 | LR: 1.0e-04 | Time: 0.6 min
Epoch 6/200 | Train: 0.5267 | Val: 0.5048 | LR: 1.0e-04 | Time: 0.7 min
Epoch 7/200 | Train: 0.5313 | Val: 0.4976 | LR: 1.0e-04 | Time: 0.9 min
Epoch 8/200 | Train: 0.5173 | Val: 0.5684 | LR: 1.0e-04 | Time: 1.0 min
Epoch 9/200 | Train: 0.5065 | Val: 0.5242 | LR: 1.0e-04 | Time: 1.1 min
Epoch 10/200 | Train: 0.5134 | Val: 0.5212 | LR: 1.0e-04 | Time: 1.2 min
Epoch 11/200 | Train: 0.5084 | Val: 0.4881 | LR: 1.0e-04 | Time: 1.3 min
Epoch 12/200 | Train: 0.4999 | Val: 0.4829 | LR: 1.0e-04 | Time: 1.4 min
Epoch 13/200 | Train: 0.4968 | Val: 0.6236 | LR: 1.0e-04 | Time: 1.5 min
Epoch 14/200 | Train: 0.4996 | Val: 0.4786 | LR: 1.0e-04 | Time: 1.6 min
Epoch 15/200 | Train: 0.4978 | Val: 0.5607 | LR: 1.0e-04 | Time: 1.7 min
Epoch 16/200 | Train: 0.4977 | Val: 0.4824 | LR: 1.0e-04 | Time: 1.8 min
Epoch 17/200 | Train: 0.4975 | Val: 0.5084 | LR: 1.0e-04 | Time: 1.9 min
Epoch 18/200 | Train: 0.4974 | Val: 0.5067 | LR: 5.0e-05 | Time: 2.0 min
Epoch 19/200 | Train: 0.4933 | Val: 0.4945 | LR: 5.0e-05 | Time: 2.1 min
Epoch 20/200 | Train: 0.4859 | Val: 0.4771 | LR: 5.0e-05 | Time: 2.2 min
Epoch 21/200 | Train: 0.4934 | Val: 0.4892 | LR: 5.0e-05 | Time: 2.3 min
Epoch 22/200 | Train: 0.4865 | Val: 0.4907 | LR: 5.0e-05 | Time: 2.4 min
Epoch 23/200 | Train: 0.4876 | Val: 0.4869 | LR: 5.0e-05 | Time: 2.5 min
Epoch 24/200 | Train: 0.4860 | Val: 0.4779 | LR: 2.5e-05 | Time: 2.6 min
Epoch 25/200 | Train: 0.4917 | Val: 0.4879 | LR: 2.5e-05 | Time: 2.7 min
Epoch 26/200 | Train: 0.4845 | Val: 0.4937 | LR: 2.5e-05 | Time: 2.7 min
Epoch 27/200 | Train: 0.4820 | Val: 0.4807 | LR: 2.5e-05 | Time: 2.8 min
Epoch 28/200 | Train: 0.4916 | Val: 0.4817 | LR: 1.3e-05 | Time: 2.9 min
Epoch 29/200 | Train: 0.4783 | Val: 0.4791 | LR: 1.3e-05 | Time: 3.0 min
Epoch 30/200 | Train: 0.4813 | Val: 0.4824 | LR: 1.3e-05 | Time: 3.1 min
Epoch 31/200 | Train: 0.4807 | Val: 0.4873 | LR: 1.3e-05 | Time: 3.2 min
Epoch 32/200 | Train: 0.4844 | Val: 0.4828 | LR: 6.3e-06 | Time: 3.3 min
Epoch 33/200 | Train: 0.4899 | Val: 0.4814 | LR: 6.3e-06 | Time: 3.4 min
Epoch 34/200 | Train: 0.4819 | Val: 0.4830 | LR: 6.3e-06 | Time: 3.5 min
Epoch 35/200 | Train: 0.4892 | Val: 0.4821 | LR: 6.3e-06 | Time: 3.6 min
Epoch 36/200 | Train: 0.4855 | Val: 0.4829 | LR: 3.1e-06 | Time: 3.7 min
Epoch 37/200 | Train: 0.4824 | Val: 0.4818 | LR: 3.1e-06 | Time: 3.8 min
Epoch 38/200 | Train: 0.4853 | Val: 0.4799 | LR: 3.1e-06 | Time: 3.9 min
Epoch 39/200 | Train: 0.4857 | Val: 0.4818 | LR: 3.1e-06 | Time: 4.0 min
Epoch 40/200 | Train: 0.4790 | Val: 0.4810 | LR: 1.6e-06 | Time: 4.1 min
Epoch 41/200 | Train: 0.4859 | Val: 0.4821 | LR: 1.6e-06 | Time: 4.2 min
Epoch 42/200 | Train: 0.4816 | Val: 0.4827 | LR: 1.6e-06 | Time: 4.2 min
Epoch 43/200 | Train: 0.4796 | Val: 0.4831 | LR: 1.6e-06 | Time: 4.3 min
Epoch 44/200 | Train: 0.4813 | Val: 0.4850 | LR: 7.8e-07 | Time: 4.4 min
Epoch 45/200 | Train: 0.4780 | Val: 0.4855 | LR: 7.8e-07 | Time: 4.5 min
🛑 Early stopping triggered at Epoch 45

✅ Training Complete in 4.5 minutes.
In [ ]:
# ==============================================================================
# CELL 6: VISUALIZATION & METRICS
# ==============================================================================
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score

# 1. Load Best Model
print(" Loading Best Model...")
model.load_state_dict(torch.load('best_prithvi.pth'))
model.eval()

# 2. Plot Loss Curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# 3. Calculate Metrics on Test Set
print(" Calculating Metrics on Test Set...")
all_preds = []
all_targets = []

with torch.no_grad():
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        preds = model(X)
        preds = torch.sigmoid(preds) # Convert to probability
        preds = (preds > 0.5).float() # Threshold to 0/1

        all_preds.append(preds.cpu().numpy().flatten())
        all_targets.append(y.cpu().numpy().flatten())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

acc = accuracy_score(all_targets, all_preds)
prec = precision_score(all_targets, all_preds, zero_division=0)
rec = recall_score(all_targets, all_preds, zero_division=0)
f1 = f1_score(all_targets, all_preds, zero_division=0)
iou = jaccard_score(all_targets, all_preds, zero_division=0)

print(f"\n TEST RESULTS:")
print(f"   Accuracy:  {acc:.4f}")
print(f"   Precision: {prec:.4f}")
print(f"   Recall:    {rec:.4f}")
print(f"   F1-Score:  {f1:.4f}")
print(f"   IoU:       {iou:.4f}")

# 4. Visual Predictions
print("\n Visualizing Predictions...")

def visualize_prediction(loader, num_samples=3):
    X, y = next(iter(loader))
    X, y = X.to(device), y.to(device)

    with torch.no_grad():
        preds = torch.sigmoid(model(X))
        preds = (preds > 0.5).float()

    for idx in range(min(num_samples, len(X))):
        # Prepare Input Image (False Color Composite: NIR, Red, Green -> Bands 3,2,1)
        # Note: Adjust indices based on your specific stack.
        # Usually: B2, B3, B4, B8... -> Indices 0,1,2,3...
        # So NIR(3), Red(2), Green(1) makes a nice vegetation map.
        rgb = X[idx].cpu().numpy()[[3,2,1],:,:].transpose(1,2,0)
        rgb = np.clip(rgb, 0, 1) # Normalize for display

        mask = y[idx, 0].cpu().numpy()
        pred = preds[idx, 0].cpu().numpy()

        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(rgb)
        plt.title(f"Input (False Color) - Sample {idx+1}")
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(mask, cmap='gray')
        plt.title("Ground Truth Mask")
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(pred, cmap='gray')
        plt.title("Prithvi Prediction")
        plt.axis('off')

        plt.show()

visualize_prediction(test_loader, num_samples=3)
⏳ Loading Best Model...
No description has been provided for this image
📊 Calculating Metrics on Test Set...

✅ TEST RESULTS:
   Accuracy:  0.7358
   Precision: 0.7455
   Recall:    0.8816
   F1-Score:  0.8078
   IoU:       0.6776

🖼️ Visualizing Predictions...
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
# ==============================================================================
# CELL 6: VISUALIZATION (TRUE COLOR + GREEN OVERLAY)
# ==============================================================================
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score

# 1. Load Best Model
print(" Loading Best Model...")
model.load_state_dict(torch.load('best_prithvi.pth'))
model.eval()

# 2. Plot Loss Curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# 3. Visual Predictions (Green Overlay)
print("\n Visualizing Predictions (True Color + Green Overlay)...")

def visualize_green_overlay(loader, num_samples=3):
    X, y = next(iter(loader))
    X, y = X.to(device), y.to(device)

    with torch.no_grad():
        preds = torch.sigmoid(model(X))
        preds = (preds > 0.5).float()

    for idx in range(min(num_samples, len(X))):
        # Prepare True Color Image (Red=2, Green=1, Blue=0)
        # Scale brightness by 3x to make it visible (satellite data is dark)
        img_rgb = X[idx].cpu().numpy()[[2,1,0],:,:].transpose(1,2,0)
        img_rgb = np.clip(img_rgb * 3.0, 0, 1)

        # Ground Truth Mask
        mask_gt = y[idx, 0].cpu().numpy()

        # Prediction Mask
        mask_pred = preds[idx, 0].cpu().numpy()

        # Create Green Overlay for Prediction
        # We create a copy of the RGB image and tint pixels GREEN where mask_pred == 1
        overlay = img_rgb.copy()
        overlay[mask_pred == 1] = [0, 1, 0] # Pure Green

        # Blend: 70% Original + 30% Green Overlay
        blended = cv2.addWeighted(img_rgb, 0.7, overlay, 0.3, 0)

        plt.figure(figsize=(12, 4))

        # 1. True Color Input
        plt.subplot(1, 3, 1)
        plt.imshow(img_rgb)
        plt.title(f"True Color Input {idx+1}")
        plt.axis('off')

        # 2. Ground Truth (Black/White)
        plt.subplot(1, 3, 2)
        plt.imshow(mask_gt, cmap='gray')
        plt.title("Ground Truth (Target)")
        plt.axis('off')

        # 3. Prediction (Green Overlay)
        plt.subplot(1, 3, 3)
        plt.imshow(blended)
        plt.title("Prediction (Green Overlay)")
        plt.axis('off')

        plt.show()

visualize_green_overlay(test_loader, num_samples=3)
⏳ Loading Best Model...
No description has been provided for this image
🖼️ Visualizing Predictions (True Color + Green Overlay)...
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
# ==============================================================================
# CELL 5: TRAINING (FORCE 200 EPOCHS + COSINE ANNEALING)
# ==============================================================================
import time

# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 200  # Will complete ALL 200 epochs
# PATIENCE removed because we want to force full training

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Training on: {device}")

# Splits
dataset = TensorDataset(X_data, y_data)
train_size = int(0.7 * len(dataset))
val_size   = int(0.15 * len(dataset))
test_size  = len(dataset) - train_size - val_size

train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# Model
model = PrithviSegmentation(in_channels=X_data.shape[1], out_classes=1).to(device)

# Loss & Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)

# --- COSINE ANNEALING SCHEDULER (New) ---
# This scheduler oscillates the LR, preventing the model from getting stuck.
# T_0=50: Resets LR every 50 epochs.
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=50, T_mult=1, eta_min=1e-6
)

# Loop Variables
train_losses, val_losses = [], []
best_loss = float('inf')

print(f"\nSTARTING FULL TRAINING ({EPOCHS} Epochs)...")
print("Early Stopping is DISABLED to force full training.")
start_time = time.time()

for epoch in range(EPOCHS):
    # --- TRAIN ---
    model.train()
    running_loss = 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        preds = model(X)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # --- VALIDATE ---
    model.eval()
    running_val_loss = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            preds = model(X)
            running_val_loss += criterion(preds, y).item()

    avg_val_loss = running_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    # --- SCHEDULER STEP ---
    # Cosine Scheduler steps every epoch (not based on val_loss)
    scheduler.step()

    # --- LOGGING ---
    elapsed = (time.time() - start_time) / 60
    current_lr = optimizer.param_groups[0]['lr']

    # Print every 10 epochs to reduce clutter
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.1e} | Time: {elapsed:.1f} min")

    # Save Best Model (but DON'T stop)
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_prithvi.pth')

total_time = (time.time() - start_time) / 60
print(f"\n Full 200 Epochs Complete in {total_time:.1f} minutes.")
🚀 Training on: cuda
⏳ Initializing Prithvi Backbone from scratch...

STARTING FULL TRAINING (200 Epochs)...
Early Stopping is DISABLED to force full training.
Epoch 10/200 | Train: 0.4919 | Val: 0.5170 | LR: 9.1e-05 | Time: 1.1 min
Epoch 20/200 | Train: 0.4944 | Val: 0.5119 | LR: 6.6e-05 | Time: 2.0 min
Epoch 30/200 | Train: 0.4816 | Val: 0.5318 | LR: 3.5e-05 | Time: 3.0 min
Epoch 40/200 | Train: 0.4834 | Val: 0.5110 | LR: 1.0e-05 | Time: 4.0 min
Epoch 50/200 | Train: 0.4792 | Val: 0.4999 | LR: 1.0e-04 | Time: 4.9 min
Epoch 60/200 | Train: 0.4940 | Val: 0.5165 | LR: 9.1e-05 | Time: 5.8 min
Epoch 70/200 | Train: 0.4710 | Val: 0.5205 | LR: 6.6e-05 | Time: 6.8 min
Epoch 80/200 | Train: 0.4558 | Val: 0.5072 | LR: 3.5e-05 | Time: 7.7 min
Epoch 90/200 | Train: 0.4500 | Val: 0.5153 | LR: 1.0e-05 | Time: 8.6 min
Epoch 100/200 | Train: 0.4530 | Val: 0.5127 | LR: 1.0e-04 | Time: 9.6 min
Epoch 110/200 | Train: 0.4507 | Val: 0.5662 | LR: 9.1e-05 | Time: 10.5 min
Epoch 120/200 | Train: 0.4507 | Val: 0.5344 | LR: 6.6e-05 | Time: 11.5 min
Epoch 130/200 | Train: 0.4384 | Val: 0.5301 | LR: 3.5e-05 | Time: 12.4 min
Epoch 140/200 | Train: 0.4364 | Val: 0.5264 | LR: 1.0e-05 | Time: 13.4 min
Epoch 150/200 | Train: 0.4252 | Val: 0.5291 | LR: 1.0e-04 | Time: 14.3 min
Epoch 160/200 | Train: 0.4326 | Val: 0.5518 | LR: 9.1e-05 | Time: 15.2 min
Epoch 170/200 | Train: 0.4115 | Val: 0.5975 | LR: 6.6e-05 | Time: 16.2 min
Epoch 180/200 | Train: 0.4200 | Val: 0.5475 | LR: 3.5e-05 | Time: 17.1 min
Epoch 190/200 | Train: 0.3916 | Val: 0.5476 | LR: 1.0e-05 | Time: 18.1 min
Epoch 200/200 | Train: 0.3911 | Val: 0.5508 | LR: 1.0e-04 | Time: 19.0 min

✅ Full 200 Epochs Complete in 19.0 minutes.
In [ ]:
# ==============================================================================
# CELL 6: VISUALIZATION & METRICS
# ==============================================================================
import cv2
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score

# 1. Load Best Model (Saved during the 200 epochs)
print(" Loading Best Model Checkpoint...")
model.load_state_dict(torch.load('best_prithvi.pth'))
model.eval()

# 2. Plot Loss Curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss (200 Epochs)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# 3. Calculate Metrics on Test Set
print(" Calculating Metrics on Test Set...")
all_preds = []
all_targets = []

with torch.no_grad():
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        preds = torch.sigmoid(model(X))
        preds = (preds > 0.5).float()

        all_preds.append(preds.cpu().numpy().flatten())
        all_targets.append(y.cpu().numpy().flatten())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

# Compute Scores
acc = accuracy_score(all_targets, all_preds)
f1 = f1_score(all_targets, all_preds, zero_division=0)
iou = jaccard_score(all_targets, all_preds, zero_division=0)

print(f"\n TEST RESULTS:")
print(f"   Accuracy:  {acc:.4f}")
print(f"   F1-Score:  {f1:.4f}")
print(f"   IoU:       {iou:.4f}")

# 4. Visual Predictions (Green Overlay)
print("\n Visualizing Predictions (Green Overlay)...")

def visualize_green_overlay(loader, num_samples=3):
    X, y = next(iter(loader))
    X, y = X.to(device), y.to(device)

    with torch.no_grad():
        preds = torch.sigmoid(model(X))
        preds = (preds > 0.5).float()

    for idx in range(min(num_samples, len(X))):
        # Prepare True Color Image (Red=2, Green=1, Blue=0)
        # We assume bands are downloaded in that order or standard S2 order
        # Adjust scale (*3.0) for brightness
        img_rgb = X[idx].cpu().numpy()[[2,1,0],:,:].transpose(1,2,0)
        img_rgb = np.clip(img_rgb * 3.5, 0, 1)

        mask_gt = y[idx, 0].cpu().numpy()
        mask_pred = preds[idx, 0].cpu().numpy()

        # Create Green Overlay
        overlay = img_rgb.copy()
        overlay[mask_pred == 1] = [0, 1, 0] # Pure Green

        # Blend
        blended = cv2.addWeighted(img_rgb, 0.7, overlay, 0.3, 0)

        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.imshow(img_rgb)
        plt.title(f"Input {idx+1}")
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(mask_gt, cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(blended)
        plt.title("Prediction (Green)")
        plt.axis('off')

        plt.show()

visualize_green_overlay(test_loader, num_samples=5)
⏳ Loading Best Model Checkpoint...
No description has been provided for this image
📊 Calculating Metrics on Test Set...

✅ TEST RESULTS:
   Accuracy:  0.7873
   F1-Score:  0.8501
   IoU:       0.7392

🖼️ Visualizing Predictions (Green Overlay)...
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
# ==============================================================================
# CELL 1: SETUP & DEPENDENCIES
# ==============================================================================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Install required libraries
!pip install -q uv
!uv pip install rasterio geopandas torchinfo transformers segmentation-models-pytorch opencv-python-headless

import ee
import geemap
import rasterio
from rasterio.windows import from_bounds
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel
import matplotlib.pyplot as plt

# Authenticate with YOUR project
try:
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')
except:
    ee.Authenticate()
    ee.Initialize(project='[REDACTED_FOR_SECURITY]')

print(" Environment Ready")
Mounted at /content/drive
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.2/22.2 MB 52.4 MB/s eta 0:00:00
Using Python 3.12.12 environment at: /usr
Resolved 63 packages in 791ms
Prepared 2 packages in 41ms
Installed 2 packages in 5ms
 + segmentation-models-pytorch==0.5.0
 + torchinfo==1.8.0
✅ Environment Ready
In [ ]:
# ==============================================================================
# CELL 2: HYBRID DATA LOADING (Mature Phase)
# ==============================================================================

# 1. Setup Path
mask_path = '/content/drive/MyDrive/Punjab Wheat Mask_Binary/Punjab Mask 2024.tif'

# 2. Define Sample Area from Mask (Drive)
with rasterio.open(mask_path) as src:
    full_bounds = src.bounds
    crs = src.crs

    # Pick center
    center_x = (full_bounds.left + full_bounds.right) / 2
    center_y = (full_bounds.bottom + full_bounds.top) / 2

    # Define small window (approx 5km x 5km)
    half_size = 0.04
    sample_bounds = (
        center_x - half_size, center_y - half_size,
        center_x + half_size, center_y + half_size
    )

    print(f" Selecting sample area from mask...")

    window = from_bounds(*sample_bounds, transform=src.transform)
    mask_data = src.read(1, window=window)
    # Ensure binary
    mask_data = np.where(mask_data > 0, 1, 0).astype(np.float32)

print(f" Mask Loaded! Shape: {mask_data.shape}")

# 3. Download Matching Satellite Data (GEE)
roi = ee.Geometry.Rectangle(
    [sample_bounds[0], sample_bounds[1], sample_bounds[2], sample_bounds[3]],
    proj=str(crs), geodesic=False
)

# MATURE PHASE (Jan - Apr) - adjusted for 2024/2025 crop cycle
# Using 2024 crop mask implies crop was sown late 2023, harvested April 2024.
# So Mature Phase is Jan 2024 - Apr 2024.
START_DATE = '2024-01-01'
END_DATE   = '2024-04-15'

print(f" Downloading Sentinel-2 (Mature Phase: {START_DATE} to {END_DATE})...")

# Create a Composite (Median) to minimize clouds
# Selecting bands compatible with Prithvi (Blue, Green, Red, NIR, SWIR, SWIR)
# S2 Bands: B2, B3, B4, B8A, B11, B12
s2_img = (ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
          .filterBounds(roi)
          .filterDate(START_DATE, END_DATE)
          .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20))
          .median()
          .select(['B2', 'B3', 'B4', 'B8A', 'B11', 'B12'])
          .clip(roi))

# Download
image_data = geemap.ee_to_numpy(s2_img, region=roi, scale=10)

# 4. Handle NaNs & Normalize
image_data = np.nan_to_num(image_data, nan=0.0)
# Sentinel-2 is 0-10000 approx, clip to 0-3000 (reflectance 0.3) for visualization/training
image_data = np.clip(image_data / 3000.0, 0, 1)

print(f" Satellite Data Downloaded! Shape: {image_data.shape}")

# 5. Fix Shape Mismatch
# Resize Satellite Image to match Mask exactly
target_h, target_w = mask_data.shape
if image_data.shape[:2] != (target_h, target_w):
    print(f" Resizing Image: {image_data.shape[:2]} -> {(target_h, target_w)}")
    image_data = cv2.resize(image_data, (target_w, target_h), interpolation=cv2.INTER_LINEAR)

print(f" Final Aligned Image Shape: {image_data.shape}")
📍 Selecting sample area from mask...
✅ Mask Loaded! Shape: (891, 891)
⏳ Downloading Sentinel-2 (Mature Phase: 2024-01-01 to 2024-04-15)...
✅ Satellite Data Downloaded! Shape: (891, 892, 6)
⚠️ Resizing Image: (891, 892) -> (891, 891)
✅ Final Aligned Image Shape: (891, 891, 6)
In [ ]:
# ==============================================================================
# CELL 3: TILING
# ==============================================================================

def create_tiles(image, mask, patch_size=64):
    h, w, c = image.shape
    patches_x, patches_y = [], []

    for y in range(0, h - patch_size + 1, patch_size):
        for x in range(0, w - patch_size + 1, patch_size):
            img_p = image[y:y+patch_size, x:x+patch_size, :]
            mask_p = mask[y:y+patch_size, x:x+patch_size]

            # Filter empty patches
            if np.mean(img_p) > 0.01:
                patches_x.append(img_p)
                patches_y.append(mask_p)

    return np.array(patches_x), np.array(patches_y)

print(" Creating tiles...")
X_np, y_np = create_tiles(image_data, mask_data, patch_size=64)

# Convert to Tensor (N, C, H, W)
X_data = torch.tensor(np.transpose(X_np, (0, 3, 1, 2)), dtype=torch.float32)
y_data = torch.tensor(np.expand_dims(y_np, 1), dtype=torch.float32)

print(f" Dataset Ready: {X_data.shape}")
✂️ Creating tiles...
✅ Dataset Ready: torch.Size([169, 6, 64, 64])
In [ ]:
# ==============================================================================
# CELL 4: PRITHVI MODEL (MANUAL DEFINITION)
# ==============================================================================
import torch
import torch.nn as nn
from functools import partial



class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class PrithviBackbone(nn.Module):
    """ Simplified Prithvi-100M Backbone for 64x64 Input """
    def __init__(self, img_size=64, patch_size=4, in_chans=6, embed_dim=768, depth=12, num_heads=12):
        super().__init__()

        # Patch Embed
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        # Positional Embed
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=0.0)

        # Transformer Blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4., qkv_bias=True)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: [B, 6, 64, 64]
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, 256, 768]

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed[:, :x.shape[1], :]
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x # [B, 257, 768]

# --- 2. Main Segmentation Model ---

class PrithviSegmentation(nn.Module):
    def __init__(self, in_channels, out_classes=1):
        super().__init__()

        print(" Initializing Prithvi Backbone from scratch...")

        # Adapter: Your N bands -> 6 bands
        self.input_adapter = nn.Conv2d(in_channels, 6, kernel_size=1)

        # Manual Prithvi Backbone (No Hugging Face loading needed)
        # Matches Prithvi-100M specs: dim=768, depth=12, heads=12
        self.encoder = PrithviBackbone(
            img_size=64,
            patch_size=4, # 16x16 tokens
            in_chans=6,
            embed_dim=768,
            depth=12,
            num_heads=12
        )

        # Decoder
        embed_dim = 768
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 256, kernel_size=2, stride=2), # 16->32
            nn.BatchNorm2d(256), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 32->64
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, out_classes, kernel_size=1)
        )

    def forward(self, x):
        # Adapt Input
        x_adapted = self.input_adapter(x) # [B, 6, 64, 64]

        # Encode
        features = self.encoder(x_adapted) # [B, 257, 768]

        # Reshape (Remove CLS)
        features = features[:, 1:, :] # [B, 256, 768]
        B, N, C = features.shape
        H = int(N**0.5) # 16
        features = features.permute(0, 2, 1).reshape(B, C, H, H) # [B, 768, 16, 16]

        # Decode
        return self.decoder(features)

print(" Prithvi Model Defined Manually (No HF Token Needed)")
✅ Prithvi Model Defined Manually (No HF Token Needed)
In [ ]:
# ==============================================================================
# CELL 5: TRAINING (FORCE 500 EPOCHS + SMOOTH SCHEDULER)
# ==============================================================================
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR

# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 500       # Increased to 500
WEIGHT_DECAY = 0.05 # Stronger regularization for long training

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Training on: {device}")

# Splits
dataset = TensorDataset(X_data, y_data)
train_size = int(0.7 * len(dataset))
val_size   = int(0.15 * len(dataset))
test_size  = len(dataset) - train_size - val_size

train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# Model
model = PrithviSegmentation(in_channels=X_data.shape[1], out_classes=1).to(device)

# Loss & Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# --- CHANGED SCHEDULER (To Avoid Spikes) ---
# CosineAnnealingLR: Smoothly decreases LR from 1e-4 to 1e-6 over 500 epochs.
# No sudden "Restarts", so validation loss should be much smoother.
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

# Loop Variables
train_losses, val_losses = [], []
best_loss = float('inf')
best_epoch = 0

print(f"\nSTARTING FULL TRAINING ({EPOCHS} Epochs)...")
print("Early Stopping DISABLED. Using Smooth Cosine Decay.")
start_time = time.time()

for epoch in range(EPOCHS):
    # --- TRAIN ---
    model.train()
    running_loss = 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        preds = model(X)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # --- VALIDATE ---
    model.eval()
    running_val_loss = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            preds = model(X)
            running_val_loss += criterion(preds, y).item()

    avg_val_loss = running_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    # --- SCHEDULER STEP ---
    # Step every epoch. This will gently lower LR.
    scheduler.step()

    # --- LOGGING ---
    elapsed = (time.time() - start_time) / 60
    current_lr = optimizer.param_groups[0]['lr']

    # Print every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | LR: {current_lr:.1e} | Time: {elapsed:.1f} min")

    # Save Best Model (Track "Peak" Performance)
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        best_epoch = epoch + 1
        torch.save(model.state_dict(), 'best_prithvi_500.pth')

total_time = (time.time() - start_time) / 60
print(f"\n Full 500 Epochs Complete in {total_time:.1f} minutes.")
print(f" Best Validation Loss was {best_loss:.4f} at Epoch {best_epoch}")
🚀 Training on: cuda
⏳ Initializing Prithvi Backbone from scratch...

STARTING FULL TRAINING (500 Epochs)...
Early Stopping DISABLED. Using Smooth Cosine Decay.
Epoch 10/500 | Train: 0.4926 | Val: 0.4957 | LR: 1.0e-04 | Time: 1.1 min
Epoch 20/500 | Train: 0.4850 | Val: 0.4974 | LR: 1.0e-04 | Time: 2.0 min
Epoch 30/500 | Train: 0.4847 | Val: 0.5035 | LR: 9.9e-05 | Time: 2.9 min
Epoch 40/500 | Train: 0.4774 | Val: 0.5534 | LR: 9.8e-05 | Time: 3.9 min
Epoch 50/500 | Train: 0.4710 | Val: 0.5228 | LR: 9.8e-05 | Time: 4.9 min
Epoch 60/500 | Train: 0.4553 | Val: 0.5221 | LR: 9.7e-05 | Time: 5.8 min
Epoch 70/500 | Train: 0.4293 | Val: 0.5289 | LR: 9.5e-05 | Time: 6.8 min
Epoch 80/500 | Train: 0.4343 | Val: 0.5558 | LR: 9.4e-05 | Time: 7.7 min
Epoch 90/500 | Train: 0.4314 | Val: 0.5229 | LR: 9.2e-05 | Time: 8.7 min
Epoch 100/500 | Train: 0.4274 | Val: 0.5967 | LR: 9.1e-05 | Time: 9.6 min
Epoch 110/500 | Train: 0.4131 | Val: 0.5723 | LR: 8.9e-05 | Time: 10.6 min
Epoch 120/500 | Train: 0.4018 | Val: 0.5985 | LR: 8.7e-05 | Time: 11.5 min
Epoch 130/500 | Train: 0.4072 | Val: 0.5780 | LR: 8.4e-05 | Time: 12.5 min
Epoch 140/500 | Train: 0.4023 | Val: 0.5846 | LR: 8.2e-05 | Time: 13.4 min
Epoch 150/500 | Train: 0.3725 | Val: 0.6684 | LR: 8.0e-05 | Time: 14.4 min
Epoch 160/500 | Train: 0.3378 | Val: 0.7010 | LR: 7.7e-05 | Time: 15.3 min
Epoch 170/500 | Train: 0.3029 | Val: 0.6721 | LR: 7.4e-05 | Time: 16.3 min
Epoch 180/500 | Train: 0.2590 | Val: 0.7995 | LR: 7.2e-05 | Time: 17.2 min
Epoch 190/500 | Train: 0.2286 | Val: 0.8050 | LR: 6.9e-05 | Time: 18.2 min
Epoch 200/500 | Train: 0.2003 | Val: 0.7671 | LR: 6.6e-05 | Time: 19.2 min
Epoch 210/500 | Train: 0.1876 | Val: 0.8271 | LR: 6.3e-05 | Time: 20.1 min
Epoch 220/500 | Train: 0.1773 | Val: 0.8541 | LR: 6.0e-05 | Time: 21.1 min
Epoch 230/500 | Train: 0.1657 | Val: 0.9243 | LR: 5.7e-05 | Time: 22.0 min
Epoch 240/500 | Train: 0.1514 | Val: 0.9354 | LR: 5.4e-05 | Time: 23.0 min
Epoch 250/500 | Train: 0.1515 | Val: 0.9574 | LR: 5.1e-05 | Time: 23.9 min
Epoch 260/500 | Train: 0.1364 | Val: 0.9202 | LR: 4.7e-05 | Time: 24.9 min
Epoch 270/500 | Train: 0.1302 | Val: 0.9741 | LR: 4.4e-05 | Time: 25.9 min
Epoch 280/500 | Train: 0.1248 | Val: 1.0028 | LR: 4.1e-05 | Time: 26.8 min
Epoch 290/500 | Train: 0.1163 | Val: 0.9780 | LR: 3.8e-05 | Time: 27.8 min
Epoch 300/500 | Train: 0.1082 | Val: 1.0251 | LR: 3.5e-05 | Time: 28.7 min
Epoch 310/500 | Train: 0.1107 | Val: 1.0805 | LR: 3.2e-05 | Time: 29.7 min
Epoch 320/500 | Train: 0.0964 | Val: 1.0552 | LR: 2.9e-05 | Time: 30.7 min
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipython-input-3282085532.py in <cell line: 0>()
     57         X, y = X.to(device), y.to(device)
     58         optimizer.zero_grad()
---> 59         preds = model(X)
     60         loss = criterion(preds, y)
     61         loss.backward()

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/tmp/ipython-input-3994323363.py in forward(self, x)
    126 
    127         # Encode
--> 128         features = self.encoder(x_adapted) # [B, 257, 768]
    129 
    130         # Reshape (Remove CLS)

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/tmp/ipython-input-3994323363.py in forward(self, x)
     84 
     85         for blk in self.blocks:
---> 86             x = blk(x)
     87 
     88         x = self.norm(x)

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/tmp/ipython-input-3994323363.py in forward(self, x)
     48     def forward(self, x):
     49         x = x + self.attn(self.norm1(x))
---> 50         x = x + self.mlp(self.norm2(x))
     51         return x
     52 

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/container.py in forward(self, input)
    248         """
    249         for module in self:
--> 250             input = module(input)
    251         return input
    252 

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    132         Runs the forward pass.
    133         """
--> 134         return F.linear(input, self.weight, self.bias)
    135 
    136     def extra_repr(self) -> str:

KeyboardInterrupt: 
In [ ]:
# ==========================================
# EVALUATION (Run this after stopping training)
# ==========================================
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
import numpy as np
import torch

def evaluate_model(loader, model, device):
    model.eval()
    all_preds = []
    all_targets = []

    print(" Calculating Metrics on Test Set (using current model state)...")

    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            outputs = model(X)

            # Convert logits to binary (0 or 1) using threshold 0.5
            preds = (torch.sigmoid(outputs) > 0.5).float()

            # Flatten to 1D array for metric calculation
            all_preds.append(preds.cpu().numpy().flatten())
            all_targets.append(y.cpu().numpy().flatten())

    # Concatenate all batches
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_targets)

    # Calculate Metrics
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    iou = jaccard_score(y_true, y_pred, zero_division=0)

    print("\n FINAL TEST RESULTS (Epoch ~310):")
    print(f"   Accuracy:  {acc:.4f}")
    print(f"   Precision: {prec:.4f}")
    print(f"   Recall:    {rec:.4f}")
    print(f"   F1-Score:  {f1:.4f}")
    print(f"   IoU:       {iou:.4f}")

    return acc, f1, iou

# Run Evaluation on Test Loader
# (Ensure 'test_loader' was defined in your split earlier. If not, use 'val_loader')
try:
    evaluate_model(test_loader, model, device)
except NameError:
    print(" 'test_loader' not found. Evaluating on 'val_loader' instead.")
    evaluate_model(val_loader, model, device)

# Save this model manually since you interrupted
torch.save(model.state_dict(), 'model_stopped_epoch_310.pth')
print(" Model saved as 'model_stopped_epoch_310.pth'")
 Calculating Metrics on Test Set (using current model state)...

 FINAL TEST RESULTS (Epoch ~310):
   Accuracy:  0.7024
   Precision: 0.7750
   Recall:    0.7993
   F1-Score:  0.7870
   IoU:       0.6487
 Model saved as 'model_stopped_epoch_310.pth'
In [ ]:
import matplotlib.pyplot as plt

def plot_history(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss', color='blue')
    plt.plot(val_losses, label='Validation Loss', color='orange')
    plt.title('Training vs Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# Run Plot
# (This uses the lists 'train_losses' and 'val_losses' from your training loop)
try:
    plot_history(train_losses, val_losses)
except NameError:
    print(" Could not find loss history lists. Did you run the training loop?")
No description has been provided for this image
In [ ]:
import numpy as np
import torch
import matplotlib.pyplot as plt

def visualize_predictions(loader, model, device, num_samples=3):
    model.eval()
    samples_shown = 0

    # Randomly pick a batch
    data_iter = iter(loader)
    images, masks = next(data_iter)

    images = images.to(device)

    with torch.no_grad():
        outputs = model(images)
        preds = (torch.sigmoid(outputs) > 0.5).float()

    # Move to CPU for plotting
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    preds = preds.cpu().numpy()

    plt.figure(figsize=(12, 4 * num_samples))

    for i in range(min(num_samples, len(images))):
        # 1. Prepare RGB Image (Bands 3,2,1 usually -> Indices 2,1,0)
        # Assuming shape [Channels, H, W] -> [H, W, Channels]
        # Sentinel-2: R=Band3(idx2), G=Band2(idx1), B=Band1(idx0) for 12-band stack
        # (Adjust indices if your stack order is different)
        img_rgb = images[i][[3, 2, 1], :, :].transpose(1, 2, 0) # Use bands 4,3,2 (NIR,Red,Green) or 3,2,1

        # Normalize for display (0-1 range)
        img_rgb = (img_rgb - img_rgb.min()) / (img_rgb.max() - img_rgb.min())

        # 2. Prepare Masks
        true_mask = masks[i].squeeze()
        pred_mask = preds[i].squeeze()

        # Plot
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(img_rgb)
        plt.title("Satellite Input (RGB)")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(true_mask, cmap='gray')
        plt.title("Ground Truth Mask")
        plt.axis('off')

        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(pred_mask, cmap='gray')
        plt.title("Model Prediction")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Run Visualization
try:
    visualize_predictions(test_loader, model, device)
except NameError:
    print(" 'test_loader' not found. Using 'val_loader' instead.")
    visualize_predictions(val_loader, model, device)
No description has been provided for this image
In [1]:
!pip install terratorch
Requirement already satisfied: terratorch in /usr/local/lib/python3.12/dist-packages (1.2)
Requirement already satisfied: torch>2.0 in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.9.0+cu126)
Requirement already satisfied: numpy>=2.2 in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.2.6)
Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.24.0+cu126)
Requirement already satisfied: rioxarray in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.20.0)
Requirement already satisfied: albumentations in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.0.8)
Requirement already satisfied: albucore in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.0.24)
Requirement already satisfied: rasterio in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.4.4)
Requirement already satisfied: torchmetrics in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.8.2)
Requirement already satisfied: geopandas in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.1.1)
Requirement already satisfied: lightly==1.5.22 in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.5.22)
Requirement already satisfied: h5py in /usr/local/lib/python3.12/dist-packages (from terratorch) (3.15.1)
Requirement already satisfied: lightning>=2.6.0 in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.6.0)
Requirement already satisfied: segmentation-models-pytorch in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.5.0)
Requirement already satisfied: jsonargparse>=4.40.0 in /usr/local/lib/python3.12/dist-packages (from terratorch) (4.45.0)
Requirement already satisfied: torchgeo in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.8.0)
Requirement already satisfied: einops in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.8.1)
Requirement already satisfied: timm>=1.0.15 in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.0.22)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.0.10)
Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.36.0)
Requirement already satisfied: tifffile in /usr/local/lib/python3.12/dist-packages (from terratorch) (2025.12.12)
Requirement already satisfied: python-box in /usr/local/lib/python3.12/dist-packages (from terratorch) (7.3.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from terratorch) (4.67.1)
Requirement already satisfied: wandb in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.23.1)
Requirement already satisfied: tensorboard in /usr/local/lib/python3.12/dist-packages (from terratorch) (2.19.0)
Requirement already satisfied: diffusers in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.36.0)
Requirement already satisfied: scikit-learn>=1.3.2 in /usr/local/lib/python3.12/dist-packages (from terratorch) (1.6.1)
Requirement already satisfied: scikit-image in /usr/local/lib/python3.12/dist-packages (from terratorch) (0.25.2)
Requirement already satisfied: certifi>=14.05.14 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2025.11.12)
Requirement already satisfied: hydra-core>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (1.3.2)
Requirement already satisfied: lightly_utils~=0.0.0 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (0.0.2)
Requirement already satisfied: python_dateutil>=2.5.3 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.9.0.post0)
Requirement already satisfied: requests>=2.27.0 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.32.4)
Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (1.17.0)
Requirement already satisfied: pydantic>=1.10.5 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.12.3)
Requirement already satisfied: pytorch_lightning>=1.0.4 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.6.0)
Requirement already satisfied: urllib3>=1.25.3 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (2.5.0)
Requirement already satisfied: aenum>=3.1.11 in /usr/local/lib/python3.12/dist-packages (from lightly==1.5.22->terratorch) (3.1.16)
Requirement already satisfied: PyYAML>=3.13 in /usr/local/lib/python3.12/dist-packages (from jsonargparse>=4.40.0->terratorch) (6.0.3)
Requirement already satisfied: fsspec<2027.0,>=2022.5.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (2025.3.0)
Requirement already satisfied: lightning-utilities<2.0,>=0.10.0 in /usr/local/lib/python3.12/dist-packages (from lightning>=2.6.0->terratorch) (0.15.2)
Requirement already satisfied: packaging<27.0,>=20.0 in /usr/local/lib/python3.12/dist-packages (from lightning>=2.6.0->terratorch) (25.0)
Requirement already satisfied: typing-extensions<6.0,>4.5.0 in /usr/local/lib/python3.12/dist-packages (from lightning>=2.6.0->terratorch) (4.15.0)
Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.3.2->terratorch) (1.16.3)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.3.2->terratorch) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.3.2->terratorch) (3.6.0)
Requirement already satisfied: safetensors in /usr/local/lib/python3.12/dist-packages (from timm>=1.0.15->terratorch) (0.7.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.20.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.6.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.1.6)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.3.20)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (1.11.1.6)
Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>2.0->terratorch) (3.5.0)
Requirement already satisfied: stringzilla>=3.10.4 in /usr/local/lib/python3.12/dist-packages (from albucore->terratorch) (4.5.1)
Requirement already satisfied: simsimd>=5.9.2 in /usr/local/lib/python3.12/dist-packages (from albucore->terratorch) (6.5.3)
Requirement already satisfied: opencv-python-headless>=4.9.0.80 in /usr/local/lib/python3.12/dist-packages (from albucore->terratorch) (4.12.0.88)
Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (8.7.0)
Requirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (0.28.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (2025.11.3)
Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from diffusers->terratorch) (11.3.0)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->terratorch) (1.2.0)
Requirement already satisfied: pyogrio>=0.7.2 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (0.12.1)
Requirement already satisfied: pandas>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (2.2.2)
Requirement already satisfied: pyproj>=3.5.0 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (3.7.2)
Requirement already satisfied: shapely>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from geopandas->terratorch) (2.1.2)
Requirement already satisfied: affine in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (2.4.0)
Requirement already satisfied: attrs in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (25.4.0)
Requirement already satisfied: click!=8.2.*,>=4.0 in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (8.3.1)
Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (0.7.2)
Requirement already satisfied: click-plugins in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (1.1.1.2)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from rasterio->terratorch) (3.2.5)
Requirement already satisfied: xarray>=2024.7.0 in /usr/local/lib/python3.12/dist-packages (from rioxarray->terratorch) (2025.12.0)
Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image->terratorch) (2.37.2)
Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image->terratorch) (0.4)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (1.4.0)
Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (1.76.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (3.10)
Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (5.29.5)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from tensorboard->terratorch) (3.1.4)
Requirement already satisfied: kornia>=0.8.2 in /usr/local/lib/python3.12/dist-packages (from torchgeo->terratorch) (0.8.2)
Requirement already satisfied: matplotlib>=3.6 in /usr/local/lib/python3.12/dist-packages (from torchgeo->terratorch) (3.10.0)
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb->terratorch) (3.1.45)
Requirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb->terratorch) (4.5.1)
Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb->terratorch) (2.47.0)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (3.13.2)
Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.12/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb->terratorch) (4.0.12)
Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers->terratorch) (4.12.0)
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers->terratorch) (1.0.9)
Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers->terratorch) (3.11)
Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->diffusers->terratorch) (0.16.0)
Requirement already satisfied: omegaconf<2.4,>=2.2 in /usr/local/lib/python3.12/dist-packages (from hydra-core>=1.0.0->lightly==1.5.22->terratorch) (2.3.0)
Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.12/dist-packages (from hydra-core>=1.0.0->lightly==1.5.22->terratorch) (4.9.3)
Requirement already satisfied: docstring-parser>=0.17 in /usr/local/lib/python3.12/dist-packages (from jsonargparse[signatures]>=4.25->torchgeo->terratorch) (0.17.0)
Requirement already satisfied: typeshed-client>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from jsonargparse[signatures]>=4.25->torchgeo->terratorch) (2.8.2)
Requirement already satisfied: kornia_rs>=0.1.9 in /usr/local/lib/python3.12/dist-packages (from kornia>=0.8.2->torchgeo->terratorch) (0.1.10)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (4.61.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.6->torchgeo->terratorch) (1.4.9)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.0.0->geopandas->terratorch) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.0.0->geopandas->terratorch) (2025.3)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.10.5->lightly==1.5.22->terratorch) (0.7.0)
Requirement already satisfied: pydantic-core==2.41.4 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.10.5->lightly==1.5.22->terratorch) (2.41.4)
Requirement already satisfied: typing-inspection>=0.4.2 in /usr/local/lib/python3.12/dist-packages (from pydantic>=1.10.5->lightly==1.5.22->terratorch) (0.4.2)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.27.0->lightly==1.5.22->terratorch) (3.4.4)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>2.0->terratorch) (1.3.0)
Requirement already satisfied: markupsafe>=2.1.1 in /usr/local/lib/python3.12/dist-packages (from werkzeug>=1.0.1->tensorboard->terratorch) (3.0.3)
Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib_metadata->diffusers->terratorch) (3.23.0)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (1.4.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (0.4.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2027.0,>=2022.5.0->lightning>=2.6.0->terratorch) (1.22.0)
Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb->terratorch) (5.0.2)
Requirement already satisfied: importlib_resources>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from typeshed-client>=2.8.2->jsonargparse[signatures]>=4.25->torchgeo->terratorch) (6.5.2)
In [2]:
from terratorch.registry import BACKBONE_REGISTRY

# Load the Prithvi-EO-1.0-100M model backbone
# Use the appropriate key for the 100M model (e.g., 'prithvi_eo_1_0_100m' or similar key found in the docs)
# The key below is for v2, please verify the exact 100M v1 key in the Prithvi docs
model = BACKBONE_REGISTRY.build("prithvi_eo_v2_600_tl", pretrained=True)
model.eval() # Set model to evaluation mode
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipython-input-2529966910.py in <cell line: 0>()
----> 1 from terratorch.registry import BACKBONE_REGISTRY
      2 
      3 # Load the Prithvi-EO-1.0-100M model backbone
      4 # Use the appropriate key for the 100M model (e.g., 'prithvi_eo_1_0_100m' or similar key found in the docs)
      5 # The key below is for v2, please verify the exact 100M v1 key in the Prithvi docs

/usr/local/lib/python3.12/dist-packages/terratorch/__init__.py in <module>
      5 os.environ["NO_ALBUMENTATIONS_UPDATE"] = "True"
      6 
----> 7 import terratorch.models  # noqa: F401
      8 from terratorch.models.backbones import *  # register models in registries # noqa: F403
      9 from terratorch.registry import BACKBONE_REGISTRY, DECODER_REGISTRY, MODEL_FACTORY_REGISTRY, FULL_MODEL_REGISTRY  # noqa: F401

/usr/local/lib/python3.12/dist-packages/terratorch/models/__init__.py in <module>
      4 import logging
      5 
----> 6 import terratorch.models.necks  # register necks  # noqa: F401
      7 from terratorch.models.clay1_5_model_factory import Clay1_5ModelFactory
      8 from terratorch.models.clay_model_factory import ClayModelFactory

/usr/local/lib/python3.12/dist-packages/terratorch/models/necks.py in <module>
     10 from einops import rearrange
     11 from torch import nn
---> 12 from torchvision.ops import FeaturePyramidNetwork
     13 
     14 from terratorch.registry import NECK_REGISTRY, TERRATORCH_NECK_REGISTRY

/usr/local/lib/python3.12/dist-packages/torchvision/__init__.py in <module>
      8 # .extensions) before entering _meta_registrations.
      9 from .extension import _HAS_OPS  # usort:skip
---> 10 from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
     11 
     12 try:

/usr/local/lib/python3.12/dist-packages/torchvision/models/__init__.py in <module>
      1 from .alexnet import *
----> 2 from .convnext import *
      3 from .densenet import *
      4 from .efficientnet import *
      5 from .googlenet import *

/usr/local/lib/python3.12/dist-packages/torchvision/models/convnext.py in <module>
      7 from torch.nn import functional as F
      8 
----> 9 from ..ops.misc import Conv2dNormActivation, Permute
     10 from ..ops.stochastic_depth import StochasticDepth
     11 from ..transforms._presets import ImageClassification

/usr/local/lib/python3.12/dist-packages/torchvision/ops/__init__.py in <module>
     21 from .giou_loss import generalized_box_iou_loss
     22 from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation
---> 23 from .poolers import MultiScaleRoIAlign
     24 from .ps_roi_align import ps_roi_align, PSRoIAlign
     25 from .ps_roi_pool import ps_roi_pool, PSRoIPool

/usr/local/lib/python3.12/dist-packages/torchvision/ops/poolers.py in <module>
      8 
      9 from ..utils import _log_api_usage_once
---> 10 from .roi_align import roi_align
     11 
     12 

/usr/local/lib/python3.12/dist-packages/torchvision/ops/roi_align.py in <module>
      5 import torch.fx
      6 from torch import nn, Tensor
----> 7 from torch._dynamo.utils import is_compile_supported
      8 from torch.jit.annotations import BroadcastingList2
      9 from torch.nn.modules.utils import _pair

KeyboardInterrupt: