diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 865179a..b5788e7 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -27,6 +27,17 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { std::round(seconds * timeBase.den / timeBase.num)); } +inline char* find_codec(const char* input) { + const char* codecs[] = {"h264", "hevc", "av1", "vp9"}; + size_t codec_len = sizeof(codecs) / sizeof(codecs[0]); + for (size_t i = 0; i < codec_len; ++i) { + if (strstr(input, codecs[i])) { + return (char*)codecs[i]; + } + } + return NULL; +} + // Some videos aren't properly encoded and do not specify pts values for // packets, and thus for frames. Unset values correspond to INT64_MIN. When that // happens, we fallback to the dts value which hopefully exists and is correct. @@ -425,9 +436,22 @@ void SingleStreamDecoder::addStream( // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { if (deviceInterface_) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) - .value_or(avCodec)); + if (device.type() != torch::kCUDA) { + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); + } + else { + const char* cuvid_suffix = "_cuvid"; + char* codec_name = find_codec(avCodec->name); + size_t cuvid_length = std::strlen(codec_name) + std::strlen(cuvid_suffix) + 1; + char* cuvid_name = new char[cuvid_length]; + std::strcpy(cuvid_name, codec_name); + std::strcat(cuvid_name, cuvid_suffix); + avCodec = avcodec_find_decoder_by_name(cuvid_name); + delete[] cuvid_name; + TORCH_CHECK(avCodec != nullptr); + } } }