vfs: check userland buffers before reading them.
[haiku.git] / src / add-ons / media / plugins / ape_reader / MAClib / WAVInputSource.cpp
blob97570ff919b018ac3fce46befccd48ad39a3b95f
1 #include "All.h"
2 #include "WAVInputSource.h"
3 #include IO_HEADER_FILE
4 #include "MACLib.h"
5 #include "GlobalFunctions.h"
7 struct RIFF_HEADER
9 char cRIFF[4]; // the characters 'RIFF' indicating that it's a RIFF file
10 unsigned long nBytes; // the number of bytes following this header
13 struct DATA_TYPE_ID_HEADER
15 char cDataTypeID[4]; // should equal 'WAVE' for a WAV file
18 struct WAV_FORMAT_HEADER
20 unsigned short nFormatTag; // the format of the WAV...should equal 1 for a PCM file
21 unsigned short nChannels; // the number of channels
22 unsigned long nSamplesPerSecond; // the number of samples per second
23 unsigned long nBytesPerSecond; // the bytes per second
24 unsigned short nBlockAlign; // block alignment
25 unsigned short nBitsPerSample; // the number of bits per sample
28 struct RIFF_CHUNK_HEADER
30 char cChunkLabel[4]; // should equal "data" indicating the data chunk
31 unsigned long nChunkBytes; // the bytes of the chunk
35 CInputSource * __stdcall CreateInputSource(const wchar_t * pSourceName, WAVEFORMATEX * pwfeSource, int * pTotalBlocks, int * pHeaderBytes, int * pTerminatingBytes, int * pErrorCode)
37 // error check the parameters
38 if ((pSourceName == NULL) || (wcslen(pSourceName) == 0))
40 if (pErrorCode) *pErrorCode = ERROR_BAD_PARAMETER;
41 return NULL;
44 // get the extension
45 const wchar_t * pExtension = &pSourceName[wcslen(pSourceName)];
46 while ((pExtension > pSourceName) && (*pExtension != '.'))
47 pExtension--;
49 // create the proper input source
50 // SHINTA -->
51 // if (wcsicmp(pExtension, L".wav") == 0)
52 // {
53 if (pErrorCode) *pErrorCode = ERROR_SUCCESS;
54 return new CWAVInputSource(pSourceName, pwfeSource, pTotalBlocks, pHeaderBytes, pTerminatingBytes, pErrorCode);
55 // }
56 // else
57 // {
58 // if (pErrorCode) *pErrorCode = ERROR_INVALID_INPUT_FILE;
59 // return NULL;
60 // }
61 // <-- SHINTA
64 CWAVInputSource::CWAVInputSource(CIO * pIO, WAVEFORMATEX * pwfeSource, int * pTotalBlocks, int * pHeaderBytes, int * pTerminatingBytes, int * pErrorCode)
65 : CInputSource(pIO, pwfeSource, pTotalBlocks, pHeaderBytes, pTerminatingBytes, pErrorCode)
67 m_bIsValid = FALSE;
69 if (pIO == NULL || pwfeSource == NULL)
71 if (pErrorCode) *pErrorCode = ERROR_BAD_PARAMETER;
72 return;
75 m_spIO.Assign(pIO, FALSE, FALSE);
77 int nRetVal = AnalyzeSource();
78 if (nRetVal == ERROR_SUCCESS)
80 // fill in the parameters
81 if (pwfeSource) memcpy(pwfeSource, &m_wfeSource, sizeof(WAVEFORMATEX));
82 if (pTotalBlocks) *pTotalBlocks = m_nDataBytes / m_wfeSource.nBlockAlign;
83 if (pHeaderBytes) *pHeaderBytes = m_nHeaderBytes;
84 if (pTerminatingBytes) *pTerminatingBytes = m_nTerminatingBytes;
86 m_bIsValid = TRUE;
89 if (pErrorCode) *pErrorCode = nRetVal;
92 CWAVInputSource::CWAVInputSource(const wchar_t * pSourceName, WAVEFORMATEX * pwfeSource, int * pTotalBlocks, int * pHeaderBytes, int * pTerminatingBytes, int * pErrorCode)
93 : CInputSource(pSourceName, pwfeSource, pTotalBlocks, pHeaderBytes, pTerminatingBytes, pErrorCode)
95 m_bIsValid = FALSE;
97 if (pSourceName == NULL || pwfeSource == NULL)
99 if (pErrorCode) *pErrorCode = ERROR_BAD_PARAMETER;
100 return;
103 m_spIO.Assign(new IO_CLASS_NAME);
104 if (m_spIO->Open(pSourceName) != ERROR_SUCCESS)
106 m_spIO.Delete();
107 if (pErrorCode) *pErrorCode = ERROR_INVALID_INPUT_FILE;
108 return;
111 int nRetVal = AnalyzeSource();
112 if (nRetVal == ERROR_SUCCESS)
114 // fill in the parameters
115 if (pwfeSource) memcpy(pwfeSource, &m_wfeSource, sizeof(WAVEFORMATEX));
116 if (pTotalBlocks) *pTotalBlocks = m_nDataBytes / m_wfeSource.nBlockAlign;
117 if (pHeaderBytes) *pHeaderBytes = m_nHeaderBytes;
118 if (pTerminatingBytes) *pTerminatingBytes = m_nTerminatingBytes;
120 m_bIsValid = TRUE;
123 if (pErrorCode) *pErrorCode = nRetVal;
126 CWAVInputSource::~CWAVInputSource()
132 int CWAVInputSource::AnalyzeSource()
134 // seek to the beginning (just in case)
135 m_spIO->Seek(0, FILE_BEGIN);
137 // get the file size
138 m_nFileBytes = m_spIO->GetSize();
140 // get the RIFF header
141 RIFF_HEADER RIFFHeader;
142 RETURN_ON_ERROR(ReadSafe(m_spIO, &RIFFHeader, sizeof(RIFFHeader)))
144 // make sure the RIFF header is valid
145 if (!(RIFFHeader.cRIFF[0] == 'R' && RIFFHeader.cRIFF[1] == 'I' && RIFFHeader.cRIFF[2] == 'F' && RIFFHeader.cRIFF[3] == 'F'))
146 return ERROR_INVALID_INPUT_FILE;
148 // read the data type header
149 DATA_TYPE_ID_HEADER DataTypeIDHeader;
150 RETURN_ON_ERROR(ReadSafe(m_spIO, &DataTypeIDHeader, sizeof(DataTypeIDHeader)))
152 // make sure it's the right data type
153 if (!(DataTypeIDHeader.cDataTypeID[0] == 'W' && DataTypeIDHeader.cDataTypeID[1] == 'A' && DataTypeIDHeader.cDataTypeID[2] == 'V' && DataTypeIDHeader.cDataTypeID[3] == 'E'))
154 return ERROR_INVALID_INPUT_FILE;
156 // find the 'fmt ' chunk
157 RIFF_CHUNK_HEADER RIFFChunkHeader;
158 RETURN_ON_ERROR(ReadSafe(m_spIO, &RIFFChunkHeader, sizeof(RIFFChunkHeader)))
160 while (!(RIFFChunkHeader.cChunkLabel[0] == 'f' && RIFFChunkHeader.cChunkLabel[1] == 'm' && RIFFChunkHeader.cChunkLabel[2] == 't' && RIFFChunkHeader.cChunkLabel[3] == ' '))
162 // move the file pointer to the end of this chunk
163 m_spIO->Seek(RIFFChunkHeader.nChunkBytes, FILE_CURRENT);
165 // check again for the data chunk
166 RETURN_ON_ERROR(ReadSafe(m_spIO, &RIFFChunkHeader, sizeof(RIFFChunkHeader)))
169 // read the format info
170 WAV_FORMAT_HEADER WAVFormatHeader;
171 RETURN_ON_ERROR(ReadSafe(m_spIO, &WAVFormatHeader, sizeof(WAVFormatHeader)))
173 // error check the header to see if we support it
174 if (WAVFormatHeader.nFormatTag != 1)
175 return ERROR_INVALID_INPUT_FILE;
177 // copy the format information to the WAVEFORMATEX passed in
178 FillWaveFormatEx(&m_wfeSource, WAVFormatHeader.nSamplesPerSecond, WAVFormatHeader.nBitsPerSample, WAVFormatHeader.nChannels);
180 // skip over any extra data in the header
181 int nWAVFormatHeaderExtra = RIFFChunkHeader.nChunkBytes - sizeof(WAVFormatHeader);
182 if (nWAVFormatHeaderExtra < 0)
183 return ERROR_INVALID_INPUT_FILE;
184 else
185 m_spIO->Seek(nWAVFormatHeaderExtra, FILE_CURRENT);
187 // find the data chunk
188 RETURN_ON_ERROR(ReadSafe(m_spIO, &RIFFChunkHeader, sizeof(RIFFChunkHeader)))
190 while (!(RIFFChunkHeader.cChunkLabel[0] == 'd' && RIFFChunkHeader.cChunkLabel[1] == 'a' && RIFFChunkHeader.cChunkLabel[2] == 't' && RIFFChunkHeader.cChunkLabel[3] == 'a'))
192 // move the file pointer to the end of this chunk
193 m_spIO->Seek(RIFFChunkHeader.nChunkBytes, FILE_CURRENT);
195 // check again for the data chunk
196 RETURN_ON_ERROR(ReadSafe(m_spIO, &RIFFChunkHeader, sizeof(RIFFChunkHeader)))
199 // we're at the data block
200 m_nHeaderBytes = m_spIO->GetPosition();
201 m_nDataBytes = RIFFChunkHeader.nChunkBytes;
202 if (m_nDataBytes < 0)
203 m_nDataBytes = m_nFileBytes - m_nHeaderBytes;
205 // make sure the data bytes is a whole number of blocks
206 if ((m_nDataBytes % m_wfeSource.nBlockAlign) != 0)
207 return ERROR_INVALID_INPUT_FILE;
209 // calculate the terminating byts
210 m_nTerminatingBytes = m_nFileBytes - m_nDataBytes - m_nHeaderBytes;
212 // we made it this far, everything must be cool
213 return ERROR_SUCCESS;
216 int CWAVInputSource::GetData(unsigned char * pBuffer, int nBlocks, int * pBlocksRetrieved)
218 if (!m_bIsValid) return ERROR_UNDEFINED;
220 int nBytes = (m_wfeSource.nBlockAlign * nBlocks);
221 unsigned int nBytesRead = 0;
223 if (m_spIO->Read(pBuffer, nBytes, &nBytesRead) != ERROR_SUCCESS)
224 return ERROR_IO_READ;
226 if (pBlocksRetrieved) *pBlocksRetrieved = (nBytesRead / m_wfeSource.nBlockAlign);
228 return ERROR_SUCCESS;
231 int CWAVInputSource::GetHeaderData(unsigned char * pBuffer)
233 if (!m_bIsValid) return ERROR_UNDEFINED;
235 int nRetVal = ERROR_SUCCESS;
237 if (m_nHeaderBytes > 0)
239 int nOriginalFileLocation = m_spIO->GetPosition();
241 m_spIO->Seek(0, FILE_BEGIN);
243 unsigned int nBytesRead = 0;
244 int nReadRetVal = m_spIO->Read(pBuffer, m_nHeaderBytes, &nBytesRead);
246 if ((nReadRetVal != ERROR_SUCCESS) || (m_nHeaderBytes != int(nBytesRead)))
248 nRetVal = ERROR_UNDEFINED;
251 m_spIO->Seek(nOriginalFileLocation, FILE_BEGIN);
254 return nRetVal;
257 int CWAVInputSource::GetTerminatingData(unsigned char * pBuffer)
259 if (!m_bIsValid) return ERROR_UNDEFINED;
261 int nRetVal = ERROR_SUCCESS;
263 if (m_nTerminatingBytes > 0)
265 int nOriginalFileLocation = m_spIO->GetPosition();
267 m_spIO->Seek(-m_nTerminatingBytes, FILE_END);
269 unsigned int nBytesRead = 0;
270 int nReadRetVal = m_spIO->Read(pBuffer, m_nTerminatingBytes, &nBytesRead);
272 if ((nReadRetVal != ERROR_SUCCESS) || (m_nTerminatingBytes != int(nBytesRead)))
274 nRetVal = ERROR_UNDEFINED;
277 m_spIO->Seek(nOriginalFileLocation, FILE_BEGIN);
280 return nRetVal;