| 
83 | 83 |     "import os\n",  | 
84 | 84 |     "import sys\n",  | 
85 | 85 |     "import torch\n",  | 
 | 86 | +    "import subprocess\n",  | 
86 | 87 |     "need_pytorch3d=False\n",  | 
87 | 88 |     "try:\n",  | 
88 | 89 |     "    import pytorch3d\n",  | 
89 | 90 |     "except ModuleNotFoundError:\n",  | 
90 | 91 |     "    need_pytorch3d=True\n",  | 
91 | 92 |     "if need_pytorch3d:\n",  | 
92 |  | -    "    if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n",  | 
93 |  | -    "        # We try to install PyTorch3D via a released wheel.\n",  | 
94 |  | -    "        pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",  | 
95 |  | -    "        version_str=\"\".join([\n",  | 
96 |  | -    "            f\"py3{sys.version_info.minor}_cu\",\n",  | 
97 |  | -    "            torch.version.cuda.replace(\".\",\"\"),\n",  | 
98 |  | -    "            f\"_pyt{pyt_version_str}\"\n",  | 
99 |  | -    "        ])\n",  | 
100 |  | -    "        !pip install fvcore iopath\n",  | 
 | 93 | +    "    pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",  | 
 | 94 | +    "    version_str=\"\".join([\n",  | 
 | 95 | +    "        f\"py3{sys.version_info.minor}_cu\",\n",  | 
 | 96 | +    "        torch.version.cuda.replace(\".\",\"\"),\n",  | 
 | 97 | +    "        f\"_pyt{pyt_version_str}\"\n",  | 
 | 98 | +    "    ])\n",  | 
 | 99 | +    "    !pip install fvcore iopath\n",  | 
 | 100 | +    "    if sys.platform.startswith(\"linux\"):\n",  | 
 | 101 | +    "        print(\"Trying to install wheel for PyTorch3D\")\n",  | 
101 | 102 |     "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",  | 
102 |  | -    "    else:\n",  | 
103 |  | -    "        # We try to install PyTorch3D from source.\n",  | 
104 |  | -    "        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"  | 
 | 103 | +    "        pip_list = !pip freeze\n",  | 
 | 104 | +    "        need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for  i in pip_list)\n",  | 
 | 105 | +    "    if need_pytorch3d:\n",  | 
 | 106 | +    "        print(f\"failed to find/install wheel for {version_str}\")\n",  | 
 | 107 | +    "if need_pytorch3d:\n",  | 
 | 108 | +    "    print(\"Installing PyTorch3D from source\")\n",  | 
 | 109 | +    "    !pip install ninja\n",  | 
 | 110 | +    "    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"  | 
105 | 111 |    ]  | 
106 | 112 |   },  | 
107 | 113 |   {  | 
 | 
0 commit comments