Ta spletni dnevnik bo ponazoril metodo za uporabo metode »torch.argmax()« v PyTorchu.
Kako uporabiti metodo “torch.argmax()” v PyTorchu?
Metoda »torch.argmax()« vzame kateri koli 1D ali 2D tenzor kot vhod in vrne tenzor, ki vsebuje indekse/indekse največjih vrednosti vzdolž dane dimenzije.
Sintaksa metode »torch.argmax()« je podana spodaj:
svetilka. argmax ( < vhodni_tenzor > )
Če želite uporabiti to metodo v PyTorchu, preglejte naslednje primere za boljše razumevanje:
1. primer: uporabite metodo »torch.argmax()« z 1D tenzorjem
V prvem primeru bomo ustvarili 1D tenzor in z njim uporabili metodo »torch.argmax()«. Sledimo spodnjemu postopku po korakih:
1. korak: uvozite knjižnico PyTorch
Najprej uvozite » svetilka ” za uporabo metode “torch.argmax()”:
uvoz svetilka2. korak: Ustvarite 1D tenzor
Nato ustvarite 1D tenzor in natisnite njegove elemente. Tukaj ustvarjamo naslednje ' desetice1 ' tenzor s seznama z uporabo ' torch.tensor() ” funkcija:
desetice1 = svetilka. tenzor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )
tiskanje ( desetice1 )
To je ustvarilo 1D tenzor, kot je prikazano spodaj:
3. korak: Poiščite indekse največje vrednosti
Zdaj pa uporabite » torch.argmax() ' za iskanje indeksa/indeksov največje vrednosti v ' desetice1 ” tenzor:
T1_ind = svetilka. argmax ( desetice1 )4. korak: Natisnite indeks največje vrednosti
Na koncu prikažite indeks največje vrednosti v vhodnem tenzorju:
tiskanje ( 'Indeksi:' , T1_ind )Spodnji rezultat prikazuje indeks največje vrednosti v ' desetice1 ” tenzor, tj. 4. To pomeni, da je najvišja vrednost tenzorja pri 4. indeksu, ki je “ 9 ”:
2. primer: uporabite metodo »torch.argmax()« z 2D tenzorjem
V drugem primeru bomo ustvarili 2D tenzor in z njim uporabili metodo »torch.argmax()«. Sledimo navedenim korakom:
1. korak: uvozite knjižnico PyTorch
Najprej uvozite » svetilka ” za uporabo metode “torch.argmax()”:
uvoz svetilka2. korak: Ustvarite 2D tenzor
Nato uporabite » torch.tensor() ” ustvariti 2D tenzor in natisniti njegove elemente. Tukaj ustvarjamo naslednje ' desetice2 '2D tenzor:
desetice2 = svetilka. tenzor ( [ [ 4 , 1 , - 7 ] , [ petnajst , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )tiskanje ( desetice2 )
To je ustvarilo 2D tenzor, kot je prikazano spodaj:
3. korak: Poiščite indekse največje vrednosti
Zdaj poiščite indeks največje vrednosti v ' desetice2 ' tenzor z uporabo ' torch.argmax() ” funkcija:
T2_ind = svetilka. argmax ( desetice2 )4. korak: Natisnite indeks največje vrednosti
Končno prikaži indeks največje vrednosti v vhodnem tenzorju:
tiskanje ( 'Indeksi:' , T2_ind )Glede na spodnji rezultat je indeks največje vrednosti v ' desetice2 ” tenzor je “3”. To pomeni, da je najvišja vrednost tenzorja pri 3. indeksu, ki je ' petnajst ”:
5. korak: Poiščite indekse največje vrednosti vzdolž stolpcev
Poleg tega lahko uporabniki najdejo tudi indekse/indekse največjih vrednosti vzdolž vsakega stolpca tenzorja. Na primer, lahko uporabimo ' dim=0 « s funkcijo »torch.argmax()«. Najde indekse največjih vrednosti vzdolž stolpcev v ' desetice2 ” tenzor in nato natisne te indekse:
col_index = svetilka. argmax ( desetice2 , dim = 0 )tiskanje ( 'Indeksi v stolpcih:' , col_index )
Spodnji rezultat prikazuje indekse največjih vrednosti vzdolž vsakega stolpca tenzorja:
6. korak: Poiščite indekse največje vrednosti vzdolž vrstic
Podobno lahko uporabniki najdejo tudi indekse/indekse največjih vrednosti vzdolž vsake vrstice tenzorja. Uporabite na primer » dim=1 « s funkcijo »torch.argmax()«, da poiščete indekse največjih vrednosti vzdolž vrstic v tenzorju »Tens2« in nato natisnete te indekse:
indeks_vrstice = svetilka. argmax ( desetice2 , dim = 1 )tiskanje ( 'Indeksi v vrsticah:' , indeks_vrstice )
Indekse največje vrednosti vzdolž vsake vrstice tenzorja 'Tens2' lahko vidite spodaj:
Učinkovito smo razložili metodo za uporabo metode »torch.argmax()« v PyTorchu.
Opomba : Do našega zvezka Google Colab lahko dostopate tukaj povezava .
Zaključek
Če želite uporabiti metodo »torch.argmax()« v PyTorchu, najprej uvozite » svetilka ' knjižnica. Nato ustvarite želeni 1D ali 2D tenzor in si oglejte njegove elemente. Nato uporabite » torch.argmax() ” za iskanje/izračun indeksov/indeksov največjih vrednosti v tenzorju. Poleg tega lahko uporabniki najdejo tudi indekse največje vrednosti vzdolž vsake vrstice ali stolpca v tenzorju z uporabo ' dim ' prepir. Na koncu prikažite indeks največje vrednosti v vhodnem tenzorju. Ta blog je ponazoril metodo za uporabo metode »torch.argmax()« v PyTorchu.