手撕FocalLoss

news/2025/2/26 20:45:52

文章目录

  • 前言
  • 1、FocalLoss
    • 1.1.公式定义
  • 2、代码
  • 总结


前言

 为了加深对Focal Loss理解,本文提供了一个简单的手写Demo。

1、FocalLoss

 介绍FocalLoss的文章已经很多了,这里简单提一下:

1.1.公式定义

 Focal Loss 的公式如下:

FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t) FL(pt)=αt(1pt)γlog(pt)

 ;根据真实标签 y y y 的不同,Focal Loss 可以分为两种情况:

 1) 当真实标签 y = 1 y = 1 y=1 时,公式变为:

FL ( p ) = − α ( 1 − p ) γ log ⁡ ( p ) \text{FL}(p) = -\alpha (1 - p)^{\gamma} \log(p) FL(p)=α(1p)γlog(p)

 2) 当真实标签 y = 0 y = 0 y=0 时,公式变为:

FL ( p ) = − ( 1 − α ) p γ log ⁡ ( 1 − p ) \text{FL}(p) = -(1 - \alpha) p^{\gamma} \log(1 - p) FL(p)=(1α)pγlog(1p)

 Focal Loss 的完整公式可以写为:

FL ( y , p ) = − [ y ⋅ α ( 1 − p ) γ log ⁡ ( p ) + ( 1 − y ) ⋅ ( 1 − α ) p γ log ⁡ ( 1 − p ) ] \text{FL}(y, p) = -\left[ y \cdot \alpha (1 - p)^{\gamma} \log(p) + (1 - y) \cdot (1 - \alpha) p^{\gamma} \log(1 - p) \right] FL(y,p)=[yα(1p)γlog(p)+(1y)(1α)pγlog(1p)]

其中 p p p表示经过sigmoid的预测值。本文实现的是完整版的公式,而且没有引入额外的封装函数。

2、代码

python">import torch
import torch.nn as nn
import torch.nn.functional as F

# focal_loss = pos_loss + neg_loss 
# if y == 1: pos_loss = -|1-p|^gamma * log(p)  
# if y == 0: neg_loss = -|0-p|^gamma * log(1-p)
class FocalLoss(nn.Module):
    def __init__(self,alpha=0.25,gamma=2.0,reduce='sum'):
        super(FocalLoss,self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self,classifications,targets):
        alpha = self.alpha
        gamma = self.gamma
        classifications = classifications.view(-1)
        p = torch.sigmoid(classifications)
        targets = targets.view(-1)
        # 获取pos 和 neg 的索引
        pos_idx = torch.nonzero(targets==1).view(-1)
        neg_idx = torch.nonzero(targets==0).view(-1)
        # step1: cpt pos loss       
        pos_loss = -(1-p[pos_idx]).abs() ** gamma * torch.log(p[pos_idx])
        # step2: cpt neg loss 
        neg_loss = -(0-p[neg_idx]).abs() ** gamma * torch.log(1-p[neg_idx])
        loss = torch.cat((pos_loss, neg_loss), dim=0)
        # targets 也需要重新排序 来跟loss值对应 
        concat_idx = torch.cat((pos_idx, neg_idx), dim=0)
        targets = targets[concat_idx]
        if alpha >= 0:
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
            loss = alpha_t * loss
        if self.reduce=='sum':
            loss = loss.sum()
        elif self.reduce=='mean':
            loss = loss.mean()
        else:
            raise ValueError('reduce type is wrong!')
        return loss


# ---test unit --- #
def main():
    # single cls focal loss 
    focal_loss = FocalLoss()
    pred = torch.FloatTensor([0.1,0.9,0.2,0.8,0.7]) # nb_anchors :5
    tgt  = torch.FloatTensor([0,1,0,1,1])           # neg:0 pos:1 ; no ignore
    loss = focal_loss(pred, tgt)
    print('loss:', loss) 

总结

 本文只是简单实现了一个二分类的FocalLoss,旨在加深读者对其理解。欢迎批评指正。


http://www.niftyadmin.cn/n/5869156.html

相关文章

利用Python爬虫精准获取VIP商品详情:实战案例指南

在电商竞争日益激烈的今天,VIP商品的详细信息对于商家制定策略、优化用户体验以及进行市场分析具有至关重要的价值。然而,VIP商品页面结构复杂且可能随时更新,这给爬虫开发带来了不小的挑战。本文将通过一个完整的案例,展示如何利…

mysql将表导出为sql文件

使用mysqldump命令 mysqldump是MySQL提供的一个命令行工具,用于导出数据库或表的结构和数据。要将表导出为SQL文件,可以使用以下命令: mysqldump -uroot -p123456 database_name table_name > output_file.sql

Python 环境管理介绍

pip pip 是 Python 的标准包管理工具&#xff0c;用于安装和管理 Python 软件包。它允许你从 Python 包索引&#xff08;PyPI&#xff09;下载并安装第三方库&#xff0c;并能自动解决依赖问题。 第三方库的安装与卸载 pip install <package>pip uninstall <packag…

【Python LeetCode 专题】动态规划

斐波那契类型70. 爬楼梯746. 使用最小花费爬楼梯198. 打家劫舍740. 删除并获得点数矩阵62. 不同路径方法一:二维 DP方法二:递归(`@cache`)64. 最小路径和63. 不同路径 II120. 三角形最小路径和221. 最大正方形字符串139. 单词拆分5. 最长回文子串516. 最长回文子序列72. 编…

利用 Open3D 保存并载入相机视角的简单示例

1. 前言 在使用 Open3D 进行三维可视化和点云处理时&#xff0c;有时需要将当前的视角&#xff08;Camera Viewpoint&#xff09;保存下来&#xff0c;以便下次再次打开时能够还原到同样的视角。本文将演示如何在最新的 Open3D GUI 界面&#xff08;o3d.visualization.gui / o…

ref和reactive的区别 Vue3

Vue3中ref和reactive的区别 ref 可以定义基本数据类型&#xff0c;也可定义对象类型的响应式数据 reactive 只能定义对象类型的响应式数据 ref和reactive定义对象类型的响应式数据有什么不同 不同点1 ref定义的响应式数据&#xff0c;取值时需要先 .value 不同点2 替换整…

单片机的串口(USART)

Tx - 数据的发送引脚&#xff0c;Rx - 数据的接受引脚。 串口的数据帧格式 空闲状态高电平&#xff0c;起始位低电平&#xff0c;数据位有8位校验位&#xff0c;9位校验位&#xff0c;停止位是高电平保持一位或者半位&#xff0c;又或者两位的状态。 8位无校验位传输一个字节…

KubeSphere部署redis集群

一、部署前准备 &#xff08;一&#xff09;KubeSphere部署redis集群思路 参考上一篇文章的部署思路&#xff1a;KubeSphere安装mysql-CSDN博客 &#xff08;二&#xff09;部署方法参考 1、参考Docker Hub的中docker部署redis的方法 部署方法按照Docker Hub官网部署redis的…