diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index c1aadb570145..4bfecfbbe2cf 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -22,6 +22,7 @@ #include #include #include +#include #include static struct kset *iommu_group_kset; @@ -185,10 +186,21 @@ int iommu_probe_device(struct device *dev) if (!iommu_get_dev_param(dev)) return -ENOMEM; + if (!try_module_get(ops->owner)) { + ret = -EINVAL; + goto err_free_dev_param; + } + ret = ops->add_device(dev); if (ret) - iommu_free_dev_param(dev); + goto err_module_put; + return 0; + +err_module_put: + module_put(ops->owner); +err_free_dev_param: + iommu_free_dev_param(dev); return ret; } @@ -199,7 +211,10 @@ void iommu_release_device(struct device *dev) if (dev->iommu_group) ops->remove_device(dev); - iommu_free_dev_param(dev); + if (dev->iommu_param) { + module_put(ops->owner); + iommu_free_dev_param(dev); + } } static struct iommu_domain *__iommu_domain_alloc(struct bus_type *bus, diff --git a/include/linux/iommu.h b/include/linux/iommu.h index 29bac5345563..d9dab5a3e912 100644 --- a/include/linux/iommu.h +++ b/include/linux/iommu.h @@ -245,6 +245,7 @@ struct iommu_iotlb_gather { * @sva_get_pasid: Get PASID associated to a SVA handle * @page_response: handle page request response * @pgsize_bitmap: bitmap of all possible supported page sizes + * @owner: Driver module providing these ops */ struct iommu_ops { bool (*capable)(enum iommu_cap); @@ -308,6 +309,7 @@ struct iommu_ops { struct iommu_page_response *msg); unsigned long pgsize_bitmap; + struct module *owner; }; /**