Source code for topobenchmarkx.loss.loss

"""Loss module for the topobenchmarkx package."""

import torch
import torch_geometric

from topobenchmarkx.loss.base import AbstractLoss


[docs] class TBXLoss(AbstractLoss): r"""Defines the default model loss for the given task. Parameters ---------- task : str Task type, either "classification" or "regression". loss_type : str, optional Loss type, either "cross_entropy", "mse", or "mae" (default: None). """ def __init__(self, task, loss_type=None): super().__init__() self.task = task if task == "classification" and loss_type == "cross_entropy": self.criterion = torch.nn.CrossEntropyLoss() elif task == "regression" and loss_type == "mse": self.criterion = torch.nn.MSELoss() elif task == "regression" and loss_type == "mae": self.criterion = torch.nn.L1Loss() else: raise Exception("Loss is not defined") self.loss_type = loss_type def __repr__(self) -> str: return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})"
[docs] def forward(self, model_out: dict, batch: torch_geometric.data.Data): r"""Forward pass of the loss function. Parameters ---------- model_out : dict Dictionary containing the model output. batch : torch_geometric.data.Data Batch object containing the batched domain data. Returns ------- dict Dictionary containing the model output with the loss. """ logits = model_out["logits"] target = model_out["labels"] if self.task == "regression": target = target.unsqueeze(1) model_out["loss"] = self.criterion(logits, target) return model_out