Skip to content

executing ./satclip/main.py : Dataset not found or corrupted. #28

@raphael10-collab

Description

@raphael10-collab

Executing the following main.py :

import lightning.pytorch
import torch
from datamodules.s2geo_dataset import S2GeoDataModule
from lightning.pytorch.cli import LightningCLI
from loss import SatCLIPLoss
from model import SatCLIP

torch.set_float32_matmul_precision('high')

class SatCLIPLightningModule(lightning.pytorch.LightningModule):
    def __init__(
        self,
        embed_dim=512,
        image_resolution=256,
        vision_layers=12,
        vision_width=768,
        vision_patch_size=32,
        in_channels=4,
        le_type="grid",
        pe_type="siren",
        frequency_num=16,
        max_radius=260,
        min_radius=1,
        legendre_polys=16,
        harmonics_calculation="analytic",
        sh_embedding_dims=32,
        learning_rate=1e-4,
        weight_decay=0.01,
        num_hidden_layers=2,
        capacity=256,
    ) -> None:
        super().__init__()

        self.model = SatCLIP(
            embed_dim=embed_dim,
            image_resolution=image_resolution,
            vision_layers=vision_layers,
            vision_width=vision_width,
            vision_patch_size=vision_patch_size,
            in_channels=in_channels,
            le_type=le_type,
            pe_type=pe_type,
            frequency_num=frequency_num,
            max_radius=max_radius,
            min_radius=min_radius,
            legendre_polys=legendre_polys,
            harmonics_calculation=harmonics_calculation,
            sh_embedding_dims=sh_embedding_dims,
            num_hidden_layers=num_hidden_layers,
            capacity=capacity,
        )

        self.loss_fun = SatCLIPLoss()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()

    def common_step(self, batch, batch_idx):
        images = batch["image"]
        t_points = batch["point"].float()
        logits_per_image, logits_per_coord = self.model(images, t_points)
        return self.loss_fun(logits_per_image, logits_per_coord)

    def training_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        exclude = (
            lambda n, p: p.ndim < 2
            or "bn" in n
            or "ln" in n
            or "bias" in n
            or "logit_scale" in n
        )
        include = lambda n, p: not exclude(n, p)

        named_parameters = list(self.model.named_parameters())
        gain_or_bias_params = [
            p for n, p in named_parameters if exclude(n, p) and p.requires_grad
        ]
        rest_params = [
            p for n, p in named_parameters if include(n, p) and p.requires_grad
        ]

        optimizer = torch.optim.AdamW(
            [
                {"params": gain_or_bias_params, "weight_decay": 0.0},
                {
                    "params": rest_params,
                    "weight_decay": self.weight_decay,
                },  # specify in configs/default.yaml
            ],
            lr=self.learning_rate,  # specify in configs/default.yaml
        )

        return optimizer


class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.add_argument("--watchmodel", action="store_true")


def cli_main(default_config_filename="./configs/default.yaml"):
    save_config_fn = default_config_filename.replace(".yaml", "-latest.yaml")
    # modify configs/default.yaml for learning rate etc.
    cli = MyLightningCLI(
        model_class=SatCLIPLightningModule,
        datamodule_class=S2GeoDataModule,
        save_config_kwargs=dict(
            config_filename=save_config_fn,
            overwrite=True,
        ),
        trainer_defaults={
            "accumulate_grad_batches": 16,
            "log_every_n_steps": 10,
        },
        parser_kwargs={"default_config_files": [default_config_filename]},
        seed_everything_default=0,
        run=False,
    )

    ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
    run_name = f"SatCLIP_S2_{ts}"
    if cli.trainer.logger is not None:
        cli.trainer.logger.experiment.name = run_name
        # this seems to be necessary to force logging of datamodule hyperparams
        cli.trainer.logger.log_hyperparams(cli.datamodule.hparams)

    # Create folder to log configs
    # NOTE: Lightning does not handle config paths with subfolders
    dirname_cfg = Path(default_config_filename).parent
    dir_log_cfg = Path(cli.trainer.log_dir) / dirname_cfg
    dir_log_cfg.mkdir(parents=True, exist_ok=True)

    cli.trainer.fit(
        model=cli.model,
        datamodule=cli.datamodule,
    )
    
if __name__ == "__main__":
    config_fn = "./configs/default.yaml"

    #A100 go vroom vroom 🚗💨
    #if torch.cuda.get_device_name(device=0)=='NVIDIA A100 80GB PCIe':
        #torch.backends.cuda.matmul.allow_tf32 = True
        #print('Superfastmode! 🚀')
    #elif:
        #torch.backends.cuda.matmul.allow_tf32 = False
    #else:
    torch.backends.cpu

I get this error:

(.satclip) (base) root@WorldMap:~/VectorEmbeddingsFromGeoCoordinates/satclip/satclip# python main.py 
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/main.py", line 162
    torch.backends.cpu
IndentationError: unexpected indent
(.satclip) (base) root@WorldMap:~/VectorEmbeddingsFromGeoCoordinates/satclip/satclip# nano main.py 
(.satclip) (base) root@WorldMap:~/VectorEmbeddingsFromGeoCoordinates/satclip/satclip# python main.py 
2025-05-20 10:56:04.598684: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1747731364.714080   74468 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747731364.756910   74468 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747731364.919538   74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747731364.919665   74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747731364.919674   74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747731364.919680   74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-05-20 10:56:04.949589: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Seed set to 0
using vision transformer
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

            No dataset found. To download, please follow instructions on: https://github.com/microsoft/satclip
            
/data/s2/index.csv missing
Traceback (most recent call last):
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/main.py", line 163, in <module>
    cli_main(config_fn)
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/main.py", line 146, in cli_main
    cli.trainer.fit(
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
    call._call_and_handle_interrupt(
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 974, in _run
    call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 107, in _call_setup_hook
    _call_lightning_datamodule_hook(trainer, "setup", stage=fn)
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 198, in _call_lightning_datamodule_hook
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/datamodules/s2geo_dataset.py", line 52, in setup
    dataset = S2Geo(root=self.data_dir, transform=self.train_transform, mode=self.mode)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/datamodules/s2geo_dataset.py", line 111, in __init__
    raise RuntimeError("Dataset not found or corrupted.")
RuntimeError: Dataset not found or corrupted.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions