# pre_trained_model.py

from torchvision import models
import torch.nn as nn
import torch

def make_resnet_model(cls_count=1000):
    device = torch.device("cuda:0")
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, cls_count)
    
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    return model, device
